diff --git a/proprietary/src/main/java/stirling/software/proprietary/security/JWTAuthenticationEntryPoint.java b/proprietary/src/main/java/stirling/software/proprietary/security/JWTAuthenticationEntryPoint.java index e1164541f..508865fc9 100644 --- a/proprietary/src/main/java/stirling/software/proprietary/security/JWTAuthenticationEntryPoint.java +++ b/proprietary/src/main/java/stirling/software/proprietary/security/JWTAuthenticationEntryPoint.java @@ -17,7 +17,6 @@ public class JWTAuthenticationEntryPoint implements AuthenticationEntryPoint { HttpServletResponse response, AuthenticationException authException) throws IOException { - response.sendError( - HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized: " + authException.getMessage()); + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage()); } } diff --git a/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java b/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java index 8090ced3b..6d3caa690 100644 --- a/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java +++ b/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java @@ -73,7 +73,6 @@ public class SecurityConfiguration { private final ApplicationProperties.Security securityProperties; private final AppConfig appConfig; private final UserAuthenticationFilter userAuthenticationFilter; - private final JWTAuthenticationFilter jwtAuthenticationFilter; private final JWTServiceInterface jwtService; private final JWTAuthenticationEntryPoint jwtAuthenticationEntryPoint; private final LoginAttemptService loginAttemptService; @@ -93,7 +92,6 @@ public class SecurityConfiguration { AppConfig appConfig, ApplicationProperties.Security securityProperties, UserAuthenticationFilter userAuthenticationFilter, - JWTAuthenticationFilter jwtAuthenticationFilter, JWTServiceInterface jwtService, JWTAuthenticationEntryPoint jwtAuthenticationEntryPoint, LoginAttemptService loginAttemptService, @@ -111,7 +109,6 @@ public class SecurityConfiguration { this.appConfig = appConfig; this.securityProperties = securityProperties; this.userAuthenticationFilter = userAuthenticationFilter; - this.jwtAuthenticationFilter = jwtAuthenticationFilter; this.jwtService = jwtService; this.jwtAuthenticationEntryPoint = jwtAuthenticationEntryPoint; this.loginAttemptService = loginAttemptService; @@ -138,9 +135,10 @@ public class SecurityConfiguration { } if (loginEnabledValue) { - if (jwtEnabled && jwtAuthenticationFilter != null) { + if (jwtEnabled) { http.addFilterBefore( - jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class) + jwtAuthenticationFilter(), + UsernamePasswordAuthenticationFilter.class) .exceptionHandling( exceptionHandling -> exceptionHandling.authenticationEntryPoint( @@ -370,4 +368,10 @@ public class SecurityConfiguration { public PersistentTokenRepository persistentTokenRepository() { return new JPATokenRepositoryImpl(persistentLoginRepository); } + + @Bean + public JWTAuthenticationFilter jwtAuthenticationFilter() { + return new JWTAuthenticationFilter( + jwtService, userDetailsService, jwtAuthenticationEntryPoint); + } } diff --git a/proprietary/src/main/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilter.java b/proprietary/src/main/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilter.java index 975e5732f..79f4fff3f 100644 --- a/proprietary/src/main/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilter.java +++ b/proprietary/src/main/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilter.java @@ -5,11 +5,12 @@ import java.io.IOException; import org.springframework.boot.autoconfigure.condition.ConditionalOnBooleanProperty; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; -import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import jakarta.servlet.FilterChain; @@ -24,17 +25,20 @@ import stirling.software.proprietary.security.service.CustomUserDetailsService; import stirling.software.proprietary.security.service.JWTServiceInterface; @Slf4j -@Component @ConditionalOnBooleanProperty("security.jwt.enabled") public class JWTAuthenticationFilter extends OncePerRequestFilter { private final JWTServiceInterface jwtService; private final CustomUserDetailsService userDetailsService; + private final AuthenticationEntryPoint authenticationEntryPoint; public JWTAuthenticationFilter( - JWTServiceInterface jwtService, CustomUserDetailsService userDetailsService) { + JWTServiceInterface jwtService, + CustomUserDetailsService userDetailsService, + AuthenticationEntryPoint authenticationEntryPoint) { this.jwtService = jwtService; this.userDetailsService = userDetailsService; + this.authenticationEntryPoint = authenticationEntryPoint; } @Override @@ -59,14 +63,17 @@ public class JWTAuthenticationFilter extends OncePerRequestFilter { response.sendRedirect("/login"); return; } - sendUnauthorizedResponse(response, "JWT is missing from the request"); + handleAuthenticationFailure( + request, + response, + new AuthenticationFailureException("JWT is missing from the request")); return; } try { jwtService.validateToken(jwtToken); } catch (AuthenticationFailureException e) { - sendUnauthorizedResponse(response, e.getMessage()); + handleAuthenticationFailure(request, response, e); return; } @@ -139,26 +146,11 @@ public class JWTAuthenticationFilter extends OncePerRequestFilter { return false; } - private void sendUnauthorizedResponse(HttpServletResponse response, String message) - throws IOException { - int unauthorized = HttpServletResponse.SC_UNAUTHORIZED; - - response.setStatus(unauthorized); - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - - String jsonResponse = - String.format( - """ - { - "error": "Unauthorized", - "message": "%s", - "status": %d - } - """, - message, unauthorized); - - response.getWriter().write(jsonResponse); - response.getWriter().flush(); + private void handleAuthenticationFailure( + HttpServletRequest request, + HttpServletResponse response, + AuthenticationException authException) + throws IOException, ServletException { + authenticationEntryPoint.commence(request, response, authException); } } diff --git a/proprietary/src/test/java/stirling/software/proprietary/security/JWTAuthenticationEntryPointTest.java b/proprietary/src/test/java/stirling/software/proprietary/security/JWTAuthenticationEntryPointTest.java index 50a7e0442..d2e233d7a 100644 --- a/proprietary/src/test/java/stirling/software/proprietary/security/JWTAuthenticationEntryPointTest.java +++ b/proprietary/src/test/java/stirling/software/proprietary/security/JWTAuthenticationEntryPointTest.java @@ -10,6 +10,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.security.core.AuthenticationException; import java.io.IOException; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; import static org.mockito.Mockito.*; @@ -23,7 +24,7 @@ class JWTAuthenticationEntryPointTest { private HttpServletResponse response; @Mock - private AuthenticationException authException; + private AuthenticationFailureException authException; @InjectMocks private JWTAuthenticationEntryPoint jwtAuthenticationEntryPoint; @@ -35,8 +36,6 @@ class JWTAuthenticationEntryPointTest { jwtAuthenticationEntryPoint.commence(request, response, authException); - verify(response).sendError( - HttpServletResponse.SC_UNAUTHORIZED, - "Unauthorized: " + errorMessage); + verify(response).sendError(HttpServletResponse.SC_UNAUTHORIZED, errorMessage); } -} \ No newline at end of file +} diff --git a/proprietary/src/test/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilterTest.java b/proprietary/src/test/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilterTest.java index 7eb2caa9f..a337651fb 100644 --- a/proprietary/src/test/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilterTest.java +++ b/proprietary/src/test/java/stirling/software/proprietary/security/filter/JWTAuthenticationFilterTest.java @@ -1,5 +1,6 @@ package stirling.software.proprietary.security.filter; +import jakarta.inject.Inject; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; @@ -7,6 +8,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; @@ -18,6 +20,7 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.web.AuthenticationEntryPoint; import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; import stirling.software.proprietary.security.service.CustomUserDetailsService; import stirling.software.proprietary.security.service.JWTServiceInterface; @@ -32,6 +35,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.contains; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; @ExtendWith(MockitoExtension.class) @@ -61,15 +65,14 @@ class JWTAuthenticationFilterTest { @Mock private PrintWriter printWriter; + @Mock + private AuthenticationEntryPoint authenticationEntryPoint; + + @InjectMocks private JWTAuthenticationFilter jwtAuthenticationFilter; - @BeforeEach - void setUp() { - jwtAuthenticationFilter = new JWTAuthenticationFilter(jwtService, userDetailsService); - } - @Test - void testDoFilterInternalWhenJwtDisabled() throws ServletException, IOException { + void shouldNotAuthenticateWhenJwtDisabled() throws ServletException, IOException { when(jwtService.isJwtEnabled()).thenReturn(false); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); @@ -79,7 +82,7 @@ class JWTAuthenticationFilterTest { } @Test - void testDoFilterInternalWhenShouldNotFilter() throws ServletException, IOException { + void shouldNotFilterWhenPageIsLogin() throws ServletException, IOException { when(jwtService.isJwtEnabled()).thenReturn(true); when(request.getRequestURI()).thenReturn("/login"); when(request.getMethod()).thenReturn("POST"); @@ -91,7 +94,7 @@ class JWTAuthenticationFilterTest { } @Test - void testDoFilterInternalWithValidToken() throws ServletException, IOException { + void testDoFilterInternal() throws ServletException, IOException { String token = "valid-jwt-token"; String newToken = "new-jwt-token"; String username = "testuser"; @@ -106,13 +109,9 @@ class JWTAuthenticationFilterTest { when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); try (MockedStatic mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { - // Create the authentication token that will be set and returned - UsernamePasswordAuthenticationToken authToken = + UsernamePasswordAuthenticationToken authToken = new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities()); - - // Mock the security context behavior: - // - First call (in createAuthToken): returns null - // - Second call (in createAuthToken after setting): returns the created token + when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken); mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); when(jwtService.generateToken(authToken)).thenReturn(newToken); @@ -143,24 +142,7 @@ class JWTAuthenticationFilterTest { } @Test - void testDoFilterInternalWithMissingTokenForNonRootPath() throws ServletException, IOException { - when(jwtService.isJwtEnabled()).thenReturn(true); - when(request.getRequestURI()).thenReturn("/protected"); - when(request.getMethod()).thenReturn("GET"); - when(jwtService.extractTokenFromRequest(request)).thenReturn(null); - when(response.getWriter()).thenReturn(printWriter); - - jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); - - verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); - verify(response).setContentType("application/json"); - verify(response).setCharacterEncoding("UTF-8"); - verify(printWriter).write(contains("JWT is missing from the request")); - verify(filterChain, never()).doFilter(request, response); - } - - @Test - void testDoFilterInternalWithInvalidToken() throws ServletException, IOException { + void validationFailsWithInvalidToken() throws ServletException, IOException { String token = "invalid-jwt-token"; when(jwtService.isJwtEnabled()).thenReturn(true); @@ -168,20 +150,16 @@ class JWTAuthenticationFilterTest { when(request.getMethod()).thenReturn("GET"); when(jwtService.extractTokenFromRequest(request)).thenReturn(token); doThrow(new AuthenticationFailureException("Invalid token")).when(jwtService).validateToken(token); - when(response.getWriter()).thenReturn(printWriter); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); verify(jwtService).validateToken(token); - verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); - verify(response).setContentType("application/json"); - verify(response).setCharacterEncoding("UTF-8"); - verify(printWriter).write(contains("Invalid token")); + verify(authenticationEntryPoint).commence(eq(request), eq(response), any(AuthenticationFailureException.class)); verify(filterChain, never()).doFilter(request, response); } @Test - void testDoFilterInternalWithExpiredToken() throws ServletException, IOException { + void validationFailsWithExpiredToken() throws ServletException, IOException { String token = "expired-jwt-token"; when(jwtService.isJwtEnabled()).thenReturn(true); @@ -189,20 +167,16 @@ class JWTAuthenticationFilterTest { when(request.getMethod()).thenReturn("GET"); when(jwtService.extractTokenFromRequest(request)).thenReturn(token); doThrow(new AuthenticationFailureException("The token has expired")).when(jwtService).validateToken(token); - when(response.getWriter()).thenReturn(printWriter); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); verify(jwtService).validateToken(token); - verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); - verify(response).setContentType("application/json"); - verify(response).setCharacterEncoding("UTF-8"); - verify(printWriter).write(contains("The token has expired")); + verify(authenticationEntryPoint).commence(eq(request), eq(response), any()); verify(filterChain, never()).doFilter(request, response); } @Test - void testDoFilterInternalWithUserNotFound() throws ServletException, IOException { + void exceptinonThrown_WhenUserNotFound() throws ServletException, IOException { String token = "valid-jwt-token"; String username = "nonexistentuser"; @@ -218,45 +192,16 @@ class JWTAuthenticationFilterTest { when(securityContext.getAuthentication()).thenReturn(null); mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); - assertThrows(UsernameNotFoundException.class, () -> { - jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); - }); + UsernameNotFoundException result = assertThrows(UsernameNotFoundException.class, () -> jwtAuthenticationFilter.doFilterInternal(request, response, filterChain)); + assertEquals("User not found: " + username, result.getMessage()); verify(userDetailsService).loadUserByUsername(username); verify(filterChain, never()).doFilter(request, response); } } @Test - void testDoFilterInternalWithExistingAuthentication() throws ServletException, IOException { - String token = "valid-jwt-token"; - String newToken = "new-jwt-token"; - String username = "testuser"; - - when(jwtService.isJwtEnabled()).thenReturn(true); - when(request.getRequestURI()).thenReturn("/protected"); - when(request.getMethod()).thenReturn("GET"); - when(jwtService.extractTokenFromRequest(request)).thenReturn(token); - doNothing().when(jwtService).validateToken(token); - when(jwtService.extractUsername(token)).thenReturn(username); - - try (MockedStatic mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { - Authentication existingAuth = mock(Authentication.class); - when(securityContext.getAuthentication()).thenReturn(existingAuth); - mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); - when(jwtService.generateToken(existingAuth)).thenReturn(newToken); - - jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); - - verify(userDetailsService, never()).loadUserByUsername(anyString()); - verify(jwtService).generateToken(existingAuth); - verify(jwtService).addTokenToResponse(response, newToken); - verify(filterChain).doFilter(request, response); - } - } - - @Test - void testShouldNotFilterLoginPost() { + void shouldNotFilterLoginPost() { when(request.getRequestURI()).thenReturn("/login"); when(request.getMethod()).thenReturn("POST"); @@ -264,7 +209,7 @@ class JWTAuthenticationFilterTest { } @Test - void testShouldNotFilterLoginGet() { + void shouldNotFilterLoginGet() { when(request.getRequestURI()).thenReturn("/login"); when(request.getMethod()).thenReturn("GET"); @@ -272,7 +217,7 @@ class JWTAuthenticationFilterTest { } @Test - void testShouldNotFilterPublicPaths() { + void shouldNotFilterPublicPaths() { String[] publicPaths = { "/register", "/error", @@ -298,7 +243,7 @@ class JWTAuthenticationFilterTest { } @Test - void testShouldNotFilterStaticFiles() { + void shouldNotFilterStaticFiles() { String[] staticFiles = { "/some/path/file.svg", "/another/path/image.png", @@ -315,7 +260,7 @@ class JWTAuthenticationFilterTest { } @Test - void testShouldFilterProtectedPaths() { + void shouldFilterProtectedPaths() { String[] protectedPaths = { "/protected", "/api/v1/user/profile", @@ -333,7 +278,7 @@ class JWTAuthenticationFilterTest { } @Test - void testShouldFilterRootPath() { + void shouldFilterRootPath() { when(request.getRequestURI()).thenReturn("/"); when(request.getMethod()).thenReturn("GET"); @@ -341,23 +286,17 @@ class JWTAuthenticationFilterTest { } @Test - void testSendUnauthorizedResponseFormat() throws ServletException, IOException { + void testAuthenticationEntryPointCalledWithCorrectException() throws ServletException, IOException { when(jwtService.isJwtEnabled()).thenReturn(true); when(request.getRequestURI()).thenReturn("/protected"); when(request.getMethod()).thenReturn("GET"); when(jwtService.extractTokenFromRequest(request)).thenReturn(null); - when(response.getWriter()).thenReturn(printWriter); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); - verify(response).setStatus(401); - verify(response).setContentType("application/json"); - verify(response).setCharacterEncoding("UTF-8"); - verify(printWriter).write(argThat((String json) -> - json.contains("\"error\": \"Unauthorized\"") && - json.contains("\"message\": \"JWT is missing from the request\"") && - json.contains("\"status\": 401") + verify(authenticationEntryPoint).commence(eq(request), eq(response), argThat(exception -> + exception.getMessage().equals("JWT is missing from the request") )); - verify(printWriter).flush(); + verify(filterChain, never()).doFilter(request, response); } -} \ No newline at end of file +} diff --git a/proprietary/src/test/java/stirling/software/proprietary/security/service/JWTServiceTest.java b/proprietary/src/test/java/stirling/software/proprietary/security/service/JWTServiceTest.java index 7d03a6393..980866a2e 100644 --- a/proprietary/src/test/java/stirling/software/proprietary/security/service/JWTServiceTest.java +++ b/proprietary/src/test/java/stirling/software/proprietary/security/service/JWTServiceTest.java @@ -86,7 +86,7 @@ class JWTServiceTest { assertNotNull(token); assertTrue(token.length() > 0); assertEquals(username, jwtService.extractUsername(token)); - + Map extractedClaims = jwtService.extractAllClaims(token); assertEquals("admin", extractedClaims.get("role")); assertEquals("IT", extractedClaims.get("department")); @@ -111,7 +111,7 @@ class JWTServiceTest { @Test void testValidateTokenWhenJwtDisabled() { when(securityProperties.isJwtActive()).thenReturn(false); - + assertThrows(IllegalStateException.class, () -> { jwtService.validateToken("any-token"); }); @@ -148,7 +148,7 @@ class JWTServiceTest { AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> { jwtService.validateToken("malformed.token"); }); - + assertTrue(exception.getMessage().contains("Invalid")); } @@ -157,7 +157,7 @@ class JWTServiceTest { AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> { jwtService.validateToken(""); }); - + assertTrue(exception.getMessage().contains("Claims are empty") || exception.getMessage().contains("Invalid")); } @@ -204,20 +204,16 @@ class JWTServiceTest { String token = jwtService.generateToken("testuser", new HashMap<>()); assertFalse(jwtService.isTokenExpired(token)); - // Create a token that expires immediately when(jwtProperties.getExpiration()).thenReturn(1L); JWTService shortLivedJwtService = new JWTService(securityProperties); String expiredToken = shortLivedJwtService.generateToken("testuser", new HashMap<>()); - // Wait a bit to ensure expiration try { Thread.sleep(10); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } - // Since expired tokens now throw exceptions in extractAllClaimsFromToken, - // isTokenExpired will also throw an exception assertThrows(AuthenticationFailureException.class, () -> { shortLivedJwtService.isTokenExpired(expiredToken); }); @@ -326,4 +322,4 @@ class JWTServiceTest { JWTService jwtServiceWithNullProps = new JWTService(securityProperties); assertFalse(jwtServiceWithNullProps.isJwtEnabled()); } -} \ No newline at end of file +}