This commit is contained in:
Dario Ghunney Ware 2025-08-07 12:40:13 +01:00
parent 24c35d610d
commit ed6b606958
13 changed files with 349 additions and 305 deletions

View File

@ -50,4 +50,4 @@ spring.main.allow-bean-definition-overriding=true
java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf} java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf}
# V2 features # V2 features
v2=true v2=false

View File

@ -59,7 +59,7 @@ security:
idpCert: classpath:okta.cert # The certificate your Provider will use to authenticate your app's SAML authentication requests. Provided by your Provider idpCert: classpath:okta.cert # The certificate your Provider will use to authenticate your app's SAML authentication requests. Provided by your Provider
privateKey: classpath:saml-private-key.key # Your private key. Generated from your keypair privateKey: classpath:saml-private-key.key # Your private key. Generated from your keypair
spCert: classpath:saml-public-cert.crt # Your signing certificate. Generated from your keypair spCert: classpath:saml-public-cert.crt # Your signing certificate. Generated from your keypair
jwt: jwt: # This feature is currently under development and not yet fully supported. Do not use in production.
persistence: true # Set to 'true' to enable JWT key store persistence: true # Set to 'true' to enable JWT key store
enableKeyRotation: true # Set to 'true' to enable key pair rotation enableKeyRotation: true # Set to 'true' to enable key pair rotation
enableKeyCleanup: true # Set to 'true' to enable key pair cleanup enableKeyCleanup: true # Set to 'true' to enable key pair cleanup

View File

@ -1,11 +1,5 @@
// Authentication utility for cookie-based JWT // Authentication utility for cookie-based JWT
window.JWTManager = { window.JWTManager = {
// Check if user is authenticated (simplified for cookie-based auth)
isAuthenticated: function() {
// With cookie-based JWT, we rely on server-side validation
// This is a simplified check - actual authentication status is determined server-side
return document.cookie.includes('stirling_jwt=');
},
// Logout - clear cookies and redirect to login // Logout - clear cookies and redirect to login
logout: function() { logout: function() {
@ -72,7 +66,3 @@ window.fetchWithCsrf = async function(url, options = {}) {
return response; return response;
} }
// Enhanced fetch function that always includes JWT
window.fetchWithJWT = async function(url, options = {}) {
return window.fetchWithCsrf(url, options);
}

View File

@ -21,27 +21,9 @@
// Clean up any JWT tokens from URL (OAuth flow) // Clean up any JWT tokens from URL (OAuth flow)
cleanupTokenFromUrl(); cleanupTokenFromUrl();
// Check if user is authenticated via cookie // Authentication is handled server-side
if (window.JWTManager.isAuthenticated()) { // If user is not authenticated, server will redirect to login
console.log('User is authenticated with JWT cookie'); console.log('JWT initialization complete - authentication handled server-side');
} else {
console.log('User is not authenticated');
// Only redirect to login if we're not already on login/register pages
const currentPath = window.location.pathname;
const currentSearch = window.location.search;
// Don't redirect if we're on logout page or already being logged out
if (!currentPath.includes('/login') &&
!currentPath.includes('/register') &&
!currentPath.includes('/oauth') &&
!currentPath.includes('/saml') &&
!currentPath.includes('/error') &&
!currentSearch.includes('logout=true')) {
// Redirect to login after a short delay to allow other scripts to load
setTimeout(() => {
window.location.href = '/login';
}, 100);
}
}
} }
// No form enhancement needed for cookie-based JWT // No form enhancement needed for cookie-based JWT
@ -51,41 +33,12 @@
// No additional processing needed // No additional processing needed
} }
// Add logout functionality to logout buttons
function enhanceLogoutButtons() {
document.addEventListener('click', function(event) {
const element = event.target;
// Check if clicked element is a logout button/link
if (element.matches('a[href="/logout"], button[data-action="logout"], .logout-btn')) {
event.preventDefault();
window.JWTManager.logout();
}
});
}
// Initialize when DOM is ready // Initialize when DOM is ready
if (document.readyState === 'loading') { if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
initializeJWT(); initializeJWT();
enhanceLogoutButtons();
}); });
} else { } else {
initializeJWT(); initializeJWT();
enhanceLogoutButtons();
} }
// Handle page visibility changes to check token expiration
document.addEventListener('visibilitychange', function() {
if (!document.hidden && !window.JWTManager.isAuthenticated()) {
// Token expired while page was hidden, redirect to login
const currentPath = window.location.pathname;
if (!currentPath.includes('/login') &&
!currentPath.includes('/register') &&
!currentPath.includes('/oauth') &&
!currentPath.includes('/saml')) {
window.location.href = '/login';
}
}
});
})(); })();

View File

@ -54,15 +54,11 @@ public class CustomAuthenticationSuccessHandler
loginAttemptService.loginSucceeded(userName); loginAttemptService.loginSucceeded(userName);
if (jwtService.isJwtEnabled()) { if (jwtService.isJwtEnabled()) {
try { String jwt =
String jwt = jwtService.generateToken(
jwtService.generateToken( authentication, Map.of("authType", AuthenticationType.WEB));
authentication, Map.of("authType", AuthenticationType.WEB)); jwtService.addToken(response, jwt);
jwtService.addToken(response, jwt); log.debug("JWT generated for user: {}", userName);
log.debug("JWT generated for user: {}", userName);
} catch (Exception e) {
log.error("Failed to generate JWT token for user: {}", userName, e);
}
getRedirectStrategy().sendRedirect(request, response, "/"); getRedirectStrategy().sendRedirect(request, response, "/");
} else { } else {

View File

@ -1,6 +1,5 @@
package stirling.software.proprietary.security.service; package stirling.software.proprietary.security.service;
import jakarta.annotation.PostConstruct;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
@ -20,7 +19,7 @@ import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cache.Cache; import org.springframework.cache.Cache;
import org.springframework.cache.CacheManager; import org.springframework.cache.CacheManager;
@ -28,6 +27,11 @@ import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.caffeine.CaffeineCache; import org.springframework.cache.caffeine.CaffeineCache;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import jakarta.annotation.PostConstruct;
import lombok.extern.slf4j.Slf4j;
import stirling.software.common.configuration.InstallationPathConfig; import stirling.software.common.configuration.InstallationPathConfig;
import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.model.JwtVerificationKey; import stirling.software.proprietary.security.model.JwtVerificationKey;

View File

@ -7,6 +7,7 @@ import java.security.spec.InvalidKeySpecException;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import stirling.software.proprietary.security.model.JwtVerificationKey; import stirling.software.proprietary.security.model.JwtVerificationKey;
public interface KeyPersistenceServiceInterface { public interface KeyPersistenceServiceInterface {

View File

@ -1,18 +1,22 @@
package stirling.software.proprietary.security; package stirling.software.proprietary.security;
import jakarta.servlet.http.HttpServletRequest; import static org.mockito.Mockito.*;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks; import org.mockito.InjectMocks;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.common.configuration.AppConfig; import stirling.software.common.configuration.AppConfig;
import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.JwtServiceInterface;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class CustomLogoutSuccessHandlerTest { class CustomLogoutSuccessHandlerTest {

View File

@ -1,32 +1,30 @@
package stirling.software.proprietary.security; package stirling.software.proprietary.security;
import jakarta.servlet.http.HttpServletRequest; import static org.mockito.Mockito.*;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks; import org.mockito.InjectMocks;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException; import jakarta.servlet.http.HttpServletRequest;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; import jakarta.servlet.http.HttpServletResponse;
import static org.mockito.Mockito.*; import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class JwtAuthenticationEntryPointTest { class JwtAuthenticationEntryPointTest {
@Mock @Mock private HttpServletRequest request;
private HttpServletRequest request;
@Mock @Mock private HttpServletResponse response;
private HttpServletResponse response;
@Mock @Mock private AuthenticationFailureException authException;
private AuthenticationFailureException authException;
@InjectMocks @InjectMocks private JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint;
private JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint;
@Test @Test
void testCommence() throws IOException { void testCommence() throws IOException {

View File

@ -1,12 +1,21 @@
package stirling.software.proprietary.security.filter; package stirling.software.proprietary.security.filter;
import jakarta.servlet.FilterChain; import static org.junit.jupiter.api.Assertions.assertEquals;
import jakarta.servlet.ServletException; import static org.junit.jupiter.api.Assertions.assertThrows;
import jakarta.servlet.http.HttpServletRequest; import static org.mockito.ArgumentMatchers.any;
import jakarta.servlet.http.HttpServletResponse; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@ -20,61 +29,43 @@ import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import stirling.software.proprietary.security.service.CustomUserDetailsService; import stirling.software.proprietary.security.service.CustomUserDetailsService;
import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.JwtServiceInterface;
import stirling.software.proprietary.security.service.UserService; import stirling.software.proprietary.security.service.UserService;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@Disabled @Disabled
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class JwtAuthenticationFilterTest { class JwtAuthenticationFilterTest {
@Mock @Mock private JwtServiceInterface jwtService;
private JwtServiceInterface jwtService;
@Mock @Mock private CustomUserDetailsService userDetailsService;
private CustomUserDetailsService userDetailsService;
@Mock @Mock private UserService userService;
private UserService userService;
@Mock @Mock private ApplicationProperties.Security securityProperties;
private ApplicationProperties.Security securityProperties;
@Mock @Mock private HttpServletRequest request;
private HttpServletRequest request;
@Mock @Mock private HttpServletResponse response;
private HttpServletResponse response;
@Mock @Mock private FilterChain filterChain;
private FilterChain filterChain;
@Mock @Mock private UserDetails userDetails;
private UserDetails userDetails;
@Mock @Mock private SecurityContext securityContext;
private SecurityContext securityContext;
@Mock @Mock private AuthenticationEntryPoint authenticationEntryPoint;
private AuthenticationEntryPoint authenticationEntryPoint;
@InjectMocks @InjectMocks private JwtAuthenticationFilter jwtAuthenticationFilter;
private JwtAuthenticationFilter jwtAuthenticationFilter;
@Test @Test
void shouldNotAuthenticateWhenJwtDisabled() throws ServletException, IOException { void shouldNotAuthenticateWhenJwtDisabled() throws ServletException, IOException {
@ -113,21 +104,29 @@ class JwtAuthenticationFilterTest {
when(userDetails.getAuthorities()).thenReturn(Collections.emptyList()); when(userDetails.getAuthorities()).thenReturn(Collections.emptyList());
when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails);
try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder =
mockStatic(SecurityContextHolder.class)) {
UsernamePasswordAuthenticationToken authToken = UsernamePasswordAuthenticationToken authToken =
new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities()); new UsernamePasswordAuthenticationToken(
userDetails, null, userDetails.getAuthorities());
when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken); when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken);
mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); mockedSecurityContextHolder
when(jwtService.generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims))).thenReturn(newToken); .when(SecurityContextHolder::getContext)
.thenReturn(securityContext);
when(jwtService.generateToken(
any(UsernamePasswordAuthenticationToken.class), eq(claims)))
.thenReturn(newToken);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(jwtService).validateToken(token); verify(jwtService).validateToken(token);
verify(jwtService).extractClaims(token); verify(jwtService).extractClaims(token);
verify(userDetailsService).loadUserByUsername(username); verify(userDetailsService).loadUserByUsername(username);
verify(securityContext).setAuthentication(any(UsernamePasswordAuthenticationToken.class)); verify(securityContext)
verify(jwtService).generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims)); .setAuthentication(any(UsernamePasswordAuthenticationToken.class));
verify(jwtService)
.generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims));
verify(jwtService).addToken(response, newToken); verify(jwtService).addToken(response, newToken);
verify(filterChain).doFilter(request, response); verify(filterChain).doFilter(request, response);
} }
@ -154,12 +153,15 @@ class JwtAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/"); when(request.getContextPath()).thenReturn("/");
when(jwtService.extractToken(request)).thenReturn(token); when(jwtService.extractToken(request)).thenReturn(token);
doThrow(new AuthenticationFailureException("Invalid token")).when(jwtService).validateToken(token); doThrow(new AuthenticationFailureException("Invalid token"))
.when(jwtService)
.validateToken(token);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(jwtService).validateToken(token); verify(jwtService).validateToken(token);
verify(authenticationEntryPoint).commence(eq(request), eq(response), any(AuthenticationFailureException.class)); verify(authenticationEntryPoint)
.commence(eq(request), eq(response), any(AuthenticationFailureException.class));
verify(filterChain, never()).doFilter(request, response); verify(filterChain, never()).doFilter(request, response);
} }
@ -171,7 +173,9 @@ class JwtAuthenticationFilterTest {
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/"); when(request.getContextPath()).thenReturn("/");
when(jwtService.extractToken(request)).thenReturn(token); when(jwtService.extractToken(request)).thenReturn(token);
doThrow(new AuthenticationFailureException("The token has expired")).when(jwtService).validateToken(token); doThrow(new AuthenticationFailureException("The token has expired"))
.when(jwtService)
.validateToken(token);
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
@ -194,11 +198,19 @@ class JwtAuthenticationFilterTest {
when(jwtService.extractClaims(token)).thenReturn(claims); when(jwtService.extractClaims(token)).thenReturn(claims);
when(userDetailsService.loadUserByUsername(username)).thenReturn(null); when(userDetailsService.loadUserByUsername(username)).thenReturn(null);
try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder = mockStatic(SecurityContextHolder.class)) { try (MockedStatic<SecurityContextHolder> mockedSecurityContextHolder =
mockStatic(SecurityContextHolder.class)) {
when(securityContext.getAuthentication()).thenReturn(null); when(securityContext.getAuthentication()).thenReturn(null);
mockedSecurityContextHolder.when(SecurityContextHolder::getContext).thenReturn(securityContext); mockedSecurityContextHolder
.when(SecurityContextHolder::getContext)
.thenReturn(securityContext);
UsernameNotFoundException result = assertThrows(UsernameNotFoundException.class, () -> jwtAuthenticationFilter.doFilterInternal(request, response, filterChain)); UsernameNotFoundException result =
assertThrows(
UsernameNotFoundException.class,
() ->
jwtAuthenticationFilter.doFilterInternal(
request, response, filterChain));
assertEquals("User not found: " + username, result.getMessage()); assertEquals("User not found: " + username, result.getMessage());
verify(userDetailsService).loadUserByUsername(username); verify(userDetailsService).loadUserByUsername(username);
@ -207,7 +219,8 @@ class JwtAuthenticationFilterTest {
} }
@Test @Test
void testAuthenticationEntryPointCalledWithCorrectException() throws ServletException, IOException { void testAuthenticationEntryPointCalledWithCorrectException()
throws ServletException, IOException {
when(jwtService.isJwtEnabled()).thenReturn(true); when(jwtService.isJwtEnabled()).thenReturn(true);
when(request.getRequestURI()).thenReturn("/protected"); when(request.getRequestURI()).thenReturn("/protected");
when(request.getContextPath()).thenReturn("/"); when(request.getContextPath()).thenReturn("/");
@ -215,9 +228,15 @@ class JwtAuthenticationFilterTest {
jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); jwtAuthenticationFilter.doFilterInternal(request, response, filterChain);
verify(authenticationEntryPoint).commence(eq(request), eq(response), argThat(exception -> verify(authenticationEntryPoint)
exception.getMessage().equals("JWT is missing from the request") .commence(
)); eq(request),
eq(response),
argThat(
exception ->
exception
.getMessage()
.equals("JWT is missing from the request")));
verify(filterChain, never()).doFilter(request, response); verify(filterChain, never()).doFilter(request, response);
} }
} }

View File

@ -1,23 +1,5 @@
package stirling.software.proprietary.security.saml2; package stirling.software.proprietary.security.saml2;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
@ -28,6 +10,28 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class JwtSaml2AuthenticationRequestRepositoryTest { class JwtSaml2AuthenticationRequestRepositoryTest {
@ -35,19 +39,18 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
private Map<String, String> tokenStore; private Map<String, String> tokenStore;
@Mock @Mock private JwtServiceInterface jwtService;
private JwtServiceInterface jwtService;
@Mock @Mock private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository; private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
tokenStore = new ConcurrentHashMap<>(); tokenStore = new ConcurrentHashMap<>();
jwtSaml2AuthenticationRequestRepository = new JwtSaml2AuthenticationRequestRepository( jwtSaml2AuthenticationRequestRepository =
tokenStore, jwtService, relyingPartyRegistrationRepository); new JwtSaml2AuthenticationRequestRepository(
tokenStore, jwtService, relyingPartyRegistrationRepository);
} }
@Test @Test
@ -71,7 +74,8 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId); when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId);
when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token); when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(authRequest, request, response); jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(
authRequest, request, response);
verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState); verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState);
verify(response).addHeader(SAML_REQUEST_TOKEN, relayState); verify(response).addHeader(SAML_REQUEST_TOKEN, relayState);
@ -94,20 +98,23 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
var assertingPartyMetadata = mock(AssertingPartyMetadata.class); var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState"; String relayState = "testRelayState";
String token = "testToken"; String token = "testToken";
Map<String, Object> claims = Map.of( Map<String, Object> claims =
"id", "testId", Map.of(
"relyingPartyRegistrationId", "stirling-pdf", "id", "testId",
"authenticationRequestUri", "example.com/authnRequest", "relyingPartyRegistrationId", "stirling-pdf",
"samlRequest", "testSamlRequest", "authenticationRequestUri", "example.com/authnRequest",
"relayState", relayState "samlRequest", "testSamlRequest",
); "relayState", relayState);
when(request.getParameter("RelayState")).thenReturn(relayState); when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractClaims(token)).thenReturn(claims); when(jwtService.extractClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration); when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf"))
.thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata); when(relyingPartyRegistration.getAssertingPartyMetadata())
when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso"); .thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation())
.thenReturn("https://example.com/sso");
tokenStore.put(relayState, token); tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
@ -142,17 +149,18 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
var request = mock(MockHttpServletRequest.class); var request = mock(MockHttpServletRequest.class);
String relayState = "testRelayState"; String relayState = "testRelayState";
String token = "testToken"; String token = "testToken";
Map<String, Object> claims = Map.of( Map<String, Object> claims =
"id", "testId", Map.of(
"relyingPartyRegistrationId", "stirling-pdf", "id", "testId",
"authenticationRequestUri", "example.com/authnRequest", "relyingPartyRegistrationId", "stirling-pdf",
"samlRequest", "testSamlRequest", "authenticationRequestUri", "example.com/authnRequest",
"relayState", relayState "samlRequest", "testSamlRequest",
); "relayState", relayState);
when(request.getParameter("RelayState")).thenReturn(relayState); when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractClaims(token)).thenReturn(claims); when(jwtService.extractClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(null); when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf"))
.thenReturn(null);
tokenStore.put(relayState, token); tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
@ -168,23 +176,28 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
var assertingPartyMetadata = mock(AssertingPartyMetadata.class); var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState"; String relayState = "testRelayState";
String token = "testToken"; String token = "testToken";
Map<String, Object> claims = Map.of( Map<String, Object> claims =
"id", "testId", Map.of(
"relyingPartyRegistrationId", "stirling-pdf", "id", "testId",
"authenticationRequestUri", "example.com/authnRequest", "relyingPartyRegistrationId", "stirling-pdf",
"samlRequest", "testSamlRequest", "authenticationRequestUri", "example.com/authnRequest",
"relayState", relayState "samlRequest", "testSamlRequest",
); "relayState", relayState);
when(request.getParameter("RelayState")).thenReturn(relayState); when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractClaims(token)).thenReturn(claims); when(jwtService.extractClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration); when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf"))
.thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata); when(relyingPartyRegistration.getAssertingPartyMetadata())
when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso"); .thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation())
.thenReturn("https://example.com/sso");
tokenStore.put(relayState, token); tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNotNull(result); assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState)); assertFalse(tokenStore.containsKey(relayState));
@ -196,7 +209,9 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
var response = mock(HttpServletResponse.class); var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn(null); when(request.getParameter("RelayState")).thenReturn(null);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNull(result); assertNull(result);
} }
@ -207,7 +222,9 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
var response = mock(HttpServletResponse.class); var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNull(result); assertNull(result);
} }
@ -220,7 +237,9 @@ class JwtSaml2AuthenticationRequestRepositoryTest {
when(request.getParameter("RelayState")).thenReturn(relayState); when(request.getParameter("RelayState")).thenReturn(relayState);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); var result =
jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(
request, response);
assertNull(result); assertNull(result);
assertFalse(tokenStore.containsKey(relayState)); assertFalse(tokenStore.containsKey(relayState));

View File

@ -1,29 +1,5 @@
package stirling.software.proprietary.security.service; package stirling.software.proprietary.security.service;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Collections;
import java.util.Optional;
import stirling.software.proprietary.security.model.JwtVerificationKey;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication;
import stirling.software.proprietary.security.model.User;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -39,23 +15,44 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.proprietary.security.model.JwtVerificationKey;
import stirling.software.proprietary.security.model.User;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class JwtServiceTest { class JwtServiceTest {
@Mock @Mock private Authentication authentication;
private Authentication authentication;
@Mock @Mock private User userDetails;
private User userDetails;
@Mock @Mock private HttpServletRequest request;
private HttpServletRequest request;
@Mock @Mock private HttpServletResponse response;
private HttpServletResponse response;
@Mock @Mock private KeyPersistenceServiceInterface keystoreService;
private KeyPersistenceServiceInterface keystoreService;
private JwtService jwtService; private JwtService jwtService;
private KeyPair testKeyPair; private KeyPair testKeyPair;
@ -69,7 +66,8 @@ class JwtServiceTest {
testKeyPair = keyPairGenerator.generateKeyPair(); testKeyPair = keyPairGenerator.generateKeyPair();
// Create test verification key // Create test verification key
String encodedPublicKey = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); String encodedPublicKey =
Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
testVerificationKey = new JwtVerificationKey("test-key-id", encodedPublicKey); testVerificationKey = new JwtVerificationKey("test-key-id", encodedPublicKey);
jwtService = new JwtService(true, keystoreService); jwtService = new JwtService(true, keystoreService);
@ -81,7 +79,8 @@ class JwtServiceTest {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(userDetails); when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username); when(userDetails.getUsername()).thenReturn(username);
@ -101,7 +100,8 @@ class JwtServiceTest {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(userDetails); when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username); when(userDetails.getUsername()).thenReturn(username);
@ -120,7 +120,8 @@ class JwtServiceTest {
void testValidateTokenSuccess() throws Exception { void testValidateTokenSuccess() throws Exception {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(userDetails); when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn("testuser"); when(userDetails.getUsername()).thenReturn("testuser");
@ -132,21 +133,28 @@ class JwtServiceTest {
@Test @Test
void testValidateTokenWithInvalidToken() throws Exception { void testValidateTokenWithInvalidToken() throws Exception {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
assertThrows(AuthenticationFailureException.class, () -> { assertThrows(
jwtService.validateToken("invalid-token"); AuthenticationFailureException.class,
}); () -> {
jwtService.validateToken("invalid-token");
});
} }
@Test @Test
void testValidateTokenWithMalformedToken() throws Exception { void testValidateTokenWithMalformedToken() throws Exception {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> { AuthenticationFailureException exception =
jwtService.validateToken("malformed.token"); assertThrows(
}); AuthenticationFailureException.class,
() -> {
jwtService.validateToken("malformed.token");
});
assertTrue(exception.getMessage().contains("Invalid")); assertTrue(exception.getMessage().contains("Invalid"));
} }
@ -154,13 +162,19 @@ class JwtServiceTest {
@Test @Test
void testValidateTokenWithEmptyToken() throws Exception { void testValidateTokenWithEmptyToken() throws Exception {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
AuthenticationFailureException exception = assertThrows(AuthenticationFailureException.class, () -> { AuthenticationFailureException exception =
jwtService.validateToken(""); assertThrows(
}); AuthenticationFailureException.class,
() -> {
jwtService.validateToken("");
});
assertTrue(exception.getMessage().contains("Claims are empty") || exception.getMessage().contains("Invalid")); assertTrue(
exception.getMessage().contains("Claims are empty")
|| exception.getMessage().contains("Invalid"));
} }
@Test @Test
@ -171,7 +185,8 @@ class JwtServiceTest {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(user); when(authentication.getPrincipal()).thenReturn(user);
when(user.getUsername()).thenReturn(username); when(user.getUsername()).thenReturn(username);
@ -183,9 +198,12 @@ class JwtServiceTest {
@Test @Test
void testExtractUsernameWithInvalidToken() throws Exception { void testExtractUsernameWithInvalidToken() throws Exception {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
assertThrows(AuthenticationFailureException.class, () -> jwtService.extractUsername("invalid-token")); assertThrows(
AuthenticationFailureException.class,
() -> jwtService.extractUsername("invalid-token"));
} }
@Test @Test
@ -195,7 +213,8 @@ class JwtServiceTest {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(userDetails); when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username); when(userDetails.getUsername()).thenReturn(username);
@ -211,15 +230,18 @@ class JwtServiceTest {
@Test @Test
void testExtractClaimsWithInvalidToken() throws Exception { void testExtractClaimsWithInvalidToken() throws Exception {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
assertThrows(AuthenticationFailureException.class, () -> jwtService.extractClaims("invalid-token")); assertThrows(
AuthenticationFailureException.class,
() -> jwtService.extractClaims("invalid-token"));
} }
@Test @Test
void testExtractTokenWithCookie() { void testExtractTokenWithCookie() {
String token = "test-token"; String token = "test-token";
Cookie[] cookies = { new Cookie("stirling_jwt", token) }; Cookie[] cookies = {new Cookie("stirling_jwt", token)};
when(request.getCookies()).thenReturn(cookies); when(request.getCookies()).thenReturn(cookies);
assertEquals(token, jwtService.extractToken(request)); assertEquals(token, jwtService.extractToken(request));
@ -299,7 +321,8 @@ class JwtServiceTest {
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(userDetails); when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username); when(userDetails.getUsername()).thenReturn(username);
@ -307,7 +330,9 @@ class JwtServiceTest {
String token = jwtService.generateToken(authentication, claims); String token = jwtService.generateToken(authentication, claims);
// Mock extraction of key ID and verification (lenient to avoid unused stubbing) // Mock extraction of key ID and verification (lenient to avoid unused stubbing)
lenient().when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); lenient()
.when(keystoreService.getKeyPair("test-key-id"))
.thenReturn(Optional.of(testKeyPair));
// Verify token can be validated // Verify token can be validated
assertDoesNotThrow(() -> jwtService.validateToken(token)); assertDoesNotThrow(() -> jwtService.validateToken(token));
@ -322,7 +347,8 @@ class JwtServiceTest {
// First, generate a token successfully // First, generate a token successfully
when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); when(keystoreService.getActiveKey()).thenReturn(testVerificationKey);
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair));
when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())).thenReturn(testKeyPair.getPublic()); when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey()))
.thenReturn(testKeyPair.getPublic());
when(authentication.getPrincipal()).thenReturn(userDetails); when(authentication.getPrincipal()).thenReturn(userDetails);
when(userDetails.getUsername()).thenReturn(username); when(userDetails.getUsername()).thenReturn(username);
@ -330,8 +356,10 @@ class JwtServiceTest {
// Now mock the scenario for validation - key not found, but fallback works // Now mock the scenario for validation - key not found, but fallback works
// Create a fallback key pair that can be used // Create a fallback key pair that can be used
JwtVerificationKey fallbackKey = new JwtVerificationKey("fallback-key", JwtVerificationKey fallbackKey =
Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded())); new JwtVerificationKey(
"fallback-key",
Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()));
// Mock the specific key lookup to fail, but the active key should work // Mock the specific key lookup to fail, but the active key should work
when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.empty()); when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.empty());
@ -351,7 +379,8 @@ class JwtServiceTest {
JwtService testService = new JwtService(true, keystoreService); JwtService testService = new JwtService(true, keystoreService);
// Set the secureCookie field using reflection // Set the secureCookie field using reflection
java.lang.reflect.Field secureCookieField = JwtService.class.getDeclaredField("secureCookie"); java.lang.reflect.Field secureCookieField =
JwtService.class.getDeclaredField("secureCookie");
secureCookieField.setAccessible(true); secureCookieField.setAccessible(true);
secureCookieField.set(testService, secureCookie); secureCookieField.set(testService, secureCookie);

View File

@ -1,5 +1,13 @@
package stirling.software.proprietary.security.service; package stirling.software.proprietary.security.service;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
@ -8,6 +16,7 @@ import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.Base64; import java.util.Base64;
import java.util.Optional; import java.util.Optional;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@ -19,31 +28,21 @@ import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.cache.CacheManager; import org.springframework.cache.CacheManager;
import org.springframework.cache.concurrent.ConcurrentMapCacheManager; import org.springframework.cache.concurrent.ConcurrentMapCacheManager;
import stirling.software.common.configuration.InstallationPathConfig; import stirling.software.common.configuration.InstallationPathConfig;
import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.security.model.JwtVerificationKey; import stirling.software.proprietary.security.model.JwtVerificationKey;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class KeyPersistenceServiceInterfaceTest { class KeyPersistenceServiceInterfaceTest {
@Mock @Mock private ApplicationProperties applicationProperties;
private ApplicationProperties applicationProperties;
@Mock @Mock private ApplicationProperties.Security security;
private ApplicationProperties.Security security;
@Mock @Mock private ApplicationProperties.Security.Jwt jwtConfig;
private ApplicationProperties.Security.Jwt jwtConfig;
@TempDir @TempDir Path tempDir;
Path tempDir;
private KeyPersistenceService keyPersistenceService; private KeyPersistenceService keyPersistenceService;
private KeyPair testKeyPair; private KeyPair testKeyPair;
@ -67,8 +66,11 @@ class KeyPersistenceServiceInterfaceTest {
void testKeystoreEnabled(boolean keystoreEnabled) { void testKeystoreEnabled(boolean keystoreEnabled) {
when(jwtConfig.isEnableKeystore()).thenReturn(keystoreEnabled); when(jwtConfig.isEnableKeystore()).thenReturn(keystoreEnabled);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
assertEquals(keystoreEnabled, keyPersistenceService.isKeystoreEnabled()); assertEquals(keystoreEnabled, keyPersistenceService.isKeystoreEnabled());
@ -77,8 +79,11 @@ class KeyPersistenceServiceInterfaceTest {
@Test @Test
void testGetActiveKeypairWhenNoActiveKeyExists() { void testGetActiveKeypairWhenNoActiveKeyExists() {
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
keyPersistenceService.initializeKeystore(); keyPersistenceService.initializeKeystore();
@ -93,16 +98,21 @@ class KeyPersistenceServiceInterfaceTest {
@Test @Test
void testGetActiveKeyPairWithExistingKey() throws Exception { void testGetActiveKeyPairWithExistingKey() throws Exception {
String keyId = "test-key-2024-01-01-120000"; String keyId = "test-key-2024-01-01-120000";
String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); String publicKeyBase64 =
String privateKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded()); Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
String privateKeyBase64 =
Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded());
JwtVerificationKey existingKey = new JwtVerificationKey(keyId, publicKeyBase64); JwtVerificationKey existingKey = new JwtVerificationKey(keyId, publicKeyBase64);
Path keyFile = tempDir.resolve(keyId + ".key"); Path keyFile = tempDir.resolve(keyId + ".key");
Files.writeString(keyFile, privateKeyBase64); Files.writeString(keyFile, privateKeyBase64);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
keyPersistenceService.initializeKeystore(); keyPersistenceService.initializeKeystore();
@ -116,19 +126,27 @@ class KeyPersistenceServiceInterfaceTest {
@Test @Test
void testGetKeyPair() throws Exception { void testGetKeyPair() throws Exception {
String keyId = "test-key-123"; String keyId = "test-key-123";
String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); String publicKeyBase64 =
String privateKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded()); Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
String privateKeyBase64 =
Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded());
JwtVerificationKey signingKey = new JwtVerificationKey(keyId, publicKeyBase64); JwtVerificationKey signingKey = new JwtVerificationKey(keyId, publicKeyBase64);
Path keyFile = tempDir.resolve(keyId + ".key"); Path keyFile = tempDir.resolve(keyId + ".key");
Files.writeString(keyFile, privateKeyBase64); Files.writeString(keyFile, privateKeyBase64);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
keyPersistenceService.getClass().getDeclaredField("verifyingKeyCache").setAccessible(true); keyPersistenceService
.getClass()
.getDeclaredField("verifyingKeyCache")
.setAccessible(true);
var cache = cacheManager.getCache("verifyingKeys"); var cache = cacheManager.getCache("verifyingKeys");
cache.put(keyId, signingKey); cache.put(keyId, signingKey);
@ -144,8 +162,11 @@ class KeyPersistenceServiceInterfaceTest {
void testGetKeyPairNotFound() { void testGetKeyPairNotFound() {
String keyId = "non-existent-key"; String keyId = "non-existent-key";
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
Optional<KeyPair> result = keyPersistenceService.getKeyPair(keyId); Optional<KeyPair> result = keyPersistenceService.getKeyPair(keyId);
@ -158,8 +179,11 @@ class KeyPersistenceServiceInterfaceTest {
void testGetKeyPairWhenKeystoreDisabled() { void testGetKeyPairWhenKeystoreDisabled() {
when(jwtConfig.isEnableKeystore()).thenReturn(false); when(jwtConfig.isEnableKeystore()).thenReturn(false);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
Optional<KeyPair> result = keyPersistenceService.getKeyPair("any-key"); Optional<KeyPair> result = keyPersistenceService.getKeyPair("any-key");
@ -170,8 +194,11 @@ class KeyPersistenceServiceInterfaceTest {
@Test @Test
void testInitializeKeystoreCreatesDirectory() throws IOException { void testInitializeKeystoreCreatesDirectory() throws IOException {
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
keyPersistenceService.initializeKeystore(); keyPersistenceService.initializeKeystore();
@ -183,12 +210,16 @@ class KeyPersistenceServiceInterfaceTest {
@Test @Test
void testLoadExistingKeypairWithMissingPrivateKeyFile() throws Exception { void testLoadExistingKeypairWithMissingPrivateKeyFile() throws Exception {
String keyId = "test-key-missing-file"; String keyId = "test-key-missing-file";
String publicKeyBase64 = Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); String publicKeyBase64 =
Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded());
JwtVerificationKey existingKey = new JwtVerificationKey(keyId, publicKeyBase64); JwtVerificationKey existingKey = new JwtVerificationKey(keyId, publicKeyBase64);
try (MockedStatic<InstallationPathConfig> mockedStatic = mockStatic(InstallationPathConfig.class)) { try (MockedStatic<InstallationPathConfig> mockedStatic =
mockedStatic.when(InstallationPathConfig::getPrivateKeyPath).thenReturn(tempDir.toString()); mockStatic(InstallationPathConfig.class)) {
mockedStatic
.when(InstallationPathConfig::getPrivateKeyPath)
.thenReturn(tempDir.toString());
keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager);
keyPersistenceService.initializeKeystore(); keyPersistenceService.initializeKeystore();