diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 6e006423a..bc5358b85 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -5,7 +5,13 @@ "Bash(mkdir:*)", "Bash(./gradlew:*)", "Bash(grep:*)", - "Bash(cat:*)" + "Bash(cat:*)", + "Bash(find:*)", + "Bash(grep:*)", + "Bash(rg:*)", + "Bash(strings:*)", + "Bash(pkill:*)", + "Bash(true)" ], "deny": [] } diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d87e478d3..c229ee40e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -147,7 +147,9 @@ jobs: - name: Generate OpenAPI documentation run: ./gradlew :stirling-pdf:generateOpenApiDocs - + env: + DISABLE_ADDITIONAL_FEATURES: true + - name: Upload OpenAPI Documentation uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: diff --git a/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java b/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java index f611f42ca..e24a92d6a 100644 --- a/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java +++ b/app/common/src/main/java/stirling/software/common/configuration/AppConfig.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Locale; import java.util.Properties; import java.util.function.Predicate; +import java.util.stream.Stream; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -51,6 +52,14 @@ public class AppConfig { @Value("${server.port:8080}") private String serverPort; + @Value("${v2}") + public boolean v2Enabled; + + @Bean + public boolean v2Enabled() { + return v2Enabled; + } + @Bean @ConditionalOnProperty(name = "system.customHTMLFiles", havingValue = "true") public SpringTemplateEngine templateEngine(ResourceLoader resourceLoader) { @@ -120,7 +129,7 @@ public class AppConfig { public boolean rateLimit() { String rateLimit = System.getProperty("rateLimit"); if (rateLimit == null) rateLimit = System.getenv("rateLimit"); - return (rateLimit != null) ? Boolean.valueOf(rateLimit) : false; + return Boolean.parseBoolean(rateLimit); } @Bean(name = "RunningInDocker") @@ -140,8 +149,8 @@ public class AppConfig { if (!Files.exists(mountInfo)) { return true; } - try { - return Files.lines(mountInfo).anyMatch(line -> line.contains(" /configs ")); + try (Stream lines = Files.lines(mountInfo)) { + return lines.anyMatch(line -> line.contains(" /configs ")); } catch (IOException e) { return false; } diff --git a/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java b/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java index 247a012ad..64fbc41b7 100644 --- a/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java +++ b/app/common/src/main/java/stirling/software/common/configuration/InstallationPathConfig.java @@ -25,6 +25,7 @@ public class InstallationPathConfig { private static final String STATIC_PATH; private static final String TEMPLATES_PATH; private static final String SIGNATURES_PATH; + private static final String PRIVATE_KEY_PATH; static { BASE_PATH = initializeBasePath(); @@ -45,6 +46,7 @@ public class InstallationPathConfig { STATIC_PATH = CUSTOM_FILES_PATH + "static" + File.separator; TEMPLATES_PATH = CUSTOM_FILES_PATH + "templates" + File.separator; SIGNATURES_PATH = CUSTOM_FILES_PATH + "signatures" + File.separator; + PRIVATE_KEY_PATH = CONFIG_PATH + "db" + File.separator + "keys" + File.separator; } private static String initializeBasePath() { @@ -120,4 +122,8 @@ public class InstallationPathConfig { public static String getSignaturesPath() { return SIGNATURES_PATH; } + + public static String getPrivateKeyPath() { + return PRIVATE_KEY_PATH; + } } diff --git a/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java b/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java index ee893c575..5845c6d16 100644 --- a/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java +++ b/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java @@ -119,6 +119,7 @@ public class ApplicationProperties { private long loginResetTimeMinutes; private String loginMethod = "all"; private String customGlobalAPIKey; + private Jwt jwt = new Jwt(); public Boolean isAltLogin() { return saml2.getEnabled() || oauth2.getEnabled(); @@ -298,6 +299,15 @@ public class ApplicationProperties { } } } + + @Data + public static class Jwt { + private boolean enableKeystore = true; + private boolean enableKeyRotation = false; + private boolean enableKeyCleanup = true; + private int keyRetentionDays = 7; + private boolean secureCookie; + } } @Data diff --git a/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java b/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java index 654c78fe9..239976b66 100644 --- a/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java +++ b/app/common/src/main/java/stirling/software/common/util/RequestUriUtils.java @@ -14,8 +14,10 @@ public class RequestUriUtils { || requestURI.startsWith(contextPath + "/images/") || requestURI.startsWith(contextPath + "/public/") || requestURI.startsWith(contextPath + "/pdfjs/") + || requestURI.startsWith(contextPath + "/pdfjs-legacy/") || requestURI.startsWith(contextPath + "/login") || requestURI.startsWith(contextPath + "/error") + || requestURI.startsWith(contextPath + "/favicon") || requestURI.endsWith(".svg") || requestURI.endsWith(".png") || requestURI.endsWith(".ico") diff --git a/app/core/src/main/resources/application.properties b/app/core/src/main/resources/application.properties index ea30bf78e..0ca864985 100644 --- a/app/core/src/main/resources/application.properties +++ b/app/core/src/main/resources/application.properties @@ -5,7 +5,7 @@ logging.level.org.eclipse.jetty=WARN #logging.level.org.springframework.security.saml2=TRACE #logging.level.org.springframework.security=DEBUG #logging.level.org.opensaml=DEBUG -#logging.level.stirling.software.SPDF.config.security: DEBUG +#logging.level.stirling.software.proprietary.security=DEBUG logging.level.com.zaxxer.hikari=WARN spring.jpa.open-in-view=false server.forward-headers-strategy=NATIVE @@ -47,4 +47,7 @@ posthog.host=https://eu.i.posthog.com spring.main.allow-bean-definition-overriding=true # Set up a consistent temporary directory location -java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf} \ No newline at end of file +java.io.tmpdir=${stirling.tempfiles.directory:${java.io.tmpdir}/stirling-pdf} + +# V2 features +v2=false diff --git a/app/core/src/main/resources/messages_en_GB.properties b/app/core/src/main/resources/messages_en_GB.properties index d6056e856..599dd0989 100644 --- a/app/core/src/main/resources/messages_en_GB.properties +++ b/app/core/src/main/resources/messages_en_GB.properties @@ -893,7 +893,7 @@ login.rememberme=Remember me login.invalid=Invalid username or password. login.locked=Your account has been locked. login.signinTitle=Please sign in -login.ssoSignIn=Login via Single Sign-on +login.ssoSignIn=Login via Single Sign-On login.oAuth2AutoCreateDisabled=OAUTH2 Auto-Create User Disabled login.oAuth2AdminBlockedUser=Registration or logging in of non-registered users is currently blocked. Please contact the administrator. login.oauth2RequestNotFound=Authorization request not found @@ -908,6 +908,7 @@ login.alreadyLoggedIn=You are already logged in to login.alreadyLoggedIn2=devices. Please log out of the devices and try again. login.toManySessions=You have too many active sessions login.logoutMessage=You have been logged out. +login.invalidInResponseTo=The requested SAML response is invalid or has expired. Please contact the administrator. #auto-redact autoRedact.title=Auto Redact diff --git a/app/core/src/main/resources/settings.yml.template b/app/core/src/main/resources/settings.yml.template index 1af95f852..bbbac5fcd 100644 --- a/app/core/src/main/resources/settings.yml.template +++ b/app/core/src/main/resources/settings.yml.template @@ -59,12 +59,17 @@ security: idpCert: classpath:okta.cert # The certificate your Provider will use to authenticate your app's SAML authentication requests. Provided by your Provider privateKey: classpath:saml-private-key.key # Your private key. Generated from your keypair spCert: classpath:saml-public-cert.crt # Your signing certificate. Generated from your keypair + jwt: # This feature is currently under development and not yet fully supported. Do not use in production. + persistence: true # Set to 'true' to enable JWT key store + enableKeyRotation: true # Set to 'true' to enable key pair rotation + enableKeyCleanup: true # Set to 'true' to enable key pair cleanup + keyRetentionDays: 7 # Number of days to retain old keys. The default is 7 days. + secureCookie: false # Set to 'true' to use secure cookies for JWTs premium: key: 00000000-0000-0000-0000-000000000000 enabled: false # Enable license key checks for pro/enterprise features proFeatures: - database: true # Enable database features SSOAutoLogin: false CustomMetadata: autoUpdateMetadata: false diff --git a/app/core/src/main/resources/static/js/DecryptFiles.js b/app/core/src/main/resources/static/js/DecryptFiles.js index 67349a012..0e5b58a92 100644 --- a/app/core/src/main/resources/static/js/DecryptFiles.js +++ b/app/core/src/main/resources/static/js/DecryptFiles.js @@ -46,10 +46,9 @@ export class DecryptFile { formData.append('password', password); } // Send decryption request - const response = await fetch('/api/v1/security/remove-password', { + const response = await fetchWithCsrf('/api/v1/security/remove-password', { method: 'POST', body: formData, - headers: csrfToken ? {'X-XSRF-TOKEN': csrfToken} : undefined, }); if (response.ok) { diff --git a/app/core/src/main/resources/static/js/downloader.js b/app/core/src/main/resources/static/js/downloader.js index 42ba0c357..b5324dd82 100644 --- a/app/core/src/main/resources/static/js/downloader.js +++ b/app/core/src/main/resources/static/js/downloader.js @@ -218,7 +218,7 @@ formData.append('password', password); // Use handleSingleDownload to send the request - const decryptionResult = await fetch(removePasswordUrl, {method: 'POST', body: formData}); + const decryptionResult = await fetchWithCsrf(removePasswordUrl, {method: 'POST', body: formData}); if (decryptionResult && decryptionResult.blob) { const decryptedBlob = await decryptionResult.blob(); diff --git a/app/core/src/main/resources/static/js/fetch-utils.js b/app/core/src/main/resources/static/js/fetch-utils.js index dfe2604a8..2a2fe894c 100644 --- a/app/core/src/main/resources/static/js/fetch-utils.js +++ b/app/core/src/main/resources/static/js/fetch-utils.js @@ -1,3 +1,29 @@ +// Authentication utility for cookie-based JWT +window.JWTManager = { + + // Logout - clear cookies and redirect to login + logout: function() { + + // Clear JWT cookie manually (fallback) + document.cookie = 'stirling_jwt=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure'; + + // Perform logout request to clear server-side session + fetch('/logout', { + method: 'POST', + credentials: 'include' + }).then(response => { + if (response.redirected) { + window.location.href = response.url; + } else { + window.location.href = '/login?logout=true'; + } + }).catch(() => { + // If logout fails, let server handle it + window.location.href = '/logout'; + }); + } +}; + window.fetchWithCsrf = async function(url, options = {}) { function getCsrfToken() { const cookieValue = document.cookie @@ -24,5 +50,19 @@ window.fetchWithCsrf = async function(url, options = {}) { fetchOptions.headers['X-XSRF-TOKEN'] = csrfToken; } - return fetch(url, fetchOptions); + // Always include credentials to send JWT cookies + fetchOptions.credentials = 'include'; + + // Make the request + const response = await fetch(url, fetchOptions); + + // Handle 401 responses (unauthorized) + if (response.status === 401) { + console.warn('Authentication failed, redirecting to login'); + window.JWTManager.logout(); + return response; + } + + return response; } + diff --git a/app/core/src/main/resources/static/js/jwt-init.js b/app/core/src/main/resources/static/js/jwt-init.js new file mode 100644 index 000000000..8cd63e189 --- /dev/null +++ b/app/core/src/main/resources/static/js/jwt-init.js @@ -0,0 +1,44 @@ +// JWT Authentication Management Script +// This script handles cookie-based JWT authentication and page access control + +(function() { + // Clean up JWT token from URL parameters after OAuth/Login flows + function cleanupTokenFromUrl() { + const urlParams = new URLSearchParams(window.location.search); + const hasToken = urlParams.get('jwt') || urlParams.get('token'); + if (hasToken) { + // Clean up URL by removing token parameter + // Token should now be set as cookie by server + urlParams.delete('jwt'); + urlParams.delete('token'); + const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : ''); + window.history.replaceState({}, '', newUrl); + } + } + + // Initialize JWT handling when page loads + function initializeJWT() { + // Clean up any JWT tokens from URL (OAuth flow) + cleanupTokenFromUrl(); + + // Authentication is handled server-side + // If user is not authenticated, server will redirect to login + console.log('JWT initialization complete - authentication handled server-side'); + } + + // No form enhancement needed for cookie-based JWT + // Cookies are automatically sent with form submissions + function enhanceFormSubmissions() { + // Cookie-based JWT is automatically included in form submissions + // No additional processing needed + } + + // Initialize when DOM is ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', function() { + initializeJWT(); + }); + } else { + initializeJWT(); + } +})(); \ No newline at end of file diff --git a/app/core/src/main/resources/static/js/navbar.js b/app/core/src/main/resources/static/js/navbar.js index a95ff1639..1fd46ed70 100644 --- a/app/core/src/main/resources/static/js/navbar.js +++ b/app/core/src/main/resources/static/js/navbar.js @@ -138,5 +138,19 @@ document.addEventListener('DOMContentLoaded', () => { tooltipSetup(); setupDropdowns(); fixNavbarDropdownStyles(); + // Setup logout button functionality + const logoutButton = document.querySelector('a[href="/logout"]'); + if (logoutButton) { + logoutButton.addEventListener('click', function(event) { + event.preventDefault(); + if (window.JWTManager) { + window.JWTManager.logout(); + } else { + // Fallback if JWTManager is not available + window.location.href = '/logout'; + } + }); + } + }); window.addEventListener('resize', fixNavbarDropdownStyles); diff --git a/app/core/src/main/resources/static/js/usage.js b/app/core/src/main/resources/static/js/usage.js index 624e4ec78..443a27ce1 100644 --- a/app/core/src/main/resources/static/js/usage.js +++ b/app/core/src/main/resources/static/js/usage.js @@ -102,7 +102,7 @@ async function fetchEndpointData() { refreshBtn.classList.add('refreshing'); refreshBtn.disabled = true; - const response = await fetch('/api/v1/info/load/all'); + const response = await fetchWithCsrf('/api/v1/info/load/all'); if (!response.ok) { throw new Error('Network response was not ok'); } diff --git a/app/core/src/main/resources/templates/account.html b/app/core/src/main/resources/templates/account.html index 33a0d9f47..db48bb3a5 100644 --- a/app/core/src/main/resources/templates/account.html +++ b/app/core/src/main/resources/templates/account.html @@ -390,8 +390,13 @@ key.includes('clientSubmissionOrder') || key.includes('lastSubmitTime') || key.includes('lastClientId') || - - + key.includes('stirling_jwt') || + key.includes('JSESSIONID') || + key.includes('XSRF-TOKEN') || + key.includes('remember-me') || + key.includes('auth') || + key.includes('token') || + key.includes('session') || key.includes('posthog') || key.includes('ssoRedirectAttempts') || key.includes('lastRedirectAttempt') || key.includes('surveyVersion') || key.includes('pageViews'); } diff --git a/app/proprietary/build.gradle b/app/proprietary/build.gradle index 719f74127..b8862bdd8 100644 --- a/app/proprietary/build.gradle +++ b/app/proprietary/build.gradle @@ -1,9 +1,15 @@ repositories { maven { url = "https://build.shibboleth.net/maven/releases" } } + +ext { + jwtVersion = '0.12.6' +} + bootRun { enabled = false } + spotless { java { target 'src/**/java/**/*.java' @@ -41,6 +47,8 @@ dependencies { api 'org.springframework.boot:spring-boot-starter-data-jpa' api 'org.springframework.boot:spring-boot-starter-oauth2-client' api 'org.springframework.boot:spring-boot-starter-mail' + api 'org.springframework.boot:spring-boot-starter-cache' + api 'com.github.ben-manes.caffeine:caffeine' api 'io.swagger.core.v3:swagger-core-jakarta:2.2.35' implementation 'com.bucket4j:bucket4j_jdk17-core:8.14.0' @@ -50,6 +58,10 @@ dependencies { implementation 'org.thymeleaf.extras:thymeleaf-extras-springsecurity5:3.1.3.RELEASE' api 'io.micrometer:micrometer-registry-prometheus' implementation 'com.unboundid.product.scim2:scim2-sdk-client:4.0.0' + + api "io.jsonwebtoken:jjwt-api:$jwtVersion" + runtimeOnly "io.jsonwebtoken:jjwt-impl:$jwtVersion" + runtimeOnly "io.jsonwebtoken:jjwt-jackson:$jwtVersion" runtimeOnly 'com.h2database:h2:2.3.232' // Don't upgrade h2database runtimeOnly 'org.postgresql:postgresql:42.7.7' constraints { diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java index d5180c321..51908ef03 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomAuthenticationSuccessHandler.java @@ -1,6 +1,7 @@ package stirling.software.proprietary.security; import java.io.IOException; +import java.util.Map; import org.springframework.security.core.Authentication; import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; @@ -17,6 +18,8 @@ import stirling.software.common.util.RequestUriUtils; import stirling.software.proprietary.audit.AuditEventType; import stirling.software.proprietary.audit.AuditLevel; import stirling.software.proprietary.audit.Audited; +import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; @@ -24,13 +27,17 @@ import stirling.software.proprietary.security.service.UserService; public class CustomAuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler { - private LoginAttemptService loginAttemptService; - private UserService userService; + private final LoginAttemptService loginAttemptService; + private final UserService userService; + private final JwtServiceInterface jwtService; public CustomAuthenticationSuccessHandler( - LoginAttemptService loginAttemptService, UserService userService) { + LoginAttemptService loginAttemptService, + UserService userService, + JwtServiceInterface jwtService) { this.loginAttemptService = loginAttemptService; this.userService = userService; + this.jwtService = jwtService; } @Override @@ -46,23 +53,31 @@ public class CustomAuthenticationSuccessHandler } loginAttemptService.loginSucceeded(userName); - // Get the saved request - HttpSession session = request.getSession(false); - SavedRequest savedRequest = - (session != null) - ? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST") - : null; + if (jwtService.isJwtEnabled()) { + String jwt = + jwtService.generateToken( + authentication, Map.of("authType", AuthenticationType.WEB)); + jwtService.addToken(response, jwt); + log.debug("JWT generated for user: {}", userName); - if (savedRequest != null - && !RequestUriUtils.isStaticResource( - request.getContextPath(), savedRequest.getRedirectUrl())) { - // Redirect to the original destination - super.onAuthenticationSuccess(request, response, authentication); - } else { - // Redirect to the root URL (considering context path) getRedirectStrategy().sendRedirect(request, response, "/"); - } + } else { + // Get the saved request + HttpSession session = request.getSession(false); + SavedRequest savedRequest = + (session != null) + ? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST") + : null; - // super.onAuthenticationSuccess(request, response, authentication); + if (savedRequest != null + && !RequestUriUtils.isStaticResource( + request.getContextPath(), savedRequest.getRedirectUrl())) { + // Redirect to the original destination + super.onAuthenticationSuccess(request, response, authentication); + } else { + // No saved request or it's a static resource, redirect to home page + getRedirectStrategy().sendRedirect(request, response, "/"); + } + } } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java index 033ea913c..136120528 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java @@ -33,6 +33,7 @@ import stirling.software.proprietary.audit.AuditLevel; import stirling.software.proprietary.audit.Audited; import stirling.software.proprietary.security.saml2.CertificateUtils; import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal; +import stirling.software.proprietary.security.service.JwtServiceInterface; @Slf4j @RequiredArgsConstructor @@ -40,15 +41,18 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { public static final String LOGOUT_PATH = "/login?logout=true"; - private final ApplicationProperties applicationProperties; + private final ApplicationProperties.Security securityProperties; private final AppConfig appConfig; + private final JwtServiceInterface jwtService; + @Override @Audited(type = AuditEventType.USER_LOGOUT, level = AuditLevel.BASIC) public void onLogoutSuccess( HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException { + if (!response.isCommitted()) { if (authentication != null) { if (authentication instanceof Saml2Authentication samlAuthentication) { @@ -67,6 +71,9 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { authentication.getClass().getSimpleName()); getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH); } + } else if (!jwtService.extractToken(request).isBlank()) { + jwtService.clearToken(response); + getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH); } else { // Redirect to login page after logout String path = checkForErrors(request); @@ -82,7 +89,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { Saml2Authentication samlAuthentication) throws IOException { - SAML2 samlConf = applicationProperties.getSecurity().getSaml2(); + SAML2 samlConf = securityProperties.getSaml2(); String registrationId = samlConf.getRegistrationId(); CustomSaml2AuthenticatedPrincipal principal = @@ -127,7 +134,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { OAuth2AuthenticationToken oAuthToken) throws IOException { String registrationId; - OAUTH2 oauth = applicationProperties.getSecurity().getOauth2(); + OAUTH2 oauth = securityProperties.getOauth2(); String path = checkForErrors(request); String redirectUrl = UrlUtils.getOrigin(request) + "/login?" + path; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java index 4b09fe0e9..e145e2754 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/InitialSecuritySetup.java @@ -43,7 +43,6 @@ public class InitialSecuritySetup { } } - userService.migrateOauth2ToSSO(); assignUsersToDefaultTeamIfMissing(); initializeInternalApiUser(); } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/JwtAuthenticationEntryPoint.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/JwtAuthenticationEntryPoint.java new file mode 100644 index 000000000..6805bcb54 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/JwtAuthenticationEntryPoint.java @@ -0,0 +1,22 @@ +package stirling.software.proprietary.security; + +import java.io.IOException; + +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.stereotype.Component; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +@Component +public class JwtAuthenticationEntryPoint implements AuthenticationEntryPoint { + @Override + public void commence( + HttpServletRequest request, + HttpServletResponse response, + AuthenticationException authException) + throws IOException { + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage()); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java index 0d846fc3d..46d0e7d3d 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/config/AccountWebController.java @@ -77,8 +77,11 @@ public class AccountWebController { @GetMapping("/login") public String login(HttpServletRequest request, Model model, Authentication authentication) { - // If the user is already authenticated, redirect them to the home page. - if (authentication != null && authentication.isAuthenticated()) { + // If the user is already authenticated and it's not a logout scenario, redirect them to the + // home page. + if (authentication != null + && authentication.isAuthenticated() + && request.getParameter("logout") == null) { return "redirect:/"; } @@ -184,7 +187,7 @@ public class AccountWebController { errorOAuth = "login.relyingPartyRegistrationNotFound"; // Valid InResponseTo was not available from the validation context, unable to // evaluate - case "invalid_in_response_to" -> errorOAuth = "login.invalid_in_response_to"; + case "invalid_in_response_to" -> errorOAuth = "login.invalidInResponseTo"; case "not_authentication_provider_found" -> errorOAuth = "login.not_authentication_provider_found"; } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/CacheConfig.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/CacheConfig.java new file mode 100644 index 000000000..ba074a5da --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/CacheConfig.java @@ -0,0 +1,31 @@ +package stirling.software.proprietary.security.configuration; + +import java.time.Duration; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.cache.CacheManager; +import org.springframework.cache.annotation.EnableCaching; +import org.springframework.cache.caffeine.CaffeineCacheManager; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import com.github.benmanes.caffeine.cache.Caffeine; + +@Configuration +@EnableCaching +public class CacheConfig { + + @Value("${security.jwt.keyRetentionDays}") + private int keyRetentionDays; + + @Bean + public CacheManager cacheManager() { + CaffeineCacheManager cacheManager = new CaffeineCacheManager(); + cacheManager.setCaffeine( + Caffeine.newBuilder() + .maximumSize(1000) // Make configurable? + .expireAfterWrite(Duration.ofDays(keyRetentionDays)) + .recordStats()); + return cacheManager; + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java index ab809a037..aceb3b712 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/configuration/SecurityConfiguration.java @@ -13,6 +13,7 @@ import org.springframework.security.authentication.dao.DaoAuthenticationProvider import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; @@ -35,10 +36,12 @@ import stirling.software.common.model.ApplicationProperties; import stirling.software.proprietary.security.CustomAuthenticationFailureHandler; import stirling.software.proprietary.security.CustomAuthenticationSuccessHandler; import stirling.software.proprietary.security.CustomLogoutSuccessHandler; +import stirling.software.proprietary.security.JwtAuthenticationEntryPoint; import stirling.software.proprietary.security.database.repository.JPATokenRepositoryImpl; import stirling.software.proprietary.security.database.repository.PersistentLoginRepository; import stirling.software.proprietary.security.filter.FirstLoginFilter; import stirling.software.proprietary.security.filter.IPRateLimitingFilter; +import stirling.software.proprietary.security.filter.JwtAuthenticationFilter; import stirling.software.proprietary.security.filter.UserAuthenticationFilter; import stirling.software.proprietary.security.model.User; import stirling.software.proprietary.security.oauth2.CustomOAuth2AuthenticationFailureHandler; @@ -48,6 +51,7 @@ import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticationSuc import stirling.software.proprietary.security.saml2.CustomSaml2ResponseAuthenticationConverter; import stirling.software.proprietary.security.service.CustomOAuth2UserService; import stirling.software.proprietary.security.service.CustomUserDetailsService; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; import stirling.software.proprietary.security.session.SessionPersistentRegistry; @@ -64,9 +68,11 @@ public class SecurityConfiguration { private final boolean loginEnabledValue; private final boolean runningProOrHigher; - private final ApplicationProperties applicationProperties; + private final ApplicationProperties.Security securityProperties; private final AppConfig appConfig; private final UserAuthenticationFilter userAuthenticationFilter; + private final JwtServiceInterface jwtService; + private final JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint; private final LoginAttemptService loginAttemptService; private final FirstLoginFilter firstLoginFilter; private final SessionPersistentRegistry sessionRegistry; @@ -82,8 +88,10 @@ public class SecurityConfiguration { @Qualifier("loginEnabled") boolean loginEnabledValue, @Qualifier("runningProOrHigher") boolean runningProOrHigher, AppConfig appConfig, - ApplicationProperties applicationProperties, + ApplicationProperties.Security securityProperties, UserAuthenticationFilter userAuthenticationFilter, + JwtServiceInterface jwtService, + JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint, LoginAttemptService loginAttemptService, FirstLoginFilter firstLoginFilter, SessionPersistentRegistry sessionRegistry, @@ -97,8 +105,10 @@ public class SecurityConfiguration { this.loginEnabledValue = loginEnabledValue; this.runningProOrHigher = runningProOrHigher; this.appConfig = appConfig; - this.applicationProperties = applicationProperties; + this.securityProperties = securityProperties; this.userAuthenticationFilter = userAuthenticationFilter; + this.jwtService = jwtService; + this.jwtAuthenticationEntryPoint = jwtAuthenticationEntryPoint; this.loginAttemptService = loginAttemptService; this.firstLoginFilter = firstLoginFilter; this.sessionRegistry = sessionRegistry; @@ -115,14 +125,28 @@ public class SecurityConfiguration { @Bean public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { - if (applicationProperties.getSecurity().getCsrfDisabled() || !loginEnabledValue) { - http.csrf(csrf -> csrf.disable()); + if (securityProperties.getCsrfDisabled() || !loginEnabledValue) { + http.csrf(CsrfConfigurer::disable); } if (loginEnabledValue) { + boolean v2Enabled = appConfig.v2Enabled(); + + if (v2Enabled) { + http.addFilterBefore( + jwtAuthenticationFilter(), + UsernamePasswordAuthenticationFilter.class) + .exceptionHandling( + exceptionHandling -> + exceptionHandling.authenticationEntryPoint( + jwtAuthenticationEntryPoint)); + } http.addFilterBefore( - userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class); - if (!applicationProperties.getSecurity().getCsrfDisabled()) { + userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class) + .addFilterAfter(rateLimitingFilter(), UserAuthenticationFilter.class) + .addFilterAfter(firstLoginFilter, UsernamePasswordAuthenticationFilter.class); + + if (!securityProperties.getCsrfDisabled()) { CookieCsrfTokenRepository cookieRepo = CookieCsrfTokenRepository.withHttpOnlyFalse(); CsrfTokenRequestAttributeHandler requestHandler = @@ -156,16 +180,21 @@ public class SecurityConfiguration { .csrfTokenRepository(cookieRepo) .csrfTokenRequestHandler(requestHandler)); } - http.addFilterBefore(rateLimitingFilter(), UsernamePasswordAuthenticationFilter.class); - http.addFilterAfter(firstLoginFilter, UsernamePasswordAuthenticationFilter.class); + http.sessionManagement( - sessionManagement -> + sessionManagement -> { + if (v2Enabled) { + sessionManagement.sessionCreationPolicy( + SessionCreationPolicy.STATELESS); + } else { sessionManagement .sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED) .maximumSessions(10) .maxSessionsPreventsLogin(false) .sessionRegistry(sessionRegistry) - .expiredUrl("/login?logout=true")); + .expiredUrl("/login?logout=true"); + } + }); http.authenticationProvider(daoAuthenticationProvider()); http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache())); http.logout( @@ -175,10 +204,10 @@ public class SecurityConfiguration { .matcher("/logout")) .logoutSuccessHandler( new CustomLogoutSuccessHandler( - applicationProperties, appConfig)) + securityProperties, appConfig, jwtService)) .clearAuthentication(true) .invalidateHttpSession(true) - .deleteCookies("JSESSIONID", "remember-me")); + .deleteCookies("JSESSIONID", "remember-me", "stirling_jwt")); http.rememberMe( rememberMeConfigurer -> // Use the configurator directly rememberMeConfigurer @@ -200,6 +229,7 @@ public class SecurityConfiguration { req -> { String uri = req.getRequestURI(); String contextPath = req.getContextPath(); + // Remove the context path from the URI String trimmedUri = uri.startsWith(contextPath) @@ -217,29 +247,35 @@ public class SecurityConfiguration { || trimmedUri.startsWith("/css/") || trimmedUri.startsWith("/fonts/") || trimmedUri.startsWith("/js/") + || trimmedUri.startsWith("/pdfjs/") + || trimmedUri.startsWith("/pdfjs-legacy/") + || trimmedUri.startsWith("/favicon") || trimmedUri.startsWith( - "/api/v1/info/status"); + "/api/v1/info/status") + || trimmedUri.startsWith("/v1/api-docs") + || uri.contains("/v1/api-docs"); }) .permitAll() .anyRequest() .authenticated()); // Handle User/Password Logins - if (applicationProperties.getSecurity().isUserPass()) { + if (securityProperties.isUserPass()) { http.formLogin( formLogin -> formLogin .loginPage("/login") .successHandler( new CustomAuthenticationSuccessHandler( - loginAttemptService, userService)) + loginAttemptService, + userService, + jwtService)) .failureHandler( new CustomAuthenticationFailureHandler( loginAttemptService, userService)) - .defaultSuccessUrl("/") .permitAll()); } // Handle OAUTH2 Logins - if (applicationProperties.getSecurity().isOauth2Active()) { + if (securityProperties.isOauth2Active()) { http.oauth2Login( oauth2 -> oauth2.loginPage("/oauth2") @@ -251,17 +287,18 @@ public class SecurityConfiguration { .successHandler( new CustomOAuth2AuthenticationSuccessHandler( loginAttemptService, - applicationProperties, - userService)) + securityProperties.getOauth2(), + userService, + jwtService)) .failureHandler( new CustomOAuth2AuthenticationFailureHandler()) - . // Add existing Authorities from the database - userInfoEndpoint( + // Add existing Authorities from the database + .userInfoEndpoint( userInfoEndpoint -> userInfoEndpoint .oidcUserService( new CustomOAuth2UserService( - applicationProperties, + securityProperties, userService, loginAttemptService)) .userAuthoritiesMapper( @@ -269,8 +306,7 @@ public class SecurityConfiguration { .permitAll()); } // Handle SAML - if (applicationProperties.getSecurity().isSaml2Active() && runningProOrHigher) { - // Configure the authentication provider + if (securityProperties.isSaml2Active() && runningProOrHigher) { OpenSaml4AuthenticationProvider authenticationProvider = new OpenSaml4AuthenticationProvider(); authenticationProvider.setResponseAuthenticationConverter( @@ -287,8 +323,9 @@ public class SecurityConfiguration { .successHandler( new CustomSaml2AuthenticationSuccessHandler( loginAttemptService, - applicationProperties, - userService)) + securityProperties.getSaml2(), + userService, + jwtService)) .failureHandler( new CustomSaml2AuthenticationFailureHandler()) .authenticationRequestResolver( @@ -323,4 +360,14 @@ public class SecurityConfiguration { public PersistentTokenRepository persistentTokenRepository() { return new JPATokenRepositoryImpl(persistentLoginRepository); } + + @Bean + public JwtAuthenticationFilter jwtAuthenticationFilter() { + return new JwtAuthenticationFilter( + jwtService, + userService, + userDetailsService, + jwtAuthenticationEntryPoint, + securityProperties); + } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilter.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilter.java new file mode 100644 index 000000000..faf50832f --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilter.java @@ -0,0 +1,204 @@ +package stirling.software.proprietary.security.filter; + +import static stirling.software.common.util.RequestUriUtils.isStaticResource; +import static stirling.software.proprietary.security.model.AuthenticationType.*; +import static stirling.software.proprietary.security.model.AuthenticationType.SAML2; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.Map; +import java.util.Optional; + +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.web.filter.OncePerRequestFilter; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.model.ApplicationProperties; +import stirling.software.common.model.exception.UnsupportedProviderException; +import stirling.software.proprietary.security.model.ApiKeyAuthenticationToken; +import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.model.User; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.service.CustomUserDetailsService; +import stirling.software.proprietary.security.service.JwtServiceInterface; +import stirling.software.proprietary.security.service.UserService; + +@Slf4j +public class JwtAuthenticationFilter extends OncePerRequestFilter { + + private final JwtServiceInterface jwtService; + private final UserService userService; + private final CustomUserDetailsService userDetailsService; + private final AuthenticationEntryPoint authenticationEntryPoint; + private final ApplicationProperties.Security securityProperties; + + public JwtAuthenticationFilter( + JwtServiceInterface jwtService, + UserService userService, + CustomUserDetailsService userDetailsService, + AuthenticationEntryPoint authenticationEntryPoint, + ApplicationProperties.Security securityProperties) { + this.jwtService = jwtService; + this.userService = userService; + this.userDetailsService = userDetailsService; + this.authenticationEntryPoint = authenticationEntryPoint; + this.securityProperties = securityProperties; + } + + @Override + protected void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + if (!jwtService.isJwtEnabled()) { + filterChain.doFilter(request, response); + return; + } + if (isStaticResource(request.getContextPath(), request.getRequestURI())) { + filterChain.doFilter(request, response); + return; + } + + if (!apiKeyExists(request, response)) { + String jwtToken = jwtService.extractToken(request); + + if (jwtToken == null) { + // Any unauthenticated requests should redirect to /login + String requestURI = request.getRequestURI(); + String contextPath = request.getContextPath(); + + if (!requestURI.startsWith(contextPath + "/login")) { + response.sendRedirect("/login"); + return; + } + } + + try { + jwtService.validateToken(jwtToken); + } catch (AuthenticationFailureException e) { + jwtService.clearToken(response); + handleAuthenticationFailure(request, response, e); + return; + } + + Map claims = jwtService.extractClaims(jwtToken); + String tokenUsername = claims.get("sub").toString(); + + try { + authenticate(request, claims); + } catch (SQLException | UnsupportedProviderException e) { + log.error("Error processing user authentication for user: {}", tokenUsername, e); + handleAuthenticationFailure( + request, + response, + new AuthenticationFailureException( + "Error processing user authentication", e)); + return; + } + } + + filterChain.doFilter(request, response); + } + + private boolean apiKeyExists(HttpServletRequest request, HttpServletResponse response) + throws IOException, ServletException { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + + if (authentication == null || !authentication.isAuthenticated()) { + String apiKey = request.getHeader("X-API-KEY"); + + if (apiKey != null && !apiKey.isBlank()) { + try { + Optional user = userService.getUserByApiKey(apiKey); + + if (user.isEmpty()) { + handleAuthenticationFailure( + request, + response, + new AuthenticationFailureException("Invalid API Key")); + return false; + } + + authentication = + new ApiKeyAuthenticationToken( + user.get(), apiKey, user.get().getAuthorities()); + SecurityContextHolder.getContext().setAuthentication(authentication); + return true; + } catch (AuthenticationException e) { + handleAuthenticationFailure( + request, + response, + new AuthenticationFailureException("Invalid API Key", e)); + return false; + } + } + + return false; + } + + return true; + } + + private void authenticate(HttpServletRequest request, Map claims) + throws SQLException, UnsupportedProviderException { + String username = claims.get("sub").toString(); + + if (username != null && SecurityContextHolder.getContext().getAuthentication() == null) { + processUserAuthenticationType(claims, username); + UserDetails userDetails = userDetailsService.loadUserByUsername(username); + + if (userDetails != null) { + UsernamePasswordAuthenticationToken authToken = + new UsernamePasswordAuthenticationToken( + userDetails, null, userDetails.getAuthorities()); + + authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); + SecurityContextHolder.getContext().setAuthentication(authToken); + } else { + throw new UsernameNotFoundException("User not found: " + username); + } + } + } + + private void processUserAuthenticationType(Map claims, String username) + throws SQLException, UnsupportedProviderException { + AuthenticationType authenticationType = + AuthenticationType.valueOf(claims.getOrDefault("authType", WEB).toString()); + log.debug("Processing {} login for {} user", authenticationType, username); + + switch (authenticationType) { + case OAUTH2 -> { + ApplicationProperties.Security.OAUTH2 oauth2Properties = + securityProperties.getOauth2(); + userService.processSSOPostLogin( + username, oauth2Properties.getAutoCreateUser(), OAUTH2); + } + case SAML2 -> { + ApplicationProperties.Security.SAML2 saml2Properties = + securityProperties.getSaml2(); + userService.processSSOPostLogin( + username, saml2Properties.getAutoCreateUser(), SAML2); + } + } + } + + private void handleAuthenticationFailure( + HttpServletRequest request, + HttpServletResponse response, + AuthenticationException authException) + throws IOException, ServletException { + authenticationEntryPoint.commence(request, response, authException); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java index e9addd239..f51a9d543 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/filter/UserAuthenticationFilter.java @@ -9,7 +9,6 @@ import org.springframework.context.annotation.Lazy; import org.springframework.http.HttpStatus; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.userdetails.UserDetails; @@ -64,6 +63,7 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { return; } String requestURI = request.getRequestURI(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); // Check for session expiration (unsure if needed) @@ -92,14 +92,9 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { response.getWriter().write("Invalid API Key."); return; } - List authorities = - user.get().getAuthorities().stream() - .map( - authority -> - new SimpleGrantedAuthority( - authority.getAuthority())) - .toList(); - authentication = new ApiKeyAuthenticationToken(user.get(), apiKey, authorities); + authentication = + new ApiKeyAuthenticationToken( + user.get(), apiKey, user.get().getAuthorities()); SecurityContextHolder.getContext().setAuthentication(authentication); } catch (AuthenticationException e) { // If API key authentication fails, deny the request @@ -115,20 +110,19 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { String method = request.getMethod(); String contextPath = request.getContextPath(); - if ("GET".equalsIgnoreCase(method) && !(contextPath + "/login").equals(requestURI)) { + if ("GET".equalsIgnoreCase(method) && !requestURI.startsWith(contextPath + "/login")) { response.sendRedirect(contextPath + "/login"); // redirect to the login page - return; } else { response.setStatus(HttpStatus.UNAUTHORIZED.value()); response.getWriter() .write( - "Authentication required. Please provide a X-API-KEY in request" - + " header.\n" - + "This is found in Settings -> Account Settings -> API Key\n" - + "Alternatively you can disable authentication if this is" - + " unexpected"); - return; + """ + Authentication required. Please provide a X-API-KEY in request header. + This is found in Settings -> Account Settings -> API Key + Alternatively you can disable authentication if this is unexpected. + """); } + return; } // Check if the authenticated user is disabled and invalidate their session if so @@ -226,11 +220,12 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { } @Override - protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + protected boolean shouldNotFilter(HttpServletRequest request) { String uri = request.getRequestURI(); String contextPath = request.getContextPath(); String[] permitAllPatterns = { contextPath + "/login", + contextPath + "/signup", contextPath + "/register", contextPath + "/error", contextPath + "/images/", @@ -247,6 +242,7 @@ public class UserAuthenticationFilter extends OncePerRequestFilter { for (String pattern : permitAllPatterns) { if (uri.startsWith(pattern) || uri.endsWith(".svg") + || uri.endsWith(".mjs") || uri.endsWith(".png") || uri.endsWith(".ico")) { return true; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java index ca8140bca..c92c1655e 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/AuthenticationType.java @@ -2,5 +2,8 @@ package stirling.software.proprietary.security.model; public enum AuthenticationType { WEB, - SSO + @Deprecated(since = "1.0.2") + SSO, + OAUTH2, + SAML2 } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java index 382d3a71e..a32e7d7ca 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/Authority.java @@ -2,6 +2,8 @@ package stirling.software.proprietary.security.model; import java.io.Serializable; +import org.springframework.security.core.GrantedAuthority; + import jakarta.persistence.Column; import jakarta.persistence.Entity; import jakarta.persistence.GeneratedValue; @@ -18,7 +20,7 @@ import lombok.Setter; @Table(name = "authorities") @Getter @Setter -public class Authority implements Serializable { +public class Authority implements GrantedAuthority, Serializable { private static final long serialVersionUID = 1L; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/JwtVerificationKey.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/JwtVerificationKey.java new file mode 100644 index 000000000..632c5f13a --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/JwtVerificationKey.java @@ -0,0 +1,33 @@ +package stirling.software.proprietary.security.model; + +import java.io.Serial; +import java.io.Serializable; +import java.time.LocalDateTime; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.ToString; + +@Getter +@Setter +@NoArgsConstructor +@ToString(onlyExplicitlyIncluded = true) +@EqualsAndHashCode(onlyExplicitlyIncluded = true) +public class JwtVerificationKey implements Serializable { + + @Serial private static final long serialVersionUID = 1L; + + @ToString.Include private String keyId; + + private String verifyingKey; + + @ToString.Include private LocalDateTime createdAt; + + public JwtVerificationKey(String keyId, String verifyingKey) { + this.keyId = keyId; + this.verifyingKey = verifyingKey; + this.createdAt = LocalDateTime.now(); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java index d3e232f61..7d1b235cd 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/User.java @@ -7,6 +7,8 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.springframework.security.core.userdetails.UserDetails; + import jakarta.persistence.*; import lombok.EqualsAndHashCode; @@ -25,7 +27,7 @@ import stirling.software.proprietary.model.Team; @Setter @EqualsAndHashCode(onlyExplicitlyIncluded = true) @ToString(onlyExplicitlyIncluded = true) -public class User implements Serializable { +public class User implements UserDetails, Serializable { private static final long serialVersionUID = 1L; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/model/exception/AuthenticationFailureException.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/exception/AuthenticationFailureException.java new file mode 100644 index 000000000..f2cd5e242 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/model/exception/AuthenticationFailureException.java @@ -0,0 +1,13 @@ +package stirling.software.proprietary.security.model.exception; + +import org.springframework.security.core.AuthenticationException; + +public class AuthenticationFailureException extends AuthenticationException { + public AuthenticationFailureException(String message) { + super(message); + } + + public AuthenticationFailureException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java index 71bd42a85..4e7ed9d9e 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/CustomOAuth2AuthenticationSuccessHandler.java @@ -1,7 +1,11 @@ package stirling.software.proprietary.security.oauth2; +import static stirling.software.proprietary.security.model.AuthenticationType.OAUTH2; +import static stirling.software.proprietary.security.model.AuthenticationType.SSO; + import java.io.IOException; import java.sql.SQLException; +import java.util.Map; import org.springframework.security.authentication.LockedException; import org.springframework.security.core.Authentication; @@ -18,10 +22,10 @@ import jakarta.servlet.http.HttpSession; import lombok.RequiredArgsConstructor; import stirling.software.common.model.ApplicationProperties; -import stirling.software.common.model.ApplicationProperties.Security.OAUTH2; import stirling.software.common.model.exception.UnsupportedProviderException; import stirling.software.common.util.RequestUriUtils; import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; @@ -30,8 +34,9 @@ public class CustomOAuth2AuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler { private final LoginAttemptService loginAttemptService; - private final ApplicationProperties applicationProperties; + private final ApplicationProperties.Security.OAUTH2 oauth2Properties; private final UserService userService; + private final JwtServiceInterface jwtService; @Override public void onAuthenticationSuccess( @@ -60,8 +65,6 @@ public class CustomOAuth2AuthenticationSuccessHandler // Redirect to the original destination super.onAuthenticationSuccess(request, response, authentication); } else { - OAUTH2 oAuth = applicationProperties.getSecurity().getOauth2(); - if (loginAttemptService.isBlocked(username)) { if (session != null) { session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST"); @@ -69,7 +72,12 @@ public class CustomOAuth2AuthenticationSuccessHandler throw new LockedException( "Your account has been locked due to too many failed login attempts."); } - + if (jwtService.isJwtEnabled()) { + String jwt = + jwtService.generateToken( + authentication, Map.of("authType", AuthenticationType.OAUTH2)); + jwtService.addToken(response, jwt); + } if (userService.isUserDisabled(username)) { getRedirectStrategy() .sendRedirect(request, response, "/logout?userIsDisabled=true"); @@ -77,20 +85,22 @@ public class CustomOAuth2AuthenticationSuccessHandler } if (userService.usernameExistsIgnoreCase(username) && userService.hasPassword(username) - && !userService.isAuthenticationTypeByUsername(username, AuthenticationType.SSO) - && oAuth.getAutoCreateUser()) { + && (!userService.isAuthenticationTypeByUsername(username, SSO) + || !userService.isAuthenticationTypeByUsername(username, OAUTH2)) + && oauth2Properties.getAutoCreateUser()) { response.sendRedirect(contextPath + "/logout?oAuth2AuthenticationErrorWeb=true"); return; } try { - if (oAuth.getBlockRegistration() + if (oauth2Properties.getBlockRegistration() && !userService.usernameExistsIgnoreCase(username)) { response.sendRedirect(contextPath + "/logout?oAuth2AdminBlockedUser=true"); return; } if (principal instanceof OAuth2User) { - userService.processSSOPostLogin(username, oAuth.getAutoCreateUser()); + userService.processSSOPostLogin( + username, oauth2Properties.getAutoCreateUser(), OAUTH2); } response.sendRedirect(contextPath + "/"); } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java index 6516cc7d7..913dc458a 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/oauth2/OAuth2Configuration.java @@ -34,6 +34,7 @@ import stirling.software.common.model.oauth2.GitHubProvider; import stirling.software.common.model.oauth2.GoogleProvider; import stirling.software.common.model.oauth2.KeycloakProvider; import stirling.software.common.model.oauth2.Provider; +import stirling.software.proprietary.security.model.Authority; import stirling.software.proprietary.security.model.User; import stirling.software.proprietary.security.model.exception.NoProviderFoundException; import stirling.software.proprietary.security.service.UserService; @@ -239,12 +240,14 @@ public class OAuth2Configuration { Optional userOpt = userService.findByUsernameIgnoreCase( (String) oAuth2Auth.getAttributes().get(useAsUsername)); - if (userOpt.isPresent()) { - User user = userOpt.get(); - mappedAuthorities.add( - new SimpleGrantedAuthority( - userService.findRole(user).getAuthority())); - } + userOpt.ifPresent( + user -> + mappedAuthorities.add( + new Authority( + userService + .findRole(user) + .getAuthority(), + user))); } }); return mappedAuthorities; diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java index 2170a9632..3255cbc15 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/CustomSaml2AuthenticationSuccessHandler.java @@ -1,7 +1,11 @@ package stirling.software.proprietary.security.saml2; +import static stirling.software.proprietary.security.model.AuthenticationType.SAML2; +import static stirling.software.proprietary.security.model.AuthenticationType.SSO; + import java.io.IOException; import java.sql.SQLException; +import java.util.Map; import org.springframework.security.authentication.LockedException; import org.springframework.security.core.Authentication; @@ -17,10 +21,10 @@ import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import stirling.software.common.model.ApplicationProperties; -import stirling.software.common.model.ApplicationProperties.Security.SAML2; import stirling.software.common.model.exception.UnsupportedProviderException; import stirling.software.common.util.RequestUriUtils; import stirling.software.proprietary.security.model.AuthenticationType; +import stirling.software.proprietary.security.service.JwtServiceInterface; import stirling.software.proprietary.security.service.LoginAttemptService; import stirling.software.proprietary.security.service.UserService; @@ -30,8 +34,9 @@ public class CustomSaml2AuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler { private LoginAttemptService loginAttemptService; - private ApplicationProperties applicationProperties; + private ApplicationProperties.Security.SAML2 saml2Properties; private UserService userService; + private final JwtServiceInterface jwtService; @Override public void onAuthenticationSuccess( @@ -65,10 +70,9 @@ public class CustomSaml2AuthenticationSuccessHandler savedRequest.getRedirectUrl()); super.onAuthenticationSuccess(request, response, authentication); } else { - SAML2 saml2 = applicationProperties.getSecurity().getSaml2(); log.debug( "Processing SAML2 authentication with autoCreateUser: {}", - saml2.getAutoCreateUser()); + saml2Properties.getAutoCreateUser()); if (loginAttemptService.isBlocked(username)) { log.debug("User {} is blocked due to too many login attempts", username); @@ -82,17 +86,21 @@ public class CustomSaml2AuthenticationSuccessHandler boolean userExists = userService.usernameExistsIgnoreCase(username); boolean hasPassword = userExists && userService.hasPassword(username); boolean isSSOUser = - userExists - && userService.isAuthenticationTypeByUsername( - username, AuthenticationType.SSO); + userExists && userService.isAuthenticationTypeByUsername(username, SSO); + boolean isSAML2User = + userExists && userService.isAuthenticationTypeByUsername(username, SAML2); log.debug( - "User status - Exists: {}, Has password: {}, Is SSO user: {}", + "User status - Exists: {}, Has password: {}, Is SSO user: {}, Is SAML2 user: {}", userExists, hasPassword, - isSSOUser); + isSSOUser, + isSAML2User); - if (userExists && hasPassword && !isSSOUser && saml2.getAutoCreateUser()) { + if (userExists + && hasPassword + && (!isSSOUser || !isSAML2User) + && saml2Properties.getAutoCreateUser()) { log.debug( "User {} exists with password but is not SSO user, redirecting to logout", username); @@ -102,15 +110,18 @@ public class CustomSaml2AuthenticationSuccessHandler } try { - if (saml2.getBlockRegistration() && !userExists) { + if (!userExists || saml2Properties.getBlockRegistration()) { log.debug("Registration blocked for new user: {}", username); response.sendRedirect( contextPath + "/login?errorOAuth=oAuth2AdminBlockedUser"); return; } log.debug("Processing SSO post-login for user: {}", username); - userService.processSSOPostLogin(username, saml2.getAutoCreateUser()); + userService.processSSOPostLogin( + username, saml2Properties.getAutoCreateUser(), SAML2); log.debug("Successfully processed authentication for user: {}", username); + + generateJwt(response, authentication); response.sendRedirect(contextPath + "/"); } catch (IllegalArgumentException | SQLException | UnsupportedProviderException e) { log.debug( @@ -124,4 +135,13 @@ public class CustomSaml2AuthenticationSuccessHandler super.onAuthenticationSuccess(request, response, authentication); } } + + private void generateJwt(HttpServletResponse response, Authentication authentication) { + if (jwtService.isJwtEnabled()) { + String jwt = + jwtService.generateToken( + authentication, Map.of("authType", AuthenticationType.SAML2)); + jwtService.addToken(response, jwt); + } + } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java new file mode 100644 index 000000000..d0508151c --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java @@ -0,0 +1,135 @@ +package stirling.software.proprietary.security.saml2; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.security.service.JwtServiceInterface; + +@Slf4j +public class JwtSaml2AuthenticationRequestRepository + implements Saml2AuthenticationRequestRepository { + private final Map tokenStore; + private final JwtServiceInterface jwtService; + private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + + private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; + + public JwtSaml2AuthenticationRequestRepository( + Map tokenStore, + JwtServiceInterface jwtService, + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + this.tokenStore = tokenStore; + this.jwtService = jwtService; + this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + } + + @Override + public void saveAuthenticationRequest( + Saml2PostAuthenticationRequest authRequest, + HttpServletRequest request, + HttpServletResponse response) { + if (!jwtService.isJwtEnabled()) { + log.debug("V2 is not enabled, skipping SAMLRequest token storage"); + return; + } + + if (authRequest == null) { + removeAuthenticationRequest(request, response); + return; + } + + Map claims = serializeSamlRequest(authRequest); + String token = jwtService.generateToken("", claims); + String relayState = authRequest.getRelayState(); + + tokenStore.put(relayState, token); + request.setAttribute(SAML_REQUEST_TOKEN, relayState); + response.addHeader(SAML_REQUEST_TOKEN, relayState); + + log.debug("Saved SAMLRequest token with RelayState: {}", relayState); + } + + @Override + public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + String token = extractTokenFromStore(request); + + if (token == null) { + log.debug("No SAMLResponse token found in RelayState"); + return null; + } + + Map claims = jwtService.extractClaims(token); + return deserializeSamlRequest(claims); + } + + @Override + public Saml2PostAuthenticationRequest removeAuthenticationRequest( + HttpServletRequest request, HttpServletResponse response) { + Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request); + + String relayStateId = request.getParameter("RelayState"); + if (relayStateId != null) { + tokenStore.remove(relayStateId); + log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId); + } + + return authRequest; + } + + private String extractTokenFromStore(HttpServletRequest request) { + String authnRequestId = request.getParameter("RelayState"); + + if (authnRequestId != null && !authnRequestId.isEmpty()) { + String token = tokenStore.get(authnRequestId); + + if (token != null) { + tokenStore.remove(authnRequestId); + log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId); + return token; + } else { + log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId); + } + } + + return null; + } + + private Map serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) { + Map claims = new HashMap<>(); + + claims.put("id", authRequest.getId()); + claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId()); + claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri()); + claims.put("samlRequest", authRequest.getSamlRequest()); + claims.put("relayState", authRequest.getRelayState()); + + return claims; + } + + private Saml2PostAuthenticationRequest deserializeSamlRequest(Map claims) { + String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId"); + RelyingPartyRegistration relyingPartyRegistration = + relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId); + + if (relyingPartyRegistration == null) { + return null; + } + + return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration) + .id((String) claims.get("id")) + .authenticationRequestUri((String) claims.get("authenticationRequestUri")) + .samlRequest((String) claims.get("samlRequest")) + .relayState((String) claims.get("relayState")) + .build(); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/SAML2Configuration.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java similarity index 85% rename from app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/SAML2Configuration.java rename to app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java index 7fd4768b3..9d21f88a3 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/SAML2Configuration.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java @@ -3,6 +3,7 @@ package stirling.software.proprietary.security.saml2; import java.security.cert.X509Certificate; import java.util.Collections; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import org.opensaml.saml.saml2.core.AuthnRequest; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; @@ -11,12 +12,12 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.Resource; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; -import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; import jakarta.servlet.http.HttpServletRequest; @@ -26,12 +27,13 @@ import lombok.extern.slf4j.Slf4j; import stirling.software.common.model.ApplicationProperties; import stirling.software.common.model.ApplicationProperties.Security.SAML2; +import stirling.software.proprietary.security.service.JwtServiceInterface; @Configuration @Slf4j @ConditionalOnProperty(value = "security.saml2.enabled", havingValue = "true") @RequiredArgsConstructor -public class SAML2Configuration { +public class Saml2Configuration { private final ApplicationProperties applicationProperties; @@ -58,6 +60,7 @@ public class SAML2Configuration { .assertionConsumerServiceBinding(Saml2MessageBinding.POST) .assertionConsumerServiceLocation( "{baseUrl}/login/saml2/sso/{registrationId}") + .authnRequestsSigned(true) .assertingPartyMetadata( metadata -> metadata.entityId(samlConf.getIdpIssuer()) @@ -71,15 +74,29 @@ public class SAML2Configuration { Saml2MessageBinding.POST) .singleLogoutServiceLocation( samlConf.getIdpSingleLogoutUrl()) + .singleLogoutServiceResponseLocation( + "http://localhost:8080/login") .wantAuthnRequestsSigned(true)) .build(); return new InMemoryRelyingPartyRegistrationRepository(rp); } + @Bean + @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") + public Saml2AuthenticationRequestRepository + saml2AuthenticationRequestRepository( + JwtServiceInterface jwtService, + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + return new JwtSaml2AuthenticationRequestRepository( + new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository); + } + @Bean @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver( - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + Saml2AuthenticationRequestRepository + saml2AuthenticationRequestRepository) { OpenSaml4AuthenticationRequestResolver resolver = new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository); @@ -87,10 +104,8 @@ public class SAML2Configuration { customizer -> { HttpServletRequest request = customizer.getRequest(); AuthnRequest authnRequest = customizer.getAuthnRequest(); - HttpSessionSaml2AuthenticationRequestRepository requestRepository = - new HttpSessionSaml2AuthenticationRequestRepository(); - AbstractSaml2AuthenticationRequest saml2AuthenticationRequest = - requestRepository.loadAuthenticationRequest(request); + Saml2PostAuthenticationRequest saml2AuthenticationRequest = + saml2AuthenticationRequestRepository.loadAuthenticationRequest(request); if (saml2AuthenticationRequest != null) { String sessionId = request.getSession(false).getId(); @@ -113,7 +128,6 @@ public class SAML2Configuration { log.debug("Generating new authentication request ID"); authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); } - logAuthnRequestDetails(authnRequest); logHttpRequestDetails(request); }); diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java index 0b286e894..8f9afbe3d 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/CustomOAuth2UserService.java @@ -27,13 +27,13 @@ public class CustomOAuth2UserService implements OAuth2UserService new UsernameNotFoundException( "No user found with username: " + username)); + if (loginAttemptService.isBlocked(username)) { throw new LockedException( "Your account has been locked due to too many failed login attempts."); } - if (!user.hasPassword()) { + + AuthenticationType userAuthenticationType = + AuthenticationType.valueOf(user.getAuthenticationType().toUpperCase()); + if (!user.hasPassword() && userAuthenticationType == AuthenticationType.WEB) { throw new IllegalArgumentException("Password must not be null"); } - return new org.springframework.security.core.userdetails.User( - user.getUsername(), - user.getPassword(), - user.isEnabled(), - true, - true, - true, - getAuthorities(user.getAuthorities())); - } - private Collection getAuthorities(Set authorities) { - return authorities.stream() - .map(authority -> new SimpleGrantedAuthority(authority.getAuthority())) - .toList(); + return user; } } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java new file mode 100644 index 000000000..8724da9a8 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java @@ -0,0 +1,330 @@ +package stirling.software.proprietary.security.service; + +import java.security.KeyPair; +import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; +import java.security.spec.InvalidKeySpecException; +import java.time.LocalDateTime; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.ResponseCookie; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.stereotype.Service; + +import io.github.pixee.security.Newlines; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.ExpiredJwtException; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.MalformedJwtException; +import io.jsonwebtoken.UnsupportedJwtException; +import io.jsonwebtoken.security.SignatureException; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.security.model.JwtVerificationKey; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal; + +@Slf4j +@Service +public class JwtService implements JwtServiceInterface { + + private static final String JWT_COOKIE_NAME = "stirling_jwt"; + private static final String ISSUER = "Stirling PDF"; + private static final long EXPIRATION = 3600000; + + @Value("${stirling.security.jwt.secureCookie:true}") + private boolean secureCookie; + + private final KeyPersistenceServiceInterface keyPersistenceService; + private final boolean v2Enabled; + + @Autowired + public JwtService( + @Qualifier("v2Enabled") boolean v2Enabled, + KeyPersistenceServiceInterface keyPersistenceService) { + this.v2Enabled = v2Enabled; + this.keyPersistenceService = keyPersistenceService; + } + + @Override + public String generateToken(Authentication authentication, Map claims) { + Object principal = authentication.getPrincipal(); + String username = ""; + + if (principal instanceof UserDetails) { + username = ((UserDetails) principal).getUsername(); + } else if (principal instanceof OAuth2User) { + username = ((OAuth2User) principal).getName(); + } else if (principal instanceof CustomSaml2AuthenticatedPrincipal) { + username = ((CustomSaml2AuthenticatedPrincipal) principal).getName(); + } + + return generateToken(username, claims); + } + + @Override + public String generateToken(String username, Map claims) { + try { + JwtVerificationKey activeKey = keyPersistenceService.getActiveKey(); + Optional keyPairOpt = keyPersistenceService.getKeyPair(activeKey.getKeyId()); + + if (keyPairOpt.isEmpty()) { + throw new RuntimeException("Unable to retrieve key pair for active key"); + } + + KeyPair keyPair = keyPairOpt.get(); + + var builder = + Jwts.builder() + .claims(claims) + .subject(username) + .issuer(ISSUER) + .issuedAt(new Date()) + .expiration(new Date(System.currentTimeMillis() + EXPIRATION)) + .signWith(keyPair.getPrivate(), Jwts.SIG.RS256); + + String keyId = activeKey.getKeyId(); + if (keyId != null) { + builder.header().keyId(keyId); + } + + return builder.compact(); + } catch (Exception e) { + throw new RuntimeException("Failed to generate token", e); + } + } + + @Override + public void validateToken(String token) throws AuthenticationFailureException { + extractAllClaims(token); + + if (isTokenExpired(token)) { + throw new AuthenticationFailureException("The token has expired"); + } + } + + @Override + public String extractUsername(String token) { + return extractClaim(token, Claims::getSubject); + } + + @Override + public Map extractClaims(String token) { + Claims claims = extractAllClaims(token); + return new HashMap<>(claims); + } + + @Override + public boolean isTokenExpired(String token) { + return extractExpiration(token).before(new Date()); + } + + private Date extractExpiration(String token) { + return extractClaim(token, Claims::getExpiration); + } + + private T extractClaim(String token, Function claimsResolver) { + final Claims claims = extractAllClaims(token); + return claimsResolver.apply(claims); + } + + private Claims extractAllClaims(String token) { + try { + String keyId = extractKeyId(token); + KeyPair keyPair; + + if (keyId != null) { + Optional specificKeyPair = keyPersistenceService.getKeyPair(keyId); + + if (specificKeyPair.isPresent()) { + keyPair = specificKeyPair.get(); + } else { + log.warn( + "Key ID {} not found in keystore, token may have been signed with an expired key", + keyId); + + if (keyId.equals(keyPersistenceService.getActiveKey().getKeyId())) { + JwtVerificationKey verificationKey = + keyPersistenceService.refreshActiveKeyPair(); + Optional refreshedKeyPair = + keyPersistenceService.getKeyPair(verificationKey.getKeyId()); + if (refreshedKeyPair.isPresent()) { + keyPair = refreshedKeyPair.get(); + } else { + throw new AuthenticationFailureException( + "Failed to retrieve refreshed key pair"); + } + } else { + // Try to use active key as fallback + JwtVerificationKey activeKey = keyPersistenceService.getActiveKey(); + Optional activeKeyPair = + keyPersistenceService.getKeyPair(activeKey.getKeyId()); + if (activeKeyPair.isPresent()) { + keyPair = activeKeyPair.get(); + } else { + throw new AuthenticationFailureException( + "Failed to retrieve active key pair"); + } + } + } + } else { + log.debug("No key ID in token header, trying all available keys"); + // Try all available keys when no keyId is present + return tryAllKeys(token); + } + + return Jwts.parser() + .verifyWith(keyPair.getPublic()) + .build() + .parseSignedClaims(token) + .getPayload(); + } catch (SignatureException e) { + log.warn("Invalid signature: {}", e.getMessage()); + throw new AuthenticationFailureException("Invalid signature", e); + } catch (MalformedJwtException e) { + log.warn("Invalid token: {}", e.getMessage()); + throw new AuthenticationFailureException("Invalid token", e); + } catch (ExpiredJwtException e) { + log.warn("The token has expired: {}", e.getMessage()); + throw new AuthenticationFailureException("The token has expired", e); + } catch (UnsupportedJwtException e) { + log.warn("The token is unsupported: {}", e.getMessage()); + throw new AuthenticationFailureException("The token is unsupported", e); + } catch (IllegalArgumentException e) { + log.warn("Claims are empty: {}", e.getMessage()); + throw new AuthenticationFailureException("Claims are empty", e); + } + } + + private Claims tryAllKeys(String token) throws AuthenticationFailureException { + // First try the active key + try { + JwtVerificationKey activeKey = keyPersistenceService.getActiveKey(); + PublicKey publicKey = + keyPersistenceService.decodePublicKey(activeKey.getVerifyingKey()); + return Jwts.parser() + .verifyWith(publicKey) + .build() + .parseSignedClaims(token) + .getPayload(); + } catch (SignatureException + | NoSuchAlgorithmException + | InvalidKeySpecException activeKeyException) { + log.debug("Active key failed, trying all available keys from cache"); + + // If active key fails, try all available keys from cache + List allKeys = + keyPersistenceService.getKeysEligibleForCleanup( + LocalDateTime.now().plusDays(1)); + + for (JwtVerificationKey verificationKey : allKeys) { + try { + PublicKey publicKey = + keyPersistenceService.decodePublicKey( + verificationKey.getVerifyingKey()); + return Jwts.parser() + .verifyWith(publicKey) + .build() + .parseSignedClaims(token) + .getPayload(); + } catch (SignatureException + | NoSuchAlgorithmException + | InvalidKeySpecException e) { + log.debug( + "Key {} failed to verify token, trying next key", + verificationKey.getKeyId()); + // Continue to next key + } + } + + throw new AuthenticationFailureException( + "Token signature could not be verified with any available key", + activeKeyException); + } + } + + @Override + public String extractToken(HttpServletRequest request) { + Cookie[] cookies = request.getCookies(); + + if (cookies != null) { + for (Cookie cookie : cookies) { + if (JWT_COOKIE_NAME.equals(cookie.getName())) { + return cookie.getValue(); + } + } + } + + return null; + } + + @Override + public void addToken(HttpServletResponse response, String token) { + ResponseCookie cookie = + ResponseCookie.from(JWT_COOKIE_NAME, Newlines.stripAll(token)) + .httpOnly(true) + .secure(secureCookie) + .sameSite("Strict") + .maxAge(EXPIRATION / 1000) + .path("/") + .build(); + + response.addHeader("Set-Cookie", cookie.toString()); + } + + @Override + public void clearToken(HttpServletResponse response) { + ResponseCookie cookie = + ResponseCookie.from(JWT_COOKIE_NAME, "") + .httpOnly(true) + .secure(secureCookie) + .sameSite("None") + .maxAge(0) + .path("/") + .build(); + + response.addHeader("Set-Cookie", cookie.toString()); + } + + @Override + public boolean isJwtEnabled() { + return v2Enabled; + } + + private String extractKeyId(String token) { + try { + PublicKey signingKey = + keyPersistenceService.decodePublicKey( + keyPersistenceService.getActiveKey().getVerifyingKey()); + + String keyId = + (String) + Jwts.parser() + .verifyWith(signingKey) + .build() + .parse(token) + .getHeader() + .get("kid"); + log.debug("Extracted key ID from token: {}", keyId); + return keyId; + } catch (Exception e) { + log.warn("Failed to extract key ID from token header: {}", e.getMessage()); + return null; + } + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtServiceInterface.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtServiceInterface.java new file mode 100644 index 000000000..7cdca8209 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtServiceInterface.java @@ -0,0 +1,90 @@ +package stirling.software.proprietary.security.service; + +import java.util.Map; + +import org.springframework.security.core.Authentication; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public interface JwtServiceInterface { + + /** + * Generate a JWT token for the authenticated user + * + * @param authentication Spring Security authentication object + * @return JWT token as a string + */ + String generateToken(Authentication authentication, Map claims); + + /** + * Generate a JWT token for a specific username + * + * @param username the username for which to generate the token + * @param claims additional claims to include in the token + * @return JWT token as a string + */ + String generateToken(String username, Map claims); + + /** + * Validate a JWT token + * + * @param token the JWT token to validate + * @return true if token is valid, false otherwise + */ + void validateToken(String token); + + /** + * Extract username from JWT token + * + * @param token the JWT token + * @return username extracted from token + */ + String extractUsername(String token); + + /** + * Extract all claims from JWT token + * + * @param token the JWT token + * @return map of claims + */ + Map extractClaims(String token); + + /** + * Check if token is expired + * + * @param token the JWT token + * @return true if token is expired, false otherwise + */ + boolean isTokenExpired(String token); + + /** + * Extract JWT token from HTTP request (header or cookie) + * + * @param request HTTP servlet request + * @return JWT token if found, null otherwise + */ + String extractToken(HttpServletRequest request); + + /** + * Add JWT token to HTTP response (header and cookie) + * + * @param response HTTP servlet response + * @param token JWT token to add + */ + void addToken(HttpServletResponse response, String token); + + /** + * Clear JWT token from HTTP response (remove cookie) + * + * @param response HTTP servlet response + */ + void clearToken(HttpServletResponse response); + + /** + * Check if JWT authentication is enabled + * + * @return true if JWT is enabled, false otherwise + */ + boolean isJwtEnabled(); +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPairCleanupService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPairCleanupService.java new file mode 100644 index 000000000..b419f78fe --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPairCleanupService.java @@ -0,0 +1,88 @@ +package stirling.software.proprietary.security.service; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.LocalDateTime; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBooleanProperty; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import jakarta.annotation.PostConstruct; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.model.JwtVerificationKey; + +@Slf4j +@Service +@ConditionalOnBooleanProperty("v2") +public class KeyPairCleanupService { + + private final KeyPersistenceService keyPersistenceService; + private final ApplicationProperties.Security.Jwt jwtProperties; + + @Autowired + public KeyPairCleanupService( + KeyPersistenceService keyPersistenceService, + ApplicationProperties applicationProperties) { + this.keyPersistenceService = keyPersistenceService; + this.jwtProperties = applicationProperties.getSecurity().getJwt(); + } + + @Transactional + @PostConstruct + @Scheduled(fixedDelay = 1, timeUnit = TimeUnit.DAYS) + public void cleanup() { + if (!jwtProperties.isEnableKeyCleanup() || !keyPersistenceService.isKeystoreEnabled()) { + return; + } + + LocalDateTime cutoffDate = + LocalDateTime.now().minusDays(jwtProperties.getKeyRetentionDays()); + + List eligibleKeys = + keyPersistenceService.getKeysEligibleForCleanup(cutoffDate); + if (eligibleKeys.isEmpty()) { + return; + } + + log.info("Removing keys older than retention period"); + removeKeys(eligibleKeys); + keyPersistenceService.refreshActiveKeyPair(); + } + + private void removeKeys(List keys) { + keys.forEach( + key -> { + try { + keyPersistenceService.removeKey(key.getKeyId()); + removePrivateKey(key.getKeyId()); + } catch (IOException e) { + log.warn("Failed to remove key: {}", key.getKeyId(), e); + } + }); + } + + private void removePrivateKey(String keyId) throws IOException { + if (!keyPersistenceService.isKeystoreEnabled()) { + return; + } + + Path privateKeyDirectory = Paths.get(InstallationPathConfig.getPrivateKeyPath()); + Path keyFile = privateKeyDirectory.resolve(keyId + KeyPersistenceService.KEY_SUFFIX); + + if (Files.exists(keyFile)) { + Files.delete(keyFile); + log.debug("Deleted private key: {}", keyFile); + } + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPersistenceService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPersistenceService.java new file mode 100644 index 000000000..48bcddac0 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPersistenceService.java @@ -0,0 +1,243 @@ +package stirling.software.proprietary.security.service; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.security.spec.X509EncodedKeySpec; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Base64; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.cache.Cache; +import org.springframework.cache.CacheManager; +import org.springframework.cache.annotation.CacheEvict; +import org.springframework.cache.caffeine.CaffeineCache; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import jakarta.annotation.PostConstruct; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.model.JwtVerificationKey; + +@Slf4j +@Service +public class KeyPersistenceService implements KeyPersistenceServiceInterface { + + public static final String KEY_SUFFIX = ".key"; + + private final ApplicationProperties.Security.Jwt jwtProperties; + private final CacheManager cacheManager; + private final Cache verifyingKeyCache; + + private volatile JwtVerificationKey activeKey; + + @Autowired + public KeyPersistenceService( + ApplicationProperties applicationProperties, CacheManager cacheManager) { + this.jwtProperties = applicationProperties.getSecurity().getJwt(); + this.cacheManager = cacheManager; + this.verifyingKeyCache = cacheManager.getCache("verifyingKeys"); + } + + @PostConstruct + public void initializeKeystore() { + if (!isKeystoreEnabled()) { + return; + } + + try { + ensurePrivateKeyDirectoryExists(); + loadKeyPair(); + } catch (Exception e) { + log.error("Failed to initialize keystore, using in-memory generation", e); + } + } + + private void loadKeyPair() { + if (activeKey == null) { + generateAndStoreKeypair(); + } + } + + @Transactional + private JwtVerificationKey generateAndStoreKeypair() { + JwtVerificationKey verifyingKey = null; + + try { + KeyPair keyPair = generateRSAKeypair(); + String keyId = generateKeyId(); + + storePrivateKey(keyId, keyPair.getPrivate()); + verifyingKey = new JwtVerificationKey(keyId, encodePublicKey(keyPair.getPublic())); + verifyingKeyCache.put(keyId, verifyingKey); + activeKey = verifyingKey; + } catch (IOException e) { + log.error("Failed to generate and store keypair", e); + } + + return verifyingKey; + } + + @Override + public JwtVerificationKey getActiveKey() { + if (activeKey == null) { + return generateAndStoreKeypair(); + } + return activeKey; + } + + @Override + public Optional getKeyPair(String keyId) { + if (!isKeystoreEnabled()) { + return Optional.empty(); + } + + try { + JwtVerificationKey verifyingKey = + verifyingKeyCache.get(keyId, JwtVerificationKey.class); + + if (verifyingKey == null) { + log.warn("No signing key found in database for keyId: {}", keyId); + return Optional.empty(); + } + + PrivateKey privateKey = loadPrivateKey(keyId); + PublicKey publicKey = decodePublicKey(verifyingKey.getVerifyingKey()); + + return Optional.of(new KeyPair(publicKey, privateKey)); + } catch (Exception e) { + log.error("Failed to load keypair for keyId: {}", keyId, e); + return Optional.empty(); + } + } + + @Override + public boolean isKeystoreEnabled() { + return jwtProperties.isEnableKeystore(); + } + + @Override + public JwtVerificationKey refreshActiveKeyPair() { + return generateAndStoreKeypair(); + } + + @Override + @CacheEvict( + value = {"verifyingKeys"}, + key = "#keyId", + condition = "#root.target.isKeystoreEnabled()") + public void removeKey(String keyId) { + verifyingKeyCache.evict(keyId); + } + + @Override + public List getKeysEligibleForCleanup(LocalDateTime cutoffDate) { + CaffeineCache caffeineCache = (CaffeineCache) verifyingKeyCache; + com.github.benmanes.caffeine.cache.Cache nativeCache = + caffeineCache.getNativeCache(); + + log.debug( + "Cache size: {}, Checking {} keys for cleanup", + nativeCache.estimatedSize(), + nativeCache.asMap().size()); + + return nativeCache.asMap().values().stream() + .filter(value -> value instanceof JwtVerificationKey) + .map(value -> (JwtVerificationKey) value) + .filter( + key -> { + boolean eligible = key.getCreatedAt().isBefore(cutoffDate); + log.debug( + "Key {} created at {}, eligible for cleanup: {}", + key.getKeyId(), + key.getCreatedAt(), + eligible); + return eligible; + }) + .collect(Collectors.toList()); + } + + private String generateKeyId() { + return "jwt-key-" + + LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd-HHmmss")); + } + + private KeyPair generateRSAKeypair() { + KeyPairGenerator keyPairGenerator = null; + + try { + keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + } catch (NoSuchAlgorithmException e) { + log.error("Failed to initialize RSA key pair generator", e); + } + + return keyPairGenerator.generateKeyPair(); + } + + private void ensurePrivateKeyDirectoryExists() throws IOException { + Path keyPath = Paths.get(InstallationPathConfig.getPrivateKeyPath()); + + if (!Files.exists(keyPath)) { + Files.createDirectories(keyPath); + } + } + + private void storePrivateKey(String keyId, PrivateKey privateKey) throws IOException { + Path keyFile = + Paths.get(InstallationPathConfig.getPrivateKeyPath()).resolve(keyId + KEY_SUFFIX); + String encodedKey = Base64.getEncoder().encodeToString(privateKey.getEncoded()); + Files.writeString(keyFile, encodedKey); + + // Set read/write to only the owner + keyFile.toFile().setReadable(true, true); + keyFile.toFile().setWritable(true, true); + keyFile.toFile().setExecutable(false, false); + } + + private PrivateKey loadPrivateKey(String keyId) + throws IOException, NoSuchAlgorithmException, InvalidKeySpecException { + Path keyFile = + Paths.get(InstallationPathConfig.getPrivateKeyPath()).resolve(keyId + KEY_SUFFIX); + + if (!Files.exists(keyFile)) { + throw new IOException("Private key not found: " + keyFile); + } + + String encodedKey = Files.readString(keyFile); + byte[] keyBytes = Base64.getDecoder().decode(encodedKey); + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + + return keyFactory.generatePrivate(keySpec); + } + + private String encodePublicKey(PublicKey publicKey) { + return Base64.getEncoder().encodeToString(publicKey.getEncoded()); + } + + public PublicKey decodePublicKey(String encodedKey) + throws NoSuchAlgorithmException, InvalidKeySpecException { + byte[] keyBytes = Base64.getDecoder().decode(encodedKey); + X509EncodedKeySpec keySpec = new X509EncodedKeySpec(keyBytes); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + return keyFactory.generatePublic(keySpec); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPersistenceServiceInterface.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPersistenceServiceInterface.java new file mode 100644 index 000000000..f3050472e --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/KeyPersistenceServiceInterface.java @@ -0,0 +1,29 @@ +package stirling.software.proprietary.security.service; + +import java.security.KeyPair; +import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; +import java.security.spec.InvalidKeySpecException; +import java.time.LocalDateTime; +import java.util.List; +import java.util.Optional; + +import stirling.software.proprietary.security.model.JwtVerificationKey; + +public interface KeyPersistenceServiceInterface { + + JwtVerificationKey getActiveKey(); + + Optional getKeyPair(String keyId); + + boolean isKeystoreEnabled(); + + JwtVerificationKey refreshActiveKeyPair(); + + List getKeysEligibleForCleanup(LocalDateTime cutoffDate); + + void removeKey(String keyId); + + PublicKey decodePublicKey(String encodedKey) + throws NoSuchAlgorithmException, InvalidKeySpecException; +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java index 50c8027f6..6f213b25e 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/service/UserService.java @@ -15,7 +15,6 @@ import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.userdetails.UserDetails; @@ -61,19 +60,9 @@ public class UserService implements UserServiceInterface { private final ApplicationProperties.Security.OAUTH2 oAuth2; - @Transactional - public void migrateOauth2ToSSO() { - userRepository - .findByAuthenticationTypeIgnoreCase("OAUTH2") - .forEach( - user -> { - user.setAuthenticationType(AuthenticationType.SSO); - userRepository.save(user); - }); - } - // Handle OAUTH2 login and user auto creation. - public void processSSOPostLogin(String username, boolean autoCreateUser) + public void processSSOPostLogin( + String username, boolean autoCreateUser, AuthenticationType type) throws IllegalArgumentException, SQLException, UnsupportedProviderException { if (!isUsernameValid(username)) { return; @@ -83,7 +72,7 @@ public class UserService implements UserServiceInterface { return; } if (autoCreateUser) { - saveUser(username, AuthenticationType.SSO); + saveUser(username, type); } } @@ -100,10 +89,7 @@ public class UserService implements UserServiceInterface { } private Collection getAuthorities(User user) { - // Convert each Authority object into a SimpleGrantedAuthority object. - return user.getAuthorities().stream() - .map((Authority authority) -> new SimpleGrantedAuthority(authority.getAuthority())) - .toList(); + return user.getAuthorities(); } private String generateApiKey() { diff --git a/app/proprietary/src/main/resources/static/js/audit/dashboard.js b/app/proprietary/src/main/resources/static/js/audit/dashboard.js index 5cc670908..c0b93bd8e 100644 --- a/app/proprietary/src/main/resources/static/js/audit/dashboard.js +++ b/app/proprietary/src/main/resources/static/js/audit/dashboard.js @@ -230,7 +230,7 @@ function loadAuditData(targetPage, realPageSize) { document.getElementById('page-indicator').textContent = `Page ${requestedPage + 1} of ?`; } - fetch(url) + fetchWithCsrf(url) .then(response => { return response.json(); }) @@ -302,7 +302,7 @@ function loadStats(days) { showLoading('user-chart-loading'); showLoading('time-chart-loading'); - fetch(`/audit/stats?days=${days}`) + fetchWithCsrf(`/audit/stats?days=${days}`) .then(response => response.json()) .then(data => { document.getElementById('total-events').textContent = data.totalEvents; @@ -835,7 +835,7 @@ function hideLoading(id) { // Load event types from the server for filter dropdowns function loadEventTypes() { - fetch('/audit/types') + fetchWithCsrf('/audit/types') .then(response => response.json()) .then(types => { if (!types || types.length === 0) { diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java index 04ca4c35f..7a4076260 100644 --- a/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/CustomLogoutSuccessHandlerTest.java @@ -14,12 +14,18 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import stirling.software.common.configuration.AppConfig; import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.service.JwtServiceInterface; @ExtendWith(MockitoExtension.class) class CustomLogoutSuccessHandlerTest { - @Mock private ApplicationProperties applicationProperties; + @Mock private ApplicationProperties.Security securityProperties; + + @Mock private AppConfig appConfig; + + @Mock private JwtServiceInterface jwtService; @InjectMocks private CustomLogoutSuccessHandler customLogoutSuccessHandler; @@ -27,9 +33,12 @@ class CustomLogoutSuccessHandlerTest { void testSuccessfulLogout() throws IOException { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); - String logoutPath = "logout=true"; + String token = "token"; + String logoutPath = "/login?logout=true"; when(response.isCommitted()).thenReturn(false); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).clearToken(response); when(request.getContextPath()).thenReturn(""); when(response.encodeRedirectURL(logoutPath)).thenReturn(logoutPath); @@ -38,12 +47,30 @@ class CustomLogoutSuccessHandlerTest { verify(response).sendRedirect(logoutPath); } + @Test + void testSuccessfulLogoutViaJWT() throws IOException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + String logoutPath = "/login?logout=true"; + String token = "token"; + + when(response.isCommitted()).thenReturn(false); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).clearToken(response); + when(request.getContextPath()).thenReturn(""); + when(response.encodeRedirectURL(logoutPath)).thenReturn(logoutPath); + + customLogoutSuccessHandler.onLogoutSuccess(request, response, null); + + verify(response).sendRedirect(logoutPath); + verify(jwtService).clearToken(response); + } + @Test void testSuccessfulLogoutViaOAuth2() throws IOException { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken oAuth2AuthenticationToken = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -54,8 +81,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(oAuth2AuthenticationToken.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, oAuth2AuthenticationToken); @@ -70,7 +96,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -84,8 +109,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -101,7 +125,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -111,8 +134,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -127,7 +149,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -138,8 +159,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -154,7 +174,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -167,8 +186,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -183,7 +201,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -198,8 +215,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -214,7 +230,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -230,8 +245,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -246,7 +260,6 @@ class CustomLogoutSuccessHandlerTest { HttpServletRequest request = mock(HttpServletRequest.class); HttpServletResponse response = mock(HttpServletResponse.class); OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - ApplicationProperties.Security security = mock(ApplicationProperties.Security.class); ApplicationProperties.Security.OAUTH2 oauth = mock(ApplicationProperties.Security.OAUTH2.class); @@ -259,8 +272,7 @@ class CustomLogoutSuccessHandlerTest { when(request.getServerName()).thenReturn("localhost"); when(request.getServerPort()).thenReturn(8080); when(request.getContextPath()).thenReturn(""); - when(applicationProperties.getSecurity()).thenReturn(security); - when(security.getOauth2()).thenReturn(oauth); + when(securityProperties.getOauth2()).thenReturn(oauth); when(authentication.getAuthorizedClientRegistrationId()).thenReturn("test"); customLogoutSuccessHandler.onLogoutSuccess(request, response, authentication); diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/JwtAuthenticationEntryPointTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/JwtAuthenticationEntryPointTest.java new file mode 100644 index 000000000..a47f45318 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/JwtAuthenticationEntryPointTest.java @@ -0,0 +1,38 @@ +package stirling.software.proprietary.security; + +import static org.mockito.Mockito.*; + +import java.io.IOException; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; + +@ExtendWith(MockitoExtension.class) +class JwtAuthenticationEntryPointTest { + + @Mock private HttpServletRequest request; + + @Mock private HttpServletResponse response; + + @Mock private AuthenticationFailureException authException; + + @InjectMocks private JwtAuthenticationEntryPoint jwtAuthenticationEntryPoint; + + @Test + void testCommence() throws IOException { + String errorMessage = "Authentication failed"; + when(authException.getMessage()).thenReturn(errorMessage); + + jwtAuthenticationEntryPoint.commence(request, response, authException); + + verify(response).sendError(HttpServletResponse.SC_UNAUTHORIZED, errorMessage); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilterTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilterTest.java new file mode 100644 index 000000000..d3f484486 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/filter/JwtAuthenticationFilterTest.java @@ -0,0 +1,242 @@ +package stirling.software.proprietary.security.filter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.web.AuthenticationEntryPoint; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.service.CustomUserDetailsService; +import stirling.software.proprietary.security.service.JwtServiceInterface; +import stirling.software.proprietary.security.service.UserService; + +@Disabled +@ExtendWith(MockitoExtension.class) +class JwtAuthenticationFilterTest { + + @Mock private JwtServiceInterface jwtService; + + @Mock private CustomUserDetailsService userDetailsService; + + @Mock private UserService userService; + + @Mock private ApplicationProperties.Security securityProperties; + + @Mock private HttpServletRequest request; + + @Mock private HttpServletResponse response; + + @Mock private FilterChain filterChain; + + @Mock private UserDetails userDetails; + + @Mock private SecurityContext securityContext; + + @Mock private AuthenticationEntryPoint authenticationEntryPoint; + + @InjectMocks private JwtAuthenticationFilter jwtAuthenticationFilter; + + @Test + void shouldNotAuthenticateWhenJwtDisabled() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(false); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(filterChain).doFilter(request, response); + verify(jwtService, never()).extractToken(any()); + } + + @Test + void shouldNotFilterWhenPageIsLogin() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/login"); + when(request.getContextPath()).thenReturn("/login"); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void testDoFilterInternal() throws ServletException, IOException { + String token = "valid-jwt-token"; + String newToken = "new-jwt-token"; + String username = "testuser"; + Map claims = Map.of("sub", username, "authType", "WEB"); + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getContextPath()).thenReturn("/"); + when(request.getRequestURI()).thenReturn("/protected"); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).validateToken(token); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(userDetails.getAuthorities()).thenReturn(Collections.emptyList()); + when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); + + try (MockedStatic mockedSecurityContextHolder = + mockStatic(SecurityContextHolder.class)) { + UsernamePasswordAuthenticationToken authToken = + new UsernamePasswordAuthenticationToken( + userDetails, null, userDetails.getAuthorities()); + + when(securityContext.getAuthentication()).thenReturn(null).thenReturn(authToken); + mockedSecurityContextHolder + .when(SecurityContextHolder::getContext) + .thenReturn(securityContext); + when(jwtService.generateToken( + any(UsernamePasswordAuthenticationToken.class), eq(claims))) + .thenReturn(newToken); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(jwtService).validateToken(token); + verify(jwtService).extractClaims(token); + verify(userDetailsService).loadUserByUsername(username); + verify(securityContext) + .setAuthentication(any(UsernamePasswordAuthenticationToken.class)); + verify(jwtService) + .generateToken(any(UsernamePasswordAuthenticationToken.class), eq(claims)); + verify(jwtService).addToken(response, newToken); + verify(filterChain).doFilter(request, response); + } + } + + @Test + void testDoFilterInternalWithMissingTokenForRootPath() throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/"); + when(request.getMethod()).thenReturn("GET"); + when(jwtService.extractToken(request)).thenReturn(null); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(response).sendRedirect("/login"); + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void validationFailsWithInvalidToken() throws ServletException, IOException { + String token = "invalid-jwt-token"; + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(token); + doThrow(new AuthenticationFailureException("Invalid token")) + .when(jwtService) + .validateToken(token); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(jwtService).validateToken(token); + verify(authenticationEntryPoint) + .commence(eq(request), eq(response), any(AuthenticationFailureException.class)); + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void validationFailsWithExpiredToken() throws ServletException, IOException { + String token = "expired-jwt-token"; + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(token); + doThrow(new AuthenticationFailureException("The token has expired")) + .when(jwtService) + .validateToken(token); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(jwtService).validateToken(token); + verify(authenticationEntryPoint).commence(eq(request), eq(response), any()); + verify(filterChain, never()).doFilter(request, response); + } + + @Test + void exceptionThrown_WhenUserNotFound() throws ServletException, IOException { + String token = "valid-jwt-token"; + String username = "nonexistentuser"; + Map claims = Map.of("sub", username, "authType", "WEB"); + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(token); + doNothing().when(jwtService).validateToken(token); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(userDetailsService.loadUserByUsername(username)).thenReturn(null); + + try (MockedStatic mockedSecurityContextHolder = + mockStatic(SecurityContextHolder.class)) { + when(securityContext.getAuthentication()).thenReturn(null); + mockedSecurityContextHolder + .when(SecurityContextHolder::getContext) + .thenReturn(securityContext); + + UsernameNotFoundException result = + assertThrows( + UsernameNotFoundException.class, + () -> + jwtAuthenticationFilter.doFilterInternal( + request, response, filterChain)); + + assertEquals("User not found: " + username, result.getMessage()); + verify(userDetailsService).loadUserByUsername(username); + verify(filterChain, never()).doFilter(request, response); + } + } + + @Test + void testAuthenticationEntryPointCalledWithCorrectException() + throws ServletException, IOException { + when(jwtService.isJwtEnabled()).thenReturn(true); + when(request.getRequestURI()).thenReturn("/protected"); + when(request.getContextPath()).thenReturn("/"); + when(jwtService.extractToken(request)).thenReturn(null); + + jwtAuthenticationFilter.doFilterInternal(request, response, filterChain); + + verify(authenticationEntryPoint) + .commence( + eq(request), + eq(response), + argThat( + exception -> + exception + .getMessage() + .equals("JWT is missing from the request"))); + verify(filterChain, never()).doFilter(request, response); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java new file mode 100644 index 000000000..1aa083cc0 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java @@ -0,0 +1,247 @@ +package stirling.software.proprietary.security.saml2; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import stirling.software.proprietary.security.service.JwtServiceInterface; + +@ExtendWith(MockitoExtension.class) +class JwtSaml2AuthenticationRequestRepositoryTest { + + private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; + + private Map tokenStore; + + @Mock private JwtServiceInterface jwtService; + + @Mock private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + + private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository; + + @BeforeEach + void setUp() { + tokenStore = new ConcurrentHashMap<>(); + jwtSaml2AuthenticationRequestRepository = + new JwtSaml2AuthenticationRequestRepository( + tokenStore, jwtService, relyingPartyRegistrationRepository); + } + + @Test + void saveAuthenticationRequest() { + var authRequest = mock(Saml2PostAuthenticationRequest.class); + var request = mock(MockHttpServletRequest.class); + var response = mock(MockHttpServletResponse.class); + String token = "testToken"; + String id = "testId"; + String relayState = "testRelayState"; + String authnRequestUri = "example.com/authnRequest"; + Map claims = Map.of(); + String samlRequest = "testSamlRequest"; + String relyingPartyRegistrationId = "stirling-pdf"; + + when(jwtService.isJwtEnabled()).thenReturn(true); + when(authRequest.getRelayState()).thenReturn(relayState); + when(authRequest.getId()).thenReturn(id); + when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri); + when(authRequest.getSamlRequest()).thenReturn(samlRequest); + when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId); + when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token); + + jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest( + authRequest, request, response); + + verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState); + verify(response).addHeader(SAML_REQUEST_TOKEN, relayState); + } + + @Test + void saveAuthenticationRequestWithNullRequest() { + var request = mock(MockHttpServletRequest.class); + var response = mock(MockHttpServletResponse.class); + + jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response); + + assertTrue(tokenStore.isEmpty()); + } + + @Test + void loadAuthenticationRequest() { + var request = mock(MockHttpServletRequest.class); + var relyingPartyRegistration = mock(RelyingPartyRegistration.class); + var assertingPartyMetadata = mock(AssertingPartyMetadata.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = + Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")) + .thenReturn(relyingPartyRegistration); + when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); + when(relyingPartyRegistration.getAssertingPartyMetadata()) + .thenReturn(assertingPartyMetadata); + when(assertingPartyMetadata.getSingleSignOnServiceLocation()) + .thenReturn("https://example.com/sso"); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNotNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } + + @ParameterizedTest + @NullAndEmptySource + void loadAuthenticationRequestWithInvalidRelayState(String relayState) { + var request = mock(MockHttpServletRequest.class); + when(request.getParameter("RelayState")).thenReturn(relayState); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void loadAuthenticationRequestWithNonExistentToken() { + var request = mock(MockHttpServletRequest.class); + when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void loadAuthenticationRequestWithNullRelyingPartyRegistration() { + var request = mock(MockHttpServletRequest.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = + Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")) + .thenReturn(null); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void removeAuthenticationRequest() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + var relyingPartyRegistration = mock(RelyingPartyRegistration.class); + var assertingPartyMetadata = mock(AssertingPartyMetadata.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = + Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")) + .thenReturn(relyingPartyRegistration); + when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); + when(relyingPartyRegistration.getAssertingPartyMetadata()) + .thenReturn(assertingPartyMetadata); + when(assertingPartyMetadata.getSingleSignOnServiceLocation()) + .thenReturn("https://example.com/sso"); + tokenStore.put(relayState, token); + + var result = + jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( + request, response); + + assertNotNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } + + @Test + void removeAuthenticationRequestWithNullRelayState() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + when(request.getParameter("RelayState")).thenReturn(null); + + var result = + jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( + request, response); + + assertNull(result); + } + + @Test + void removeAuthenticationRequestWithNonExistentToken() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); + + var result = + jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( + request, response); + + assertNull(result); + } + + @Test + void removeAuthenticationRequestWithOnlyRelayState() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + String relayState = "testRelayState"; + + when(request.getParameter("RelayState")).thenReturn(relayState); + + var result = + jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest( + request, response); + + assertNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/service/JwtServiceTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/JwtServiceTest.java new file mode 100644 index 000000000..6f9af4c54 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/JwtServiceTest.java @@ -0,0 +1,389 @@ +package stirling.software.proprietary.security.service; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.contains; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.core.Authentication; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import stirling.software.proprietary.security.model.JwtVerificationKey; +import stirling.software.proprietary.security.model.User; +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; + +@ExtendWith(MockitoExtension.class) +class JwtServiceTest { + + @Mock private Authentication authentication; + + @Mock private User userDetails; + + @Mock private HttpServletRequest request; + + @Mock private HttpServletResponse response; + + @Mock private KeyPersistenceServiceInterface keystoreService; + + private JwtService jwtService; + private KeyPair testKeyPair; + private JwtVerificationKey testVerificationKey; + + @BeforeEach + void setUp() throws NoSuchAlgorithmException { + // Generate a test keypair + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + testKeyPair = keyPairGenerator.generateKeyPair(); + + // Create test verification key + String encodedPublicKey = + Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + testVerificationKey = new JwtVerificationKey("test-key-id", encodedPublicKey); + + jwtService = new JwtService(true, keystoreService); + } + + @Test + void testGenerateTokenWithAuthentication() throws Exception { + String username = "testuser"; + + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, Collections.emptyMap()); + + assertNotNull(token); + assertFalse(token.isEmpty()); + assertEquals(username, jwtService.extractUsername(token)); + } + + @Test + void testGenerateTokenWithUsernameAndClaims() throws Exception { + String username = "testuser"; + Map claims = new HashMap<>(); + claims.put("role", "admin"); + claims.put("department", "IT"); + + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + assertNotNull(token); + assertFalse(token.isEmpty()); + assertEquals(username, jwtService.extractUsername(token)); + + Map extractedClaims = jwtService.extractClaims(token); + assertEquals("admin", extractedClaims.get("role")); + assertEquals("IT", extractedClaims.get("department")); + } + + @Test + void testValidateTokenSuccess() throws Exception { + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn("testuser"); + + String token = jwtService.generateToken(authentication, new HashMap<>()); + + assertDoesNotThrow(() -> jwtService.validateToken(token)); + } + + @Test + void testValidateTokenWithInvalidToken() throws Exception { + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + + assertThrows( + AuthenticationFailureException.class, + () -> { + jwtService.validateToken("invalid-token"); + }); + } + + @Test + void testValidateTokenWithMalformedToken() throws Exception { + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + + AuthenticationFailureException exception = + assertThrows( + AuthenticationFailureException.class, + () -> { + jwtService.validateToken("malformed.token"); + }); + + assertTrue(exception.getMessage().contains("Invalid")); + } + + @Test + void testValidateTokenWithEmptyToken() throws Exception { + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + + AuthenticationFailureException exception = + assertThrows( + AuthenticationFailureException.class, + () -> { + jwtService.validateToken(""); + }); + + assertTrue( + exception.getMessage().contains("Claims are empty") + || exception.getMessage().contains("Invalid")); + } + + @Test + void testExtractUsername() throws Exception { + String username = "testuser"; + User user = mock(User.class); + Map claims = Map.of("sub", "testuser", "authType", "WEB"); + + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(user); + when(user.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + assertEquals(username, jwtService.extractUsername(token)); + } + + @Test + void testExtractUsernameWithInvalidToken() throws Exception { + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + + assertThrows( + AuthenticationFailureException.class, + () -> jwtService.extractUsername("invalid-token")); + } + + @Test + void testExtractClaims() throws Exception { + String username = "testuser"; + Map claims = Map.of("role", "admin", "department", "IT"); + + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + Map extractedClaims = jwtService.extractClaims(token); + + assertEquals("admin", extractedClaims.get("role")); + assertEquals("IT", extractedClaims.get("department")); + assertEquals(username, extractedClaims.get("sub")); + assertEquals("Stirling PDF", extractedClaims.get("iss")); + } + + @Test + void testExtractClaimsWithInvalidToken() throws Exception { + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + + assertThrows( + AuthenticationFailureException.class, + () -> jwtService.extractClaims("invalid-token")); + } + + @Test + void testExtractTokenWithCookie() { + String token = "test-token"; + Cookie[] cookies = {new Cookie("stirling_jwt", token)}; + when(request.getCookies()).thenReturn(cookies); + + assertEquals(token, jwtService.extractToken(request)); + } + + @Test + void testExtractTokenWithNoCookies() { + when(request.getCookies()).thenReturn(null); + + assertNull(jwtService.extractToken(request)); + } + + @Test + void testExtractTokenWithWrongCookie() { + Cookie[] cookies = {new Cookie("OTHER_COOKIE", "value")}; + when(request.getCookies()).thenReturn(cookies); + + assertNull(jwtService.extractToken(request)); + } + + @Test + void testExtractTokenWithInvalidAuthorizationHeader() { + when(request.getCookies()).thenReturn(null); + + assertNull(jwtService.extractToken(request)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testAddToken(boolean secureCookie) throws Exception { + String token = "test-token"; + + // Create new JwtService instance with the secureCookie parameter + JwtService testJwtService = createJwtServiceWithSecureCookie(secureCookie); + + testJwtService.addToken(response, token); + + verify(response).addHeader(eq("Set-Cookie"), contains("stirling_jwt=" + token)); + verify(response).addHeader(eq("Set-Cookie"), contains("HttpOnly")); + + if (secureCookie) { + verify(response).addHeader(eq("Set-Cookie"), contains("Secure")); + } + } + + @Test + void testClearToken() { + jwtService.clearToken(response); + + verify(response).addHeader(eq("Set-Cookie"), contains("stirling_jwt=")); + verify(response).addHeader(eq("Set-Cookie"), contains("Max-Age=0")); + } + + @Test + void testGenerateTokenWithKeyId() throws Exception { + String username = "testuser"; + Map claims = new HashMap<>(); + + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + assertNotNull(token); + assertFalse(token.isEmpty()); + // Verify that the keystore service was called + verify(keystoreService).getActiveKey(); + verify(keystoreService).getKeyPair("test-key-id"); + } + + @Test + void testTokenVerificationWithSpecificKeyId() throws Exception { + String username = "testuser"; + Map claims = new HashMap<>(); + + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + // Generate token with key ID + String token = jwtService.generateToken(authentication, claims); + + // Mock extraction of key ID and verification (lenient to avoid unused stubbing) + lenient() + .when(keystoreService.getKeyPair("test-key-id")) + .thenReturn(Optional.of(testKeyPair)); + + // Verify token can be validated + assertDoesNotThrow(() -> jwtService.validateToken(token)); + assertEquals(username, jwtService.extractUsername(token)); + } + + @Test + void testTokenVerificationFallsBackToActiveKeyWhenKeyIdNotFound() throws Exception { + String username = "testuser"; + Map claims = new HashMap<>(); + + // First, generate a token successfully + when(keystoreService.getActiveKey()).thenReturn(testVerificationKey); + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.of(testKeyPair)); + when(keystoreService.decodePublicKey(testVerificationKey.getVerifyingKey())) + .thenReturn(testKeyPair.getPublic()); + when(authentication.getPrincipal()).thenReturn(userDetails); + when(userDetails.getUsername()).thenReturn(username); + + String token = jwtService.generateToken(authentication, claims); + + // Now mock the scenario for validation - key not found, but fallback works + // Create a fallback key pair that can be used + JwtVerificationKey fallbackKey = + new JwtVerificationKey( + "fallback-key", + Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded())); + + // Mock the specific key lookup to fail, but the active key should work + when(keystoreService.getKeyPair("test-key-id")).thenReturn(Optional.empty()); + when(keystoreService.refreshActiveKeyPair()).thenReturn(fallbackKey); + when(keystoreService.getKeyPair("fallback-key")).thenReturn(Optional.of(testKeyPair)); + + // Should still work by falling back to the active keypair + assertDoesNotThrow(() -> jwtService.validateToken(token)); + assertEquals(username, jwtService.extractUsername(token)); + + // Verify fallback logic was used + verify(keystoreService, atLeast(1)).getActiveKey(); + } + + private JwtService createJwtServiceWithSecureCookie(boolean secureCookie) throws Exception { + // Use reflection to create JwtService with custom secureCookie value + JwtService testService = new JwtService(true, keystoreService); + + // Set the secureCookie field using reflection + java.lang.reflect.Field secureCookieField = + JwtService.class.getDeclaredField("secureCookie"); + secureCookieField.setAccessible(true); + secureCookieField.set(testService, secureCookie); + + return testService; + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeyPersistenceServiceInterfaceTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeyPersistenceServiceInterfaceTest.java new file mode 100644 index 000000000..33b971e5a --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/security/service/KeyPersistenceServiceInterfaceTest.java @@ -0,0 +1,232 @@ +package stirling.software.proprietary.security.service; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Optional; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.cache.CacheManager; +import org.springframework.cache.concurrent.ConcurrentMapCacheManager; + +import stirling.software.common.configuration.InstallationPathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.security.model.JwtVerificationKey; + +@ExtendWith(MockitoExtension.class) +class KeyPersistenceServiceInterfaceTest { + + @Mock private ApplicationProperties applicationProperties; + + @Mock private ApplicationProperties.Security security; + + @Mock private ApplicationProperties.Security.Jwt jwtConfig; + + @TempDir Path tempDir; + + private KeyPersistenceService keyPersistenceService; + private KeyPair testKeyPair; + private CacheManager cacheManager; + + @BeforeEach + void setUp() throws NoSuchAlgorithmException { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + testKeyPair = keyPairGenerator.generateKeyPair(); + + cacheManager = new ConcurrentMapCacheManager("verifyingKeys"); + + lenient().when(applicationProperties.getSecurity()).thenReturn(security); + lenient().when(security.getJwt()).thenReturn(jwtConfig); + lenient().when(jwtConfig.isEnableKeystore()).thenReturn(true); // Default value + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testKeystoreEnabled(boolean keystoreEnabled) { + when(jwtConfig.isEnableKeystore()).thenReturn(keystoreEnabled); + + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + + assertEquals(keystoreEnabled, keyPersistenceService.isKeystoreEnabled()); + } + } + + @Test + void testGetActiveKeypairWhenNoActiveKeyExists() { + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + keyPersistenceService.initializeKeystore(); + + JwtVerificationKey result = keyPersistenceService.getActiveKey(); + + assertNotNull(result); + assertNotNull(result.getKeyId()); + assertNotNull(result.getVerifyingKey()); + } + } + + @Test + void testGetActiveKeyPairWithExistingKey() throws Exception { + String keyId = "test-key-2024-01-01-120000"; + String publicKeyBase64 = + Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + String privateKeyBase64 = + Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded()); + + JwtVerificationKey existingKey = new JwtVerificationKey(keyId, publicKeyBase64); + + Path keyFile = tempDir.resolve(keyId + ".key"); + Files.writeString(keyFile, privateKeyBase64); + + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + keyPersistenceService.initializeKeystore(); + + JwtVerificationKey result = keyPersistenceService.getActiveKey(); + + assertNotNull(result); + assertNotNull(result.getKeyId()); + } + } + + @Test + void testGetKeyPair() throws Exception { + String keyId = "test-key-123"; + String publicKeyBase64 = + Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + String privateKeyBase64 = + Base64.getEncoder().encodeToString(testKeyPair.getPrivate().getEncoded()); + + JwtVerificationKey signingKey = new JwtVerificationKey(keyId, publicKeyBase64); + + Path keyFile = tempDir.resolve(keyId + ".key"); + Files.writeString(keyFile, privateKeyBase64); + + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + + keyPersistenceService + .getClass() + .getDeclaredField("verifyingKeyCache") + .setAccessible(true); + var cache = cacheManager.getCache("verifyingKeys"); + cache.put(keyId, signingKey); + + Optional result = keyPersistenceService.getKeyPair(keyId); + + assertTrue(result.isPresent()); + assertNotNull(result.get().getPublic()); + assertNotNull(result.get().getPrivate()); + } + } + + @Test + void testGetKeyPairNotFound() { + String keyId = "non-existent-key"; + + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + + Optional result = keyPersistenceService.getKeyPair(keyId); + + assertFalse(result.isPresent()); + } + } + + @Test + void testGetKeyPairWhenKeystoreDisabled() { + when(jwtConfig.isEnableKeystore()).thenReturn(false); + + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + + Optional result = keyPersistenceService.getKeyPair("any-key"); + + assertFalse(result.isPresent()); + } + } + + @Test + void testInitializeKeystoreCreatesDirectory() throws IOException { + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + keyPersistenceService.initializeKeystore(); + + assertTrue(Files.exists(tempDir)); + assertTrue(Files.isDirectory(tempDir)); + } + } + + @Test + void testLoadExistingKeypairWithMissingPrivateKeyFile() throws Exception { + String keyId = "test-key-missing-file"; + String publicKeyBase64 = + Base64.getEncoder().encodeToString(testKeyPair.getPublic().getEncoded()); + + JwtVerificationKey existingKey = new JwtVerificationKey(keyId, publicKeyBase64); + + try (MockedStatic mockedStatic = + mockStatic(InstallationPathConfig.class)) { + mockedStatic + .when(InstallationPathConfig::getPrivateKeyPath) + .thenReturn(tempDir.toString()); + keyPersistenceService = new KeyPersistenceService(applicationProperties, cacheManager); + keyPersistenceService.initializeKeystore(); + + JwtVerificationKey result = keyPersistenceService.getActiveKey(); + assertNotNull(result); + assertNotNull(result.getKeyId()); + assertNotNull(result.getVerifyingKey()); + } + } +} diff --git a/exampleYmlFiles/test_cicd.yml b/exampleYmlFiles/test_cicd.yml index 31e24da48..086f862d5 100644 --- a/exampleYmlFiles/test_cicd.yml +++ b/exampleYmlFiles/test_cicd.yml @@ -20,6 +20,7 @@ services: environment: DISABLE_ADDITIONAL_FEATURES: "false" SECURITY_ENABLELOGIN: "true" + V2: "false" PUID: 1002 PGID: 1002 UMASK: "022"