Updated tests

This commit is contained in:
Dario Ghunney Ware 2025-07-09 17:07:38 +01:00
parent 00db1ccf95
commit dfeb17c886
3 changed files with 151 additions and 64 deletions

View File

@ -70,29 +70,12 @@ public class JWTService implements JWTServiceInterface {
} }
@Override @Override
public void validateToken(String token) { public void validateToken(String token) throws AuthenticationFailureException {
if (!isJwtEnabled()) { if (!isJwtEnabled()) {
throw new IllegalStateException("JWT is not enabled"); throw new IllegalStateException("JWT is not enabled");
} }
try {
extractAllClaimsFromToken(token); extractAllClaimsFromToken(token);
} catch (SignatureException e) {
log.warn("Invalid signature: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid signature", e);
} catch (MalformedJwtException e) {
log.warn("Invalid token: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid token", e);
} catch (ExpiredJwtException e) {
log.warn("The token has expired: {}", e.getMessage());
throw new AuthenticationFailureException("The token has expired", e);
} catch (UnsupportedJwtException e) {
log.warn("The token is unsupported: {}", e.getMessage());
throw new AuthenticationFailureException("The token is unsupported", e);
} catch (IllegalArgumentException e) {
log.warn("Claims are empty: {}", e.getMessage());
throw new AuthenticationFailureException("Claims are empty", e);
}
} }
@Override @Override
@ -128,20 +111,20 @@ public class JWTService implements JWTServiceInterface {
.parseSignedClaims(token) .parseSignedClaims(token)
.getPayload(); .getPayload();
} catch (SignatureException e) { } catch (SignatureException e) {
log.warn("Invalid JWT signature: {}", e.getMessage()); log.warn("Invalid signature: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid JWT signature", e); throw new AuthenticationFailureException("Invalid signature", e);
} catch (MalformedJwtException e) { } catch (MalformedJwtException e) {
log.warn("Invalid JWT token: {}", e.getMessage()); log.warn("Invalid token: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid JWT token", e); throw new AuthenticationFailureException("Invalid token", e);
} catch (ExpiredJwtException e) { } catch (ExpiredJwtException e) {
log.warn("JWT token is expired: {}", e.getMessage()); log.warn("The token has expired: {}", e.getMessage());
throw new AuthenticationFailureException("JWT token is expired", e); throw new AuthenticationFailureException("The token has expired", e);
} catch (UnsupportedJwtException e) { } catch (UnsupportedJwtException e) {
log.warn("JWT token is unsupported: {}", e.getMessage()); log.warn("The token is unsupported: {}", e.getMessage());
throw new AuthenticationFailureException("JWT token is unsupported", e); throw new AuthenticationFailureException("The token is unsupported", e);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
log.warn("JWT claims are empty: {}", e.getMessage()); log.warn("Claims are empty: {}", e.getMessage());
throw new AuthenticationFailureException("JWT claims are empty", e); throw new AuthenticationFailureException("Claims are empty", e);
} }
} }

View File

@ -4,16 +4,15 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import java.util.List;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockedStatic; import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
@ -24,13 +23,15 @@ import stirling.software.proprietary.security.service.CustomUserDetailsService;
import stirling.software.proprietary.security.service.JWTServiceInterface; import stirling.software.proprietary.security.service.JWTServiceInterface;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.io.PrintWriter;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.contains;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@ -57,9 +58,16 @@ class JWTAuthenticationFilterTest {
@Mock @Mock
private SecurityContext securityContext; private SecurityContext securityContext;
@InjectMocks @Mock
private PrintWriter printWriter;
private JWTAuthenticationFilter jwtAuthenticationFilter; private JWTAuthenticationFilter jwtAuthenticationFilter;
@BeforeEach
void setUp() {
jwtAuthenticationFilter = new JWTAuthenticationFilter(jwtService, userDetailsService);
}
@Test @Test
void testDoFilterInternalWhenJwtDisabled() throws ServletException, IOException { void testDoFilterInternalWhenJwtDisabled() throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(false); when(jwtService.isJwtEnabled()).thenReturn(false);
@ -92,15 +100,22 @@ class JWTAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET"); when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token); when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
when(jwtService.validateToken(token)).thenReturn(true); doNothing().when(jwtService).validateToken(token);
when(jwtService.extractUsername(token)).thenReturn(username); when(jwtService.extractUsername(token)).thenReturn(username);
when(userDetails.getAuthorities()).thenReturn((Collection) Arrays.asList(new SimpleGrantedAuthority("ROLE_USER"))); when(userDetails.getAuthorities()).thenReturn(Collections.emptyList());
when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails);
try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) {
when(securityContext.getAuthentication()).thenReturn(null); // Create the authentication token that will be set and returned
UsernamePasswordAuthenticationToken authToken =
new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities());
// Mock the security context behavior:
// - First call (in createAuthToken): returns null
// - Second call (in createAuthToken after setting): returns the created token
when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken);
mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext);
when(jwtService.generateToken(any())).thenReturn(newToken); when(jwtService.generateToken(authToken)).thenReturn(newToken);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
@ -108,7 +123,7 @@ class JWTAuthenticationFilterTest {
verify(jwtService).extractUsername(token); verify(jwtService).extractUsername(token);
verify(userDetailsService).loadUserByUsername(username); verify(userDetailsService).loadUserByUsername(username);
verify(securityContext).setAuthentication(any(UsernamePasswordAuthenticationToken.class)); verify(securityContext).setAuthentication(any(UsernamePasswordAuthenticationToken.class));
verify(jwtService).generateToken(any()); verify(jwtService).generateToken(authToken);
verify(jwtService).addTokenToResponse(response, newToken); verify(jwtService).addTokenToResponse(response, newToken);
verify(filterChain).doFilter(request, response); verify(filterChain).doFilter(request, response);
} }
@ -133,12 +148,14 @@ class JWTAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET"); when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(null); when(jwtService.extractTokenFromRequest(request)).thenReturn(null);
when(response.getWriter()).thenReturn(printWriter);
assertThrows(AuthenticationFailureException.class, () -> {
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
});
verify(response, never()).sendRedirect(anyString()); verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
verify(response).setContentType("application/json");
verify(response).setCharacterEncoding("UTF-8");
verify(printWriter).write(contains("JWT is missing from the request"));
verify(filterChain, never()).doFilter(request, response); verify(filterChain, never()).doFilter(request, response);
} }
@ -150,13 +167,37 @@ class JWTAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET"); when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token); when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
when(jwtService.validateToken(token)).thenReturn(false); doThrow(new AuthenticationFailureException("Invalid token")).when(jwtService).validateToken(token);
when(response.getWriter()).thenReturn(printWriter);
assertThrows(AuthenticationFailureException.class, () -> {
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
});
verify(jwtService).validateToken(token); verify(jwtService).validateToken(token);
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
verify(response).setContentType("application/json");
verify(response).setCharacterEncoding("UTF-8");
verify(printWriter).write(contains("Invalid token"));
verify(filterChain, never()).doFilter(request, response);
}
@Test
void testDoFilterInternalWithExpiredToken() throws ServletException, IOException {
String token = "expired-jwt-token";
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
doThrow(new AuthenticationFailureException("The token has expired")).when(jwtService).validateToken(token);
when(response.getWriter()).thenReturn(printWriter);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(jwtService).validateToken(token);
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
verify(response).setContentType("application/json");
verify(response).setCharacterEncoding("UTF-8");
verify(printWriter).write(contains("The token has expired"));
verify(filterChain, never()).doFilter(request, response); verify(filterChain, never()).doFilter(request, response);
} }
@ -169,7 +210,7 @@ class JWTAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET"); when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token); when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
when(jwtService.validateToken(token)).thenReturn(true); doNothing().when(jwtService).validateToken(token);
when(jwtService.extractUsername(token)).thenReturn(username); when(jwtService.extractUsername(token)).thenReturn(username);
when(userDetailsService.loadUserByUsername(username)).thenReturn(null); when(userDetailsService.loadUserByUsername(username)).thenReturn(null);
@ -196,7 +237,7 @@ class JWTAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET"); when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(token); when(jwtService.extractTokenFromRequest(request)).thenReturn(token);
when(jwtService.validateToken(token)).thenReturn(true); doNothing().when(jwtService).validateToken(token);
when(jwtService.extractUsername(token)).thenReturn(username); when(jwtService.extractUsername(token)).thenReturn(username);
try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) {
@ -298,4 +339,25 @@ class JWTAuthenticationFilterTest {
assertFalse(jwtAuthenticationFilter.shouldNotFilter(request)); assertFalse(jwtAuthenticationFilter.shouldNotFilter(request));
} }
@Test
void testSendUnauthorizedResponseFormat() throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected");
when(request.getMethod()).thenReturn("GET");
when(jwtService.extractTokenFromRequest(request)).thenReturn(null);
when(response.getWriter()).thenReturn(printWriter);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(response).setStatus(401);
verify(response).setContentType("application/json");
verify(response).setCharacterEncoding("UTF-8");
verify(printWriter).write(argThat((String json) ->
json.contains("\"error\": \"Unauthorized\"") &&
json.contains("\"message\": \"JWT is missing from the request\"") &&
json.contains("\"status\": 401")
));
verify(printWriter).flush();
}
} }

View File

@ -9,7 +9,6 @@ import io.jsonwebtoken.security.SignatureException;
import jakarta.servlet.http.Cookie; import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@ -18,6 +17,7 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import java.security.KeyPair; import java.security.KeyPair;
import java.util.Date; import java.util.Date;
@ -105,19 +105,23 @@ class JWTServiceTest {
void testValidateTokenSuccess() { void testValidateTokenSuccess() {
String token = jwtService.generateToken("testuser", new HashMap<>()); String token = jwtService.generateToken("testuser", new HashMap<>());
assertTrue(jwtService.validateToken(token)); assertDoesNotThrow(() -> jwtService.validateToken(token));
} }
@Test @Test
void testValidateTokenWhenJwtDisabled() { void testValidateTokenWhenJwtDisabled() {
when(securityProperties.isJwtActive()).thenReturn(false); when(securityProperties.isJwtActive()).thenReturn(false);
assertFalse(jwtService.validateToken("any-token")); assertThrows(IllegalStateException.class, () -> {
jwtService.validateToken("any-token");
});
} }
@Test @Test
void testValidateTokenWithInvalidToken() { void testValidateTokenWithInvalidToken() {
assertFalse(jwtService.validateToken("invalid-token")); assertThrows(AuthenticationFailureException.class, () -> {
jwtService.validateToken("invalid-token");
});
} }
@Test @Test
@ -134,7 +138,27 @@ class JWTServiceTest {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
} }
assertFalse(shortLivedJwtService.validateToken(token)); assertThrows(AuthenticationFailureException.class, () -> {
shortLivedJwtService.validateToken(token);
});
}
@Test
void testValidateTokenWithMalformedToken() {
AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> {
jwtService.validateToken("malformed.token");
});
assertTrue(exception.getMessage().contains("Invalid"));
}
@Test
void testValidateTokenWithEmptyToken() {
AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> {
jwtService.validateToken("");
});
assertTrue(exception.getMessage().contains("Claims are empty") || exception.getMessage().contains("Invalid"));
} }
@Test @Test
@ -145,6 +169,13 @@ class JWTServiceTest {
assertEquals(username, jwtService.extractUsername(token)); assertEquals(username, jwtService.extractUsername(token));
} }
@Test
void testExtractUsernameWithInvalidToken() {
assertThrows(AuthenticationFailureException.class, () -> {
jwtService.extractUsername("invalid-token");
});
}
@Test @Test
void testExtractAllClaims() { void testExtractAllClaims() {
String username = "testuser"; String username = "testuser";
@ -162,11 +193,9 @@ class JWTServiceTest {
} }
@Test @Test
void testExtractAllClaimsWhenJwtDisabled() { void testExtractAllClaimsWithInvalidToken() {
when(securityProperties.isJwtActive()).thenReturn(false); assertThrows(AuthenticationFailureException.class, () -> {
jwtService.extractAllClaims("invalid-token");
assertThrows(IllegalStateException.class, () -> {
jwtService.extractAllClaims("any-token");
}); });
} }
@ -175,17 +204,30 @@ class JWTServiceTest {
String token = jwtService.generateToken("testuser", new HashMap<>()); String token = jwtService.generateToken("testuser", new HashMap<>());
assertFalse(jwtService.isTokenExpired(token)); assertFalse(jwtService.isTokenExpired(token));
// Create a token that expires immediately
when(jwtProperties.getExpiration()).thenReturn(1L); when(jwtProperties.getExpiration()).thenReturn(1L);
JWTService shortLivedJwtService = new JWTService(securityProperties); JWTService shortLivedJwtService = new JWTService(securityProperties);
String expiredToken = shortLivedJwtService.generateToken("testuser", new HashMap<>()); String expiredToken = shortLivedJwtService.generateToken("testuser", new HashMap<>());
// Wait a bit to ensure expiration
try { try {
Thread.sleep(10); Thread.sleep(10);
} catch (InterruptedException e) { } catch (InterruptedException e) {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
} }
assertThrows(ExpiredJwtException.class, () -> assertTrue(shortLivedJwtService.isTokenExpired(expiredToken))); // Since expired tokens now throw exceptions in extractAllClaimsFromToken,
// isTokenExpired will also throw an exception
assertThrows(AuthenticationFailureException.class, () -> {
shortLivedJwtService.isTokenExpired(expiredToken);
});
}
@Test
void testIsTokenExpiredWithInvalidToken() {
assertThrows(AuthenticationFailureException.class, () -> {
jwtService.isTokenExpired("invalid-token");
});
} }
@Test @Test