diff --git a/app/core/src/main/resources/static/js/DecryptFiles.js b/app/core/src/main/resources/static/js/DecryptFiles.js index 67349a012..0e5b58a92 100644 --- a/app/core/src/main/resources/static/js/DecryptFiles.js +++ b/app/core/src/main/resources/static/js/DecryptFiles.js @@ -46,10 +46,9 @@ export class DecryptFile { formData.append('password', password); } // Send decryption request - const response = await fetch('/api/v1/security/remove-password', { + const response = await fetchWithCsrf('/api/v1/security/remove-password', { method: 'POST', body: formData, - headers: csrfToken ? {'X-XSRF-TOKEN': csrfToken} : undefined, }); if (response.ok) { diff --git a/app/core/src/main/resources/static/js/downloader.js b/app/core/src/main/resources/static/js/downloader.js index 42ba0c357..b5324dd82 100644 --- a/app/core/src/main/resources/static/js/downloader.js +++ b/app/core/src/main/resources/static/js/downloader.js @@ -218,7 +218,7 @@ formData.append('password', password); // Use handleSingleDownload to send the request - const decryptionResult = await fetch(removePasswordUrl, {method: 'POST', body: formData}); + const decryptionResult = await fetchWithCsrf(removePasswordUrl, {method: 'POST', body: formData}); if (decryptionResult && decryptionResult.blob) { const decryptedBlob = await decryptionResult.blob(); diff --git a/app/core/src/main/resources/static/js/fetch-utils.js b/app/core/src/main/resources/static/js/fetch-utils.js index dfe2604a8..3d202e47b 100644 --- a/app/core/src/main/resources/static/js/fetch-utils.js +++ b/app/core/src/main/resources/static/js/fetch-utils.js @@ -1,3 +1,76 @@ +// JWT Management Utility +window.JWTManager = { + JWT_STORAGE_KEY: 'stirling_jwt', + + // Store JWT token in localStorage + storeToken: function(token) { + if (token) { + localStorage.setItem(this.JWT_STORAGE_KEY, token); + } + }, + + // Get JWT token from localStorage + getToken: function() { + return localStorage.getItem(this.JWT_STORAGE_KEY); + }, + + // Remove JWT token from localStorage + removeToken: function() { + localStorage.removeItem(this.JWT_STORAGE_KEY); + }, + + // Extract JWT from Authorization header in response + extractTokenFromResponse: function(response) { + const authHeader = response.headers.get('Authorization'); + if (authHeader && authHeader.startsWith('Bearer ')) { + const token = authHeader.substring(7); // Remove 'Bearer ' prefix + this.storeToken(token); + return token; + } + return null; + }, + + // Check if user is authenticated (has valid JWT) + isAuthenticated: function() { + const token = this.getToken(); + if (!token) return false; + + try { + // Basic JWT expiration check (decode payload) + const payload = JSON.parse(atob(token.split('.')[1])); + const now = Date.now() / 1000; + return payload.exp > now; + } catch (error) { + console.warn('Invalid JWT token:', error); + this.removeToken(); + return false; + } + }, + + // Logout - remove token and redirect to login + logout: function() { + this.removeToken(); + + // Clear all possible token storage locations + localStorage.removeItem(this.JWT_STORAGE_KEY); + sessionStorage.removeItem(this.JWT_STORAGE_KEY); + + // Clear JWT cookie manually (fallback) + document.cookie = 'STIRLING_JWT=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure'; + + // Perform logout request to clear server-side session + fetch('/logout', { + method: 'POST', + credentials: 'include' + }).then(() => { + window.location.href = '/login'; + }).catch(() => { + // Even if logout fails, redirect to login + window.location.href = '/login'; + }); + } +}; + window.fetchWithCsrf = async function(url, options = {}) { function getCsrfToken() { const cookieValue = document.cookie @@ -24,5 +97,31 @@ window.fetchWithCsrf = async function(url, options = {}) { fetchOptions.headers['X-XSRF-TOKEN'] = csrfToken; } - return fetch(url, fetchOptions); + // Add JWT token to Authorization header if available + const jwtToken = window.JWTManager.getToken(); + if (jwtToken) { + fetchOptions.headers['Authorization'] = `Bearer ${jwtToken}`; + // Include credentials when JWT is enabled + fetchOptions.credentials = 'include'; + } + + // Make the request + const response = await fetch(url, fetchOptions); + + // Extract JWT from response if present + window.JWTManager.extractTokenFromResponse(response); + + // Handle 401 responses (unauthorized) + if (response.status === 401) { + console.warn('Authentication failed, redirecting to login'); + window.JWTManager.logout(); + return response; + } + + return response; +} + +// Enhanced fetch function that always includes JWT +window.fetchWithJWT = async function(url, options = {}) { + return window.fetchWithCsrf(url, options); } diff --git a/app/core/src/main/resources/static/js/jwt-init.js b/app/core/src/main/resources/static/js/jwt-init.js new file mode 100644 index 000000000..72741013b --- /dev/null +++ b/app/core/src/main/resources/static/js/jwt-init.js @@ -0,0 +1,121 @@ +// JWT Initialization Script +// This script handles JWT token extraction during OAuth/Login flows and initializes the JWT manager + +(function() { + // Extract JWT token from URL parameters (for OAuth redirects) + function extractTokenFromUrl() { + const urlParams = new URLSearchParams(window.location.search); + const token = urlParams.get('jwt') || urlParams.get('token'); + if (token) { + window.JWTManager.storeToken(token); + // Clean up URL by removing token parameter + urlParams.delete('jwt'); + urlParams.delete('token'); + const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : ''); + window.history.replaceState({}, '', newUrl); + } + } + + // Extract JWT token from cookie on page load (fallback) + function extractTokenFromCookie() { + const cookieValue = document.cookie + .split('; ') + .find(row => row.startsWith('STIRLING_JWT=')) + ?.split('=')[1]; + + if (cookieValue) { + window.JWTManager.storeToken(cookieValue); + // Clear the cookie since we're using localStorage with consistent SameSite policy + document.cookie = 'STIRLING_JWT=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure'; + } + } + + // Initialize JWT handling when page loads + function initializeJWT() { + // Try to extract token from URL first (OAuth flow) + extractTokenFromUrl(); + + // If no token in URL, try cookie (login flow) + if (!window.JWTManager.getToken()) { + extractTokenFromCookie(); + } + + // Check if user is authenticated + if (window.JWTManager.isAuthenticated()) { + console.log('User is authenticated with JWT'); + } else { + console.log('User is not authenticated or token expired'); + // Only redirect to login if we're not already on login/register pages + const currentPath = window.location.pathname; + if (!currentPath.includes('/login') && + !currentPath.includes('/register') && + !currentPath.includes('/oauth') && + !currentPath.includes('/saml') && + !currentPath.includes('/error')) { + // Redirect to login after a short delay to allow other scripts to load + setTimeout(() => { + window.location.href = '/login'; + }, 100); + } + } + } + + // Override form submissions to include JWT + function enhanceFormSubmissions() { + // Override form submit for login forms + document.addEventListener('submit', function(event) { + const form = event.target; + + // Add JWT to form data if available + const jwtToken = window.JWTManager.getToken(); + if (jwtToken && form.method && form.method.toLowerCase() !== 'get') { + // Create a hidden input for JWT + const jwtInput = document.createElement('input'); + jwtInput.type = 'hidden'; + jwtInput.name = 'jwt'; + jwtInput.value = jwtToken; + form.appendChild(jwtInput); + } + }); + } + + // Add logout functionality to logout buttons + function enhanceLogoutButtons() { + document.addEventListener('click', function(event) { + const element = event.target; + + // Check if clicked element is a logout button/link + if (element.matches('a[href="/logout"], button[data-action="logout"], .logout-btn')) { + event.preventDefault(); + window.JWTManager.logout(); + } + }); + } + + // Initialize when DOM is ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', function() { + initializeJWT(); + enhanceFormSubmissions(); + enhanceLogoutButtons(); + }); + } else { + initializeJWT(); + enhanceFormSubmissions(); + enhanceLogoutButtons(); + } + + // Handle page visibility changes to check token expiration + document.addEventListener('visibilitychange', function() { + if (!document.hidden && !window.JWTManager.isAuthenticated()) { + // Token expired while page was hidden, redirect to login + const currentPath = window.location.pathname; + if (!currentPath.includes('/login') && + !currentPath.includes('/register') && + !currentPath.includes('/oauth') && + !currentPath.includes('/saml')) { + window.location.href = '/login'; + } + } + }); +})(); \ No newline at end of file diff --git a/app/core/src/main/resources/static/js/navbar.js b/app/core/src/main/resources/static/js/navbar.js index a95ff1639..57f916f8a 100644 --- a/app/core/src/main/resources/static/js/navbar.js +++ b/app/core/src/main/resources/static/js/navbar.js @@ -42,39 +42,6 @@ function toolsManager() { }); } -function setupDropdowns() { - const dropdowns = document.querySelectorAll('.navbar-nav > .nav-item.dropdown'); - - dropdowns.forEach((dropdown) => { - const toggle = dropdown.querySelector('[data-bs-toggle="dropdown"]'); - if (!toggle) return; - - // Skip search dropdown, it has its own logic - if (toggle.id === 'searchDropdown') { - return; - } - - dropdown.addEventListener('show.bs.dropdown', () => { - // Find all other open dropdowns and hide them - const openDropdowns = document.querySelectorAll('.navbar-nav .dropdown-menu.show'); - openDropdowns.forEach((menu) => { - const parentDropdown = menu.closest('.dropdown'); - if (parentDropdown && parentDropdown !== dropdown) { - const parentToggle = parentDropdown.querySelector('[data-bs-toggle="dropdown"]'); - if (parentToggle) { - // Get or create Bootstrap dropdown instance - let instance = bootstrap.Dropdown.getInstance(parentToggle); - if (!instance) { - instance = new bootstrap.Dropdown(parentToggle); - } - instance.hide(); - } - } - }); - }); - }); -} - window.tooltipSetup = () => { const tooltipElements = document.querySelectorAll('[title]'); @@ -89,54 +56,37 @@ window.tooltipSetup = () => { document.body.appendChild(customTooltip); element.addEventListener('mouseenter', (event) => { - if (window.innerWidth >= 1200) { - customTooltip.style.display = 'block'; - customTooltip.style.left = `${event.pageX + 10}px`; - customTooltip.style.top = `${event.pageY + 10}px`; - } + customTooltip.style.display = 'block'; + customTooltip.style.left = `${event.pageX + 10}px`; // Position tooltip slightly away from the cursor + customTooltip.style.top = `${event.pageY + 10}px`; }); + // Update the position of the tooltip as the user moves the mouse element.addEventListener('mousemove', (event) => { - if (window.innerWidth >= 1200) { - customTooltip.style.left = `${event.pageX + 10}px`; - customTooltip.style.top = `${event.pageY + 10}px`; - } + customTooltip.style.left = `${event.pageX + 10}px`; + customTooltip.style.top = `${event.pageY + 10}px`; }); + // Hide the tooltip when the mouse leaves element.addEventListener('mouseleave', () => { customTooltip.style.display = 'none'; }); }); }; - -// Override the bootstrap dropdown styles for mobile -function fixNavbarDropdownStyles() { - if (window.innerWidth < 1200) { - document.querySelectorAll('.navbar .dropdown-menu').forEach(function(menu) { - menu.style.transform = 'none'; - menu.style.transformOrigin = 'none'; - menu.style.left = '0'; - menu.style.right = '0'; - menu.style.maxWidth = '95vw'; - menu.style.width = '100vw'; - menu.style.marginBottom = '0'; - }); - } else { - document.querySelectorAll('.navbar .dropdown-menu').forEach(function(menu) { - menu.style.transform = ''; - menu.style.transformOrigin = ''; - menu.style.left = ''; - menu.style.right = ''; - menu.style.maxWidth = ''; - menu.style.width = ''; - menu.style.marginBottom = ''; - }); - } -} - document.addEventListener('DOMContentLoaded', () => { tooltipSetup(); - setupDropdowns(); - fixNavbarDropdownStyles(); + + // Setup logout button functionality + const logoutButton = document.querySelector('a[href="/logout"]'); + if (logoutButton) { + logoutButton.addEventListener('click', function(event) { + event.preventDefault(); + if (window.JWTManager) { + window.JWTManager.logout(); + } else { + // Fallback if JWTManager is not available + window.location.href = '/logout'; + } + }); + } }); -window.addEventListener('resize', fixNavbarDropdownStyles); diff --git a/app/core/src/main/resources/static/js/usage.js b/app/core/src/main/resources/static/js/usage.js index 624e4ec78..443a27ce1 100644 --- a/app/core/src/main/resources/static/js/usage.js +++ b/app/core/src/main/resources/static/js/usage.js @@ -102,7 +102,7 @@ async function fetchEndpointData() { refreshBtn.classList.add('refreshing'); refreshBtn.disabled = true; - const response = await fetch('/api/v1/info/load/all'); + const response = await fetchWithCsrf('/api/v1/info/load/all'); if (!response.ok) { throw new Error('Network response was not ok'); } diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java index 45869b05c..aaf6ea2c3 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/security/CustomLogoutSuccessHandler.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.List; import org.springframework.core.io.Resource; -import org.springframework.http.HttpStatus; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; @@ -117,10 +116,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { samlClient.setSPKeys(certificate, privateKey); // Redirect to identity provider for logout. todo: add relay state - // samlClient.redirectToIdentityProvider(response, null, nameIdValue); - samlClient.processLogoutRequestPostFromIdentityProvider(request, nameIdValue); - samlClient.redirectToIdentityProviderLogout( - response, HttpStatus.OK.name(), nameIdValue); + samlClient.redirectToIdentityProvider(response, null, nameIdValue); } catch (Exception e) { log.error( "Error retrieving logout URL from Provider {} for user {}", diff --git a/app/proprietary/src/main/resources/static/js/audit/dashboard.js b/app/proprietary/src/main/resources/static/js/audit/dashboard.js index 5cc670908..c0b93bd8e 100644 --- a/app/proprietary/src/main/resources/static/js/audit/dashboard.js +++ b/app/proprietary/src/main/resources/static/js/audit/dashboard.js @@ -230,7 +230,7 @@ function loadAuditData(targetPage, realPageSize) { document.getElementById('page-indicator').textContent = `Page ${requestedPage + 1} of ?`; } - fetch(url) + fetchWithCsrf(url) .then(response => { return response.json(); }) @@ -302,7 +302,7 @@ function loadStats(days) { showLoading('user-chart-loading'); showLoading('time-chart-loading'); - fetch(`/audit/stats?days=${days}`) + fetchWithCsrf(`/audit/stats?days=${days}`) .then(response => response.json()) .then(data => { document.getElementById('total-events').textContent = data.totalEvents; @@ -835,7 +835,7 @@ function hideLoading(id) { // Load event types from the server for filter dropdowns function loadEventTypes() { - fetch('/audit/types') + fetchWithCsrf('/audit/types') .then(response => response.json()) .then(types => { if (!types || types.length === 0) { diff --git a/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java b/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java new file mode 100644 index 000000000..3e6d4491a --- /dev/null +++ b/proprietary/src/main/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepository.java @@ -0,0 +1,130 @@ +package stirling.software.proprietary.security.saml2; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.security.service.JwtServiceInterface; + +@Slf4j +public class JwtSaml2AuthenticationRequestRepository + implements Saml2AuthenticationRequestRepository { + private final Map tokenStore; + private final JwtServiceInterface jwtService; + private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + + private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; + + public JwtSaml2AuthenticationRequestRepository( + Map tokenStore, + JwtServiceInterface jwtService, + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + this.tokenStore = tokenStore; + this.jwtService = jwtService; + this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + } + + @Override + public void saveAuthenticationRequest( + Saml2PostAuthenticationRequest authRequest, + HttpServletRequest request, + HttpServletResponse response) { + if (authRequest == null) { + removeAuthenticationRequest(request, response); + return; + } + + Map claims = serializeSamlRequest(authRequest); + String token = jwtService.generateToken("", claims); + String relayState = authRequest.getRelayState(); + + tokenStore.put(relayState, token); + request.setAttribute(SAML_REQUEST_TOKEN, relayState); + response.addHeader(SAML_REQUEST_TOKEN, relayState); + + log.debug("Saved SAMLRequest token with RelayState: {}", relayState); + } + + @Override + public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) { + String token = extractTokenFromStore(request); + + if (token == null) { + log.debug("No SAMLResponse token found in RelayState"); + return null; + } + + Map claims = jwtService.extractAllClaims(token); + return deserializeSamlRequest(claims); + } + + @Override + public Saml2PostAuthenticationRequest removeAuthenticationRequest( + HttpServletRequest request, HttpServletResponse response) { + Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request); + + String relayStateId = request.getParameter("RelayState"); + if (relayStateId != null) { + tokenStore.remove(relayStateId); + log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId); + } + + return authRequest; + } + + private String extractTokenFromStore(HttpServletRequest request) { + String authnRequestId = request.getParameter("RelayState"); + + if (authnRequestId != null && !authnRequestId.isEmpty()) { + String token = tokenStore.get(authnRequestId); + + if (token != null) { + tokenStore.remove(authnRequestId); + log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId); + return token; + } else { + log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId); + } + } + + return null; + } + + private Map serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) { + Map claims = new HashMap<>(); + + claims.put("id", authRequest.getId()); + claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId()); + claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri()); + claims.put("samlRequest", authRequest.getSamlRequest()); + claims.put("relayState", authRequest.getRelayState()); + + return claims; + } + + private Saml2PostAuthenticationRequest deserializeSamlRequest(Map claims) { + String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId"); + RelyingPartyRegistration relyingPartyRegistration = + relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId); + + if (relyingPartyRegistration == null) { + return null; + } + + return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration) + .id((String) claims.get("id")) + .authenticationRequestUri((String) claims.get("authenticationRequestUri")) + .samlRequest((String) claims.get("samlRequest")) + .relayState((String) claims.get("relayState")) + .build(); + } +} diff --git a/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java b/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java new file mode 100644 index 000000000..9d21f88a3 --- /dev/null +++ b/proprietary/src/main/java/stirling/software/proprietary/security/saml2/Saml2Configuration.java @@ -0,0 +1,188 @@ +package stirling.software.proprietary.security.saml2; + +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import org.opensaml.saml.saml2.core.AuthnRequest; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.Resource; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; +import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver; + +import jakarta.servlet.http.HttpServletRequest; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.model.ApplicationProperties; +import stirling.software.common.model.ApplicationProperties.Security.SAML2; +import stirling.software.proprietary.security.service.JwtServiceInterface; + +@Configuration +@Slf4j +@ConditionalOnProperty(value = "security.saml2.enabled", havingValue = "true") +@RequiredArgsConstructor +public class Saml2Configuration { + + private final ApplicationProperties applicationProperties; + + @Bean + @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") + public RelyingPartyRegistrationRepository relyingPartyRegistrations() throws Exception { + SAML2 samlConf = applicationProperties.getSecurity().getSaml2(); + X509Certificate idpCert = CertificateUtils.readCertificate(samlConf.getIdpCert()); + Saml2X509Credential verificationCredential = Saml2X509Credential.verification(idpCert); + Resource privateKeyResource = samlConf.getPrivateKey(); + Resource certificateResource = samlConf.getSpCert(); + Saml2X509Credential signingCredential = + new Saml2X509Credential( + CertificateUtils.readPrivateKey(privateKeyResource), + CertificateUtils.readCertificate(certificateResource), + Saml2X509CredentialType.SIGNING); + RelyingPartyRegistration rp = + RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId()) + .signingX509Credentials(c -> c.add(signingCredential)) + .entityId(samlConf.getIdpIssuer()) + .singleLogoutServiceBinding(Saml2MessageBinding.POST) + .singleLogoutServiceLocation(samlConf.getIdpSingleLogoutUrl()) + .singleLogoutServiceResponseLocation("http://localhost:8080/login") + .assertionConsumerServiceBinding(Saml2MessageBinding.POST) + .assertionConsumerServiceLocation( + "{baseUrl}/login/saml2/sso/{registrationId}") + .authnRequestsSigned(true) + .assertingPartyMetadata( + metadata -> + metadata.entityId(samlConf.getIdpIssuer()) + .verificationX509Credentials( + c -> c.add(verificationCredential)) + .singleSignOnServiceBinding( + Saml2MessageBinding.POST) + .singleSignOnServiceLocation( + samlConf.getIdpSingleLoginUrl()) + .singleLogoutServiceBinding( + Saml2MessageBinding.POST) + .singleLogoutServiceLocation( + samlConf.getIdpSingleLogoutUrl()) + .singleLogoutServiceResponseLocation( + "http://localhost:8080/login") + .wantAuthnRequestsSigned(true)) + .build(); + return new InMemoryRelyingPartyRegistrationRepository(rp); + } + + @Bean + @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") + public Saml2AuthenticationRequestRepository + saml2AuthenticationRequestRepository( + JwtServiceInterface jwtService, + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + return new JwtSaml2AuthenticationRequestRepository( + new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository); + } + + @Bean + @ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true") + public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver( + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + Saml2AuthenticationRequestRepository + saml2AuthenticationRequestRepository) { + OpenSaml4AuthenticationRequestResolver resolver = + new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository); + + resolver.setAuthnRequestCustomizer( + customizer -> { + HttpServletRequest request = customizer.getRequest(); + AuthnRequest authnRequest = customizer.getAuthnRequest(); + Saml2PostAuthenticationRequest saml2AuthenticationRequest = + saml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + if (saml2AuthenticationRequest != null) { + String sessionId = request.getSession(false).getId(); + + log.debug( + "Retrieving SAML 2 authentication request ID from the current HTTP session {}", + sessionId); + + String authenticationRequestId = saml2AuthenticationRequest.getId(); + + if (!authenticationRequestId.isBlank()) { + authnRequest.setID(authenticationRequestId); + } else { + log.warn( + "No authentication request found for HTTP session {}. Generating new ID", + sessionId); + authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); + } + } else { + log.debug("Generating new authentication request ID"); + authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1)); + } + logAuthnRequestDetails(authnRequest); + logHttpRequestDetails(request); + }); + return resolver; + } + + private static void logAuthnRequestDetails(AuthnRequest authnRequest) { + String message = + """ + AuthnRequest: + + ID: {} + Issuer: {} + IssueInstant: {} + AssertionConsumerService (ACS) URL: {} + """; + log.debug( + message, + authnRequest.getID(), + authnRequest.getIssuer() != null ? authnRequest.getIssuer().getValue() : null, + authnRequest.getIssueInstant(), + authnRequest.getAssertionConsumerServiceURL()); + + if (authnRequest.getNameIDPolicy() != null) { + log.debug("NameIDPolicy Format: {}", authnRequest.getNameIDPolicy().getFormat()); + } + } + + private static void logHttpRequestDetails(HttpServletRequest request) { + log.debug("HTTP Headers: "); + Collections.list(request.getHeaderNames()) + .forEach( + headerName -> + log.debug("{}: {}", headerName, request.getHeader(headerName))); + String message = + """ + HTTP Request Method: {} + Session ID: {} + Request Path: {} + Query String: {} + Remote Address: {} + + SAML Request Parameters: + + SAMLRequest: {} + RelayState: {} + """; + log.debug( + message, + request.getMethod(), + request.getSession().getId(), + request.getRequestURI(), + request.getQueryString(), + request.getRemoteAddr(), + request.getParameter("SAMLRequest"), + request.getParameter("RelayState")); + } +} diff --git a/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java b/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java new file mode 100644 index 000000000..5b1bb8ec2 --- /dev/null +++ b/proprietary/src/main/java/stirling/software/proprietary/security/service/JwtService.java @@ -0,0 +1,196 @@ +package stirling.software.proprietary.security.service; + +import java.security.KeyPair; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.http.ResponseCookie; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.stereotype.Service; + +import io.github.pixee.security.Newlines; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.ExpiredJwtException; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.MalformedJwtException; +import io.jsonwebtoken.UnsupportedJwtException; +import io.jsonwebtoken.security.SignatureException; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.security.model.exception.AuthenticationFailureException; +import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal; + +@Slf4j +@Service +public class JwtService implements JwtServiceInterface { + + private static final String JWT_COOKIE_NAME = "STIRLING_JWT"; + private static final String AUTHORIZATION_HEADER = "Authorization"; + private static final String BEARER_PREFIX = "Bearer "; + private static final String ISSUER = "Stirling PDF"; + private static final long EXPIRATION = 3600000; + + private final KeyPair keyPair; + private final boolean v2Enabled; + + public JwtService(@Qualifier("v2Enabled") boolean v2Enabled) { + this.v2Enabled = v2Enabled; + keyPair = Jwts.SIG.RS256.keyPair().build(); + } + + @Override + public String generateToken(Authentication authentication, Map claims) { + Object principal = authentication.getPrincipal(); + String username = ""; + + if (principal instanceof UserDetails) { + username = ((UserDetails) principal).getUsername(); + } else if (principal instanceof OAuth2User) { + username = ((OAuth2User) principal).getName(); + } else if (principal instanceof CustomSaml2AuthenticatedPrincipal) { + username = ((CustomSaml2AuthenticatedPrincipal) principal).getName(); + } + + return generateToken(username, claims); + } + + @Override + public String generateToken(String username, Map claims) { + return Jwts.builder() + .claims(claims) + .subject(username) + .issuer(ISSUER) + .issuedAt(new Date()) + .expiration(new Date(System.currentTimeMillis() + EXPIRATION)) + .signWith(keyPair.getPrivate(), Jwts.SIG.RS256) + .compact(); + } + + @Override + public void validateToken(String token) throws AuthenticationFailureException { + extractAllClaimsFromToken(token); + + // todo: test + if (isTokenExpired(token)) { + throw new AuthenticationFailureException("The token has expired"); + } + } + + @Override + public String extractUsername(String token) { + return extractClaim(token, Claims::getSubject); + } + + @Override + public Map extractAllClaims(String token) { + Claims claims = extractAllClaimsFromToken(token); + return new HashMap<>(claims); + } + + @Override + public boolean isTokenExpired(String token) { + return extractExpiration(token).before(new Date()); + } + + private Date extractExpiration(String token) { + return extractClaim(token, Claims::getExpiration); + } + + private T extractClaim(String token, Function claimsResolver) { + final Claims claims = extractAllClaimsFromToken(token); + return claimsResolver.apply(claims); + } + + private Claims extractAllClaimsFromToken(String token) { + try { + return Jwts.parser() + .verifyWith(keyPair.getPublic()) + .build() + .parseSignedClaims(token) + .getPayload(); + } catch (SignatureException e) { + log.warn("Invalid signature: {}", e.getMessage()); + throw new AuthenticationFailureException("Invalid signature", e); + } catch (MalformedJwtException e) { + log.warn("Invalid token: {}", e.getMessage()); + throw new AuthenticationFailureException("Invalid token", e); + } catch (ExpiredJwtException e) { + log.warn("The token has expired: {}", e.getMessage()); + throw new AuthenticationFailureException("The token has expired", e); + } catch (UnsupportedJwtException e) { + log.warn("The token is unsupported: {}", e.getMessage()); + throw new AuthenticationFailureException("The token is unsupported", e); + } catch (IllegalArgumentException e) { + log.warn("Claims are empty: {}", e.getMessage()); + throw new AuthenticationFailureException("Claims are empty", e); + } + } + + @Override + public String extractTokenFromRequest(HttpServletRequest request) { + String authHeader = request.getHeader(AUTHORIZATION_HEADER); + + if (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) { + return authHeader.substring(BEARER_PREFIX.length()); + } + + Cookie[] cookies = request.getCookies(); + if (cookies != null) { + for (Cookie cookie : cookies) { + if (JWT_COOKIE_NAME.equals(cookie.getName())) { + return cookie.getValue(); + } + } + } + + return null; + } + + @Override + public void addTokenToResponse(HttpServletResponse response, String token) { + response.setHeader(AUTHORIZATION_HEADER, Newlines.stripAll(BEARER_PREFIX + token)); + + ResponseCookie cookie = + ResponseCookie.from(JWT_COOKIE_NAME, Newlines.stripAll(token)) + .httpOnly(true) + .secure(true) + .sameSite("None") + .maxAge(EXPIRATION / 1000) + .path("/") + .build(); + + response.addHeader("Set-Cookie", cookie.toString()); + } + + @Override + public void clearTokenFromResponse(HttpServletResponse response) { + // Remove Authorization header instead of setting empty string + response.setHeader(AUTHORIZATION_HEADER, null); + + ResponseCookie cookie = + ResponseCookie.from(JWT_COOKIE_NAME, "") + .httpOnly(true) + .secure(true) + .sameSite("None") + .maxAge(0) + .path("/") + .build(); + + response.addHeader("Set-Cookie", cookie.toString()); + } + + @Override + public boolean isJwtEnabled() { + return v2Enabled; + } +} diff --git a/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java b/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java new file mode 100644 index 000000000..11f3c00d9 --- /dev/null +++ b/proprietary/src/test/java/stirling/software/proprietary/security/saml2/JwtSaml2AuthenticationRequestRepositoryTest.java @@ -0,0 +1,229 @@ +package stirling.software.proprietary.security.saml2; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata; +import stirling.software.proprietary.security.service.JwtServiceInterface; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class JwtSaml2AuthenticationRequestRepositoryTest { + + private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token"; + + private Map tokenStore; + + @Mock + private JwtServiceInterface jwtService; + + @Mock + private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + + private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository; + + @BeforeEach + void setUp() { + tokenStore = new ConcurrentHashMap<>(); + jwtSaml2AuthenticationRequestRepository = new JwtSaml2AuthenticationRequestRepository( + tokenStore, jwtService, relyingPartyRegistrationRepository); + } + + @Test + void saveAuthenticationRequest() { + var authRequest = mock(Saml2PostAuthenticationRequest.class); + var request = mock(MockHttpServletRequest.class); + var response = mock(MockHttpServletResponse.class); + String token = "testToken"; + String id = "testId"; + String relayState = "testRelayState"; + String authnRequestUri = "example.com/authnRequest"; + Map claims = Map.of(); + String samlRequest = "testSamlRequest"; + String relyingPartyRegistrationId = "stirling-pdf"; + + when(authRequest.getRelayState()).thenReturn(relayState); + when(authRequest.getId()).thenReturn(id); + when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri); + when(authRequest.getSamlRequest()).thenReturn(samlRequest); + when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId); + when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token); + + jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(authRequest, request, response); + + verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState); + verify(response).addHeader(SAML_REQUEST_TOKEN, relayState); + } + + @Test + void saveAuthenticationRequestWithNullRequest() { + var request = mock(MockHttpServletRequest.class); + var response = mock(MockHttpServletResponse.class); + + jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response); + + assertTrue(tokenStore.isEmpty()); + } + + @Test + void loadAuthenticationRequest() { + var request = mock(MockHttpServletRequest.class); + var relyingPartyRegistration = mock(RelyingPartyRegistration.class); + var assertingPartyMetadata = mock(AssertingPartyMetadata.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState + ); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractAllClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration); + when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); + when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata); + when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso"); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNotNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } + + @ParameterizedTest + @NullAndEmptySource + void loadAuthenticationRequestWithInvalidRelayState(String relayState) { + var request = mock(MockHttpServletRequest.class); + when(request.getParameter("RelayState")).thenReturn(relayState); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void loadAuthenticationRequestWithNonExistentToken() { + var request = mock(MockHttpServletRequest.class); + when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void loadAuthenticationRequestWithNullRelyingPartyRegistration() { + var request = mock(MockHttpServletRequest.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState + ); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractAllClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(null); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request); + + assertNull(result); + } + + @Test + void removeAuthenticationRequest() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + var relyingPartyRegistration = mock(RelyingPartyRegistration.class); + var assertingPartyMetadata = mock(AssertingPartyMetadata.class); + String relayState = "testRelayState"; + String token = "testToken"; + Map claims = Map.of( + "id", "testId", + "relyingPartyRegistrationId", "stirling-pdf", + "authenticationRequestUri", "example.com/authnRequest", + "samlRequest", "testSamlRequest", + "relayState", relayState + ); + + when(request.getParameter("RelayState")).thenReturn(relayState); + when(jwtService.extractAllClaims(token)).thenReturn(claims); + when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration); + when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf"); + when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata); + when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso"); + tokenStore.put(relayState, token); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNotNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } + + @Test + void removeAuthenticationRequestWithNullRelayState() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + when(request.getParameter("RelayState")).thenReturn(null); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNull(result); + } + + @Test + void removeAuthenticationRequestWithNonExistentToken() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState"); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNull(result); + } + + @Test + void removeAuthenticationRequestWithOnlyRelayState() { + var request = mock(HttpServletRequest.class); + var response = mock(HttpServletResponse.class); + String relayState = "testRelayState"; + + when(request.getParameter("RelayState")).thenReturn(relayState); + + var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response); + + assertNull(result); + assertFalse(tokenStore.containsKey(relayState)); + } +}