Added test

This commit is contained in:
Dario Ghunney Ware 2025-07-18 15:08:15 +01:00
parent ae8980f656
commit f5756944ed
12 changed files with 992 additions and 84 deletions

View File

@ -46,10 +46,9 @@ export class DecryptFile {
formData.append('password', password);
}
// Send decryption request
const response = await fetch('/api/v1/security/remove-password', {
const response = await fetchWithCsrf('/api/v1/security/remove-password', {
method: 'POST',
body: formData,
headers: csrfToken ? {'X-XSRF-TOKEN': csrfToken} : undefined,
});
if (response.ok) {

View File

@ -218,7 +218,7 @@
formData.append('password', password);
// Use handleSingleDownload to send the request
const decryptionResult = await fetch(removePasswordUrl, {method: 'POST', body: formData});
const decryptionResult = await fetchWithCsrf(removePasswordUrl, {method: 'POST', body: formData});
if (decryptionResult && decryptionResult.blob) {
const decryptedBlob = await decryptionResult.blob();

View File

@ -1,3 +1,76 @@
// JWT Management Utility
window.JWTManager = {
JWT_STORAGE_KEY: 'stirling_jwt',
// Store JWT token in localStorage
storeToken: function(token) {
if (token) {
localStorage.setItem(this.JWT_STORAGE_KEY, token);
}
},
// Get JWT token from localStorage
getToken: function() {
return localStorage.getItem(this.JWT_STORAGE_KEY);
},
// Remove JWT token from localStorage
removeToken: function() {
localStorage.removeItem(this.JWT_STORAGE_KEY);
},
// Extract JWT from Authorization header in response
extractTokenFromResponse: function(response) {
const authHeader = response.headers.get('Authorization');
if (authHeader && authHeader.startsWith('Bearer ')) {
const token = authHeader.substring(7); // Remove 'Bearer ' prefix
this.storeToken(token);
return token;
}
return null;
},
// Check if user is authenticated (has valid JWT)
isAuthenticated: function() {
const token = this.getToken();
if (!token) return false;
try {
// Basic JWT expiration check (decode payload)
const payload = JSON.parse(atob(token.split('.')[1]));
const now = Date.now() / 1000;
return payload.exp > now;
} catch (error) {
console.warn('Invalid JWT token:', error);
this.removeToken();
return false;
}
},
// Logout - remove token and redirect to login
logout: function() {
this.removeToken();
// Clear all possible token storage locations
localStorage.removeItem(this.JWT_STORAGE_KEY);
sessionStorage.removeItem(this.JWT_STORAGE_KEY);
// Clear JWT cookie manually (fallback)
document.cookie = 'STIRLING_JWT=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure';
// Perform logout request to clear server-side session
fetch('/logout', {
method: 'POST',
credentials: 'include'
}).then(() => {
window.location.href = '/login';
}).catch(() => {
// Even if logout fails, redirect to login
window.location.href = '/login';
});
}
};
window.fetchWithCsrf = async function(url, options = {}) {
function getCsrfToken() {
const cookieValue = document.cookie
@ -24,5 +97,31 @@ window.fetchWithCsrf = async function(url, options = {}) {
fetchOptions.headers['X-XSRF-TOKEN'] = csrfToken;
}
return fetch(url, fetchOptions);
// Add JWT token to Authorization header if available
const jwtToken = window.JWTManager.getToken();
if (jwtToken) {
fetchOptions.headers['Authorization'] = `Bearer ${jwtToken}`;
// Include credentials when JWT is enabled
fetchOptions.credentials = 'include';
}
// Make the request
const response = await fetch(url, fetchOptions);
// Extract JWT from response if present
window.JWTManager.extractTokenFromResponse(response);
// Handle 401 responses (unauthorized)
if (response.status === 401) {
console.warn('Authentication failed, redirecting to login');
window.JWTManager.logout();
return response;
}
return response;
}
// Enhanced fetch function that always includes JWT
window.fetchWithJWT = async function(url, options = {}) {
return window.fetchWithCsrf(url, options);
}

View File

@ -0,0 +1,121 @@
// JWT Initialization Script
// This script handles JWT token extraction during OAuth/Login flows and initializes the JWT manager
(function() {
// Extract JWT token from URL parameters (for OAuth redirects)
function extractTokenFromUrl() {
const urlParams = new URLSearchParams(window.location.search);
const token = urlParams.get('jwt') || urlParams.get('token');
if (token) {
window.JWTManager.storeToken(token);
// Clean up URL by removing token parameter
urlParams.delete('jwt');
urlParams.delete('token');
const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : '');
window.history.replaceState({}, '', newUrl);
}
}
// Extract JWT token from cookie on page load (fallback)
function extractTokenFromCookie() {
const cookieValue = document.cookie
.split('; ')
.find(row => row.startsWith('STIRLING_JWT='))
?.split('=')[1];
if (cookieValue) {
window.JWTManager.storeToken(cookieValue);
// Clear the cookie since we're using localStorage with consistent SameSite policy
document.cookie = 'STIRLING_JWT=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; SameSite=None; Secure';
}
}
// Initialize JWT handling when page loads
function initializeJWT() {
// Try to extract token from URL first (OAuth flow)
extractTokenFromUrl();
// If no token in URL, try cookie (login flow)
if (!window.JWTManager.getToken()) {
extractTokenFromCookie();
}
// Check if user is authenticated
if (window.JWTManager.isAuthenticated()) {
console.log('User is authenticated with JWT');
} else {
console.log('User is not authenticated or token expired');
// Only redirect to login if we're not already on login/register pages
const currentPath = window.location.pathname;
if (!currentPath.includes('/login') &&
!currentPath.includes('/register') &&
!currentPath.includes('/oauth') &&
!currentPath.includes('/saml') &&
!currentPath.includes('/error')) {
// Redirect to login after a short delay to allow other scripts to load
setTimeout(() => {
window.location.href = '/login';
}, 100);
}
}
}
// Override form submissions to include JWT
function enhanceFormSubmissions() {
// Override form submit for login forms
document.addEventListener('submit', function(event) {
const form = event.target;
// Add JWT to form data if available
const jwtToken = window.JWTManager.getToken();
if (jwtToken && form.method && form.method.toLowerCase() !== 'get') {
// Create a hidden input for JWT
const jwtInput = document.createElement('input');
jwtInput.type = 'hidden';
jwtInput.name = 'jwt';
jwtInput.value = jwtToken;
form.appendChild(jwtInput);
}
});
}
// Add logout functionality to logout buttons
function enhanceLogoutButtons() {
document.addEventListener('click', function(event) {
const element = event.target;
// Check if clicked element is a logout button/link
if (element.matches('a[href="/logout"], button[data-action="logout"], .logout-btn')) {
event.preventDefault();
window.JWTManager.logout();
}
});
}
// Initialize when DOM is ready
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', function() {
initializeJWT();
enhanceFormSubmissions();
enhanceLogoutButtons();
});
} else {
initializeJWT();
enhanceFormSubmissions();
enhanceLogoutButtons();
}
// Handle page visibility changes to check token expiration
document.addEventListener('visibilitychange', function() {
if (!document.hidden && !window.JWTManager.isAuthenticated()) {
// Token expired while page was hidden, redirect to login
const currentPath = window.location.pathname;
if (!currentPath.includes('/login') &&
!currentPath.includes('/register') &&
!currentPath.includes('/oauth') &&
!currentPath.includes('/saml')) {
window.location.href = '/login';
}
}
});
})();

View File

@ -42,39 +42,6 @@ function toolsManager() {
});
}
function setupDropdowns() {
const dropdowns = document.querySelectorAll('.navbar-nav > .nav-item.dropdown');
dropdowns.forEach((dropdown) => {
const toggle = dropdown.querySelector('[data-bs-toggle="dropdown"]');
if (!toggle) return;
// Skip search dropdown, it has its own logic
if (toggle.id === 'searchDropdown') {
return;
}
dropdown.addEventListener('show.bs.dropdown', () => {
// Find all other open dropdowns and hide them
const openDropdowns = document.querySelectorAll('.navbar-nav .dropdown-menu.show');
openDropdowns.forEach((menu) => {
const parentDropdown = menu.closest('.dropdown');
if (parentDropdown && parentDropdown !== dropdown) {
const parentToggle = parentDropdown.querySelector('[data-bs-toggle="dropdown"]');
if (parentToggle) {
// Get or create Bootstrap dropdown instance
let instance = bootstrap.Dropdown.getInstance(parentToggle);
if (!instance) {
instance = new bootstrap.Dropdown(parentToggle);
}
instance.hide();
}
}
});
});
});
}
window.tooltipSetup = () => {
const tooltipElements = document.querySelectorAll('[title]');
@ -89,54 +56,37 @@ window.tooltipSetup = () => {
document.body.appendChild(customTooltip);
element.addEventListener('mouseenter', (event) => {
if (window.innerWidth >= 1200) {
customTooltip.style.display = 'block';
customTooltip.style.left = `${event.pageX + 10}px`;
customTooltip.style.top = `${event.pageY + 10}px`;
}
customTooltip.style.display = 'block';
customTooltip.style.left = `${event.pageX + 10}px`; // Position tooltip slightly away from the cursor
customTooltip.style.top = `${event.pageY + 10}px`;
});
// Update the position of the tooltip as the user moves the mouse
element.addEventListener('mousemove', (event) => {
if (window.innerWidth >= 1200) {
customTooltip.style.left = `${event.pageX + 10}px`;
customTooltip.style.top = `${event.pageY + 10}px`;
}
customTooltip.style.left = `${event.pageX + 10}px`;
customTooltip.style.top = `${event.pageY + 10}px`;
});
// Hide the tooltip when the mouse leaves
element.addEventListener('mouseleave', () => {
customTooltip.style.display = 'none';
});
});
};
// Override the bootstrap dropdown styles for mobile
function fixNavbarDropdownStyles() {
if (window.innerWidth < 1200) {
document.querySelectorAll('.navbar .dropdown-menu').forEach(function(menu) {
menu.style.transform = 'none';
menu.style.transformOrigin = 'none';
menu.style.left = '0';
menu.style.right = '0';
menu.style.maxWidth = '95vw';
menu.style.width = '100vw';
menu.style.marginBottom = '0';
});
} else {
document.querySelectorAll('.navbar .dropdown-menu').forEach(function(menu) {
menu.style.transform = '';
menu.style.transformOrigin = '';
menu.style.left = '';
menu.style.right = '';
menu.style.maxWidth = '';
menu.style.width = '';
menu.style.marginBottom = '';
});
}
}
document.addEventListener('DOMContentLoaded', () => {
tooltipSetup();
setupDropdowns();
fixNavbarDropdownStyles();
// Setup logout button functionality
const logoutButton = document.querySelector('a[href="/logout"]');
if (logoutButton) {
logoutButton.addEventListener('click', function(event) {
event.preventDefault();
if (window.JWTManager) {
window.JWTManager.logout();
} else {
// Fallback if JWTManager is not available
window.location.href = '/logout';
}
});
}
});
window.addEventListener('resize', fixNavbarDropdownStyles);

View File

@ -102,7 +102,7 @@ async function fetchEndpointData() {
refreshBtn.classList.add('refreshing');
refreshBtn.disabled = true;
const response = await fetch('/api/v1/info/load/all');
const response = await fetchWithCsrf('/api/v1/info/load/all');
if (!response.ok) {
throw new Error('Network response was not ok');
}

View File

@ -7,7 +7,6 @@ import java.util.ArrayList;
import java.util.List;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
@ -117,10 +116,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
samlClient.setSPKeys(certificate, privateKey);
// Redirect to identity provider for logout. todo: add relay state
// samlClient.redirectToIdentityProvider(response, null, nameIdValue);
samlClient.processLogoutRequestPostFromIdentityProvider(request, nameIdValue);
samlClient.redirectToIdentityProviderLogout(
response, HttpStatus.OK.name(), nameIdValue);
samlClient.redirectToIdentityProvider(response, null, nameIdValue);
} catch (Exception e) {
log.error(
"Error retrieving logout URL from Provider {} for user {}",

View File

@ -230,7 +230,7 @@ function loadAuditData(targetPage, realPageSize) {
document.getElementById('page-indicator').textContent = `Page ${requestedPage + 1} of ?`;
}
fetch(url)
fetchWithCsrf(url)
.then(response => {
return response.json();
})
@ -302,7 +302,7 @@ function loadStats(days) {
showLoading('user-chart-loading');
showLoading('time-chart-loading');
fetch(`/audit/stats?days=${days}`)
fetchWithCsrf(`/audit/stats?days=${days}`)
.then(response => response.json())
.then(data => {
document.getElementById('total-events').textContent = data.totalEvents;
@ -835,7 +835,7 @@ function hideLoading(id) {
// Load event types from the server for filter dropdowns
function loadEventTypes() {
fetch('/audit/types')
fetchWithCsrf('/audit/types')
.then(response => response.json())
.then(types => {
if (!types || types.length === 0) {

View File

@ -0,0 +1,130 @@
package stirling.software.proprietary.security.saml2;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Slf4j
public class JwtSaml2AuthenticationRequestRepository
implements Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest> {
private final Map<String, String> tokenStore;
private final JwtServiceInterface jwtService;
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token";
public JwtSaml2AuthenticationRequestRepository(
Map<String, String> tokenStore,
JwtServiceInterface jwtService,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this.tokenStore = tokenStore;
this.jwtService = jwtService;
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
}
@Override
public void saveAuthenticationRequest(
Saml2PostAuthenticationRequest authRequest,
HttpServletRequest request,
HttpServletResponse response) {
if (authRequest == null) {
removeAuthenticationRequest(request, response);
return;
}
Map<String, Object> claims = serializeSamlRequest(authRequest);
String token = jwtService.generateToken("", claims);
String relayState = authRequest.getRelayState();
tokenStore.put(relayState, token);
request.setAttribute(SAML_REQUEST_TOKEN, relayState);
response.addHeader(SAML_REQUEST_TOKEN, relayState);
log.debug("Saved SAMLRequest token with RelayState: {}", relayState);
}
@Override
public Saml2PostAuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
String token = extractTokenFromStore(request);
if (token == null) {
log.debug("No SAMLResponse token found in RelayState");
return null;
}
Map<String, Object> claims = jwtService.extractAllClaims(token);
return deserializeSamlRequest(claims);
}
@Override
public Saml2PostAuthenticationRequest removeAuthenticationRequest(
HttpServletRequest request, HttpServletResponse response) {
Saml2PostAuthenticationRequest authRequest = loadAuthenticationRequest(request);
String relayStateId = request.getParameter("RelayState");
if (relayStateId != null) {
tokenStore.remove(relayStateId);
log.debug("Removed SAMLRequest token for RelayState ID: {}", relayStateId);
}
return authRequest;
}
private String extractTokenFromStore(HttpServletRequest request) {
String authnRequestId = request.getParameter("RelayState");
if (authnRequestId != null && !authnRequestId.isEmpty()) {
String token = tokenStore.get(authnRequestId);
if (token != null) {
tokenStore.remove(authnRequestId);
log.debug("Retrieved SAMLRequest token for RelayState ID: {}", authnRequestId);
return token;
} else {
log.warn("No SAMLRequest token found for RelayState ID: {}", authnRequestId);
}
}
return null;
}
private Map<String, Object> serializeSamlRequest(Saml2PostAuthenticationRequest authRequest) {
Map<String, Object> claims = new HashMap<>();
claims.put("id", authRequest.getId());
claims.put("relyingPartyRegistrationId", authRequest.getRelyingPartyRegistrationId());
claims.put("authenticationRequestUri", authRequest.getAuthenticationRequestUri());
claims.put("samlRequest", authRequest.getSamlRequest());
claims.put("relayState", authRequest.getRelayState());
return claims;
}
private Saml2PostAuthenticationRequest deserializeSamlRequest(Map<String, Object> claims) {
String relyingPartyRegistrationId = (String) claims.get("relyingPartyRegistrationId");
RelyingPartyRegistration relyingPartyRegistration =
relyingPartyRegistrationRepository.findByRegistrationId(relyingPartyRegistrationId);
if (relyingPartyRegistration == null) {
return null;
}
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(relyingPartyRegistration)
.id((String) claims.get("id"))
.authenticationRequestUri((String) claims.get("authenticationRequestUri"))
.samlRequest((String) claims.get("samlRequest"))
.relayState((String) claims.get("relayState"))
.build();
}
}

View File

@ -0,0 +1,188 @@
package stirling.software.proprietary.security.saml2;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.common.model.ApplicationProperties.Security.SAML2;
import stirling.software.proprietary.security.service.JwtServiceInterface;
@Configuration
@Slf4j
@ConditionalOnProperty(value = "security.saml2.enabled", havingValue = "true")
@RequiredArgsConstructor
public class Saml2Configuration {
private final ApplicationProperties applicationProperties;
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public RelyingPartyRegistrationRepository relyingPartyRegistrations() throws Exception {
SAML2 samlConf = applicationProperties.getSecurity().getSaml2();
X509Certificate idpCert = CertificateUtils.readCertificate(samlConf.getIdpCert());
Saml2X509Credential verificationCredential = Saml2X509Credential.verification(idpCert);
Resource privateKeyResource = samlConf.getPrivateKey();
Resource certificateResource = samlConf.getSpCert();
Saml2X509Credential signingCredential =
new Saml2X509Credential(
CertificateUtils.readPrivateKey(privateKeyResource),
CertificateUtils.readCertificate(certificateResource),
Saml2X509CredentialType.SIGNING);
RelyingPartyRegistration rp =
RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId())
.signingX509Credentials(c -> c.add(signingCredential))
.entityId(samlConf.getIdpIssuer())
.singleLogoutServiceBinding(Saml2MessageBinding.POST)
.singleLogoutServiceLocation(samlConf.getIdpSingleLogoutUrl())
.singleLogoutServiceResponseLocation("http://localhost:8080/login")
.assertionConsumerServiceBinding(Saml2MessageBinding.POST)
.assertionConsumerServiceLocation(
"{baseUrl}/login/saml2/sso/{registrationId}")
.authnRequestsSigned(true)
.assertingPartyMetadata(
metadata ->
metadata.entityId(samlConf.getIdpIssuer())
.verificationX509Credentials(
c -> c.add(verificationCredential))
.singleSignOnServiceBinding(
Saml2MessageBinding.POST)
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.singleLogoutServiceBinding(
Saml2MessageBinding.POST)
.singleLogoutServiceLocation(
samlConf.getIdpSingleLogoutUrl())
.singleLogoutServiceResponseLocation(
"http://localhost:8080/login")
.wantAuthnRequestsSigned(true))
.build();
return new InMemoryRelyingPartyRegistrationRepository(rp);
}
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest>
saml2AuthenticationRequestRepository(
JwtServiceInterface jwtService,
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
return new JwtSaml2AuthenticationRequestRepository(
new ConcurrentHashMap<>(), jwtService, relyingPartyRegistrationRepository);
}
@Bean
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public OpenSaml4AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
Saml2AuthenticationRequestRepository<Saml2PostAuthenticationRequest>
saml2AuthenticationRequestRepository) {
OpenSaml4AuthenticationRequestResolver resolver =
new OpenSaml4AuthenticationRequestResolver(relyingPartyRegistrationRepository);
resolver.setAuthnRequestCustomizer(
customizer -> {
HttpServletRequest request = customizer.getRequest();
AuthnRequest authnRequest = customizer.getAuthnRequest();
Saml2PostAuthenticationRequest saml2AuthenticationRequest =
saml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
if (saml2AuthenticationRequest != null) {
String sessionId = request.getSession(false).getId();
log.debug(
"Retrieving SAML 2 authentication request ID from the current HTTP session {}",
sessionId);
String authenticationRequestId = saml2AuthenticationRequest.getId();
if (!authenticationRequestId.isBlank()) {
authnRequest.setID(authenticationRequestId);
} else {
log.warn(
"No authentication request found for HTTP session {}. Generating new ID",
sessionId);
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
}
} else {
log.debug("Generating new authentication request ID");
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
}
logAuthnRequestDetails(authnRequest);
logHttpRequestDetails(request);
});
return resolver;
}
private static void logAuthnRequestDetails(AuthnRequest authnRequest) {
String message =
"""
AuthnRequest:
ID: {}
Issuer: {}
IssueInstant: {}
AssertionConsumerService (ACS) URL: {}
""";
log.debug(
message,
authnRequest.getID(),
authnRequest.getIssuer() != null ? authnRequest.getIssuer().getValue() : null,
authnRequest.getIssueInstant(),
authnRequest.getAssertionConsumerServiceURL());
if (authnRequest.getNameIDPolicy() != null) {
log.debug("NameIDPolicy Format: {}", authnRequest.getNameIDPolicy().getFormat());
}
}
private static void logHttpRequestDetails(HttpServletRequest request) {
log.debug("HTTP Headers: ");
Collections.list(request.getHeaderNames())
.forEach(
headerName ->
log.debug("{}: {}", headerName, request.getHeader(headerName)));
String message =
"""
HTTP Request Method: {}
Session ID: {}
Request Path: {}
Query String: {}
Remote Address: {}
SAML Request Parameters:
SAMLRequest: {}
RelayState: {}
""";
log.debug(
message,
request.getMethod(),
request.getSession().getId(),
request.getRequestURI(),
request.getQueryString(),
request.getRemoteAddr(),
request.getParameter("SAMLRequest"),
request.getParameter("RelayState"));
}
}

View File

@ -0,0 +1,196 @@
package stirling.software.proprietary.security.service;
import java.security.KeyPair;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.ResponseCookie;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
import io.github.pixee.security.Newlines;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.SignatureException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.security.model.exception.AuthenticationFailureException;
import stirling.software.proprietary.security.saml2.CustomSaml2AuthenticatedPrincipal;
@Slf4j
@Service
public class JwtService implements JwtServiceInterface {
private static final String JWT_COOKIE_NAME = "STIRLING_JWT";
private static final String AUTHORIZATION_HEADER = "Authorization";
private static final String BEARER_PREFIX = "Bearer ";
private static final String ISSUER = "Stirling PDF";
private static final long EXPIRATION = 3600000;
private final KeyPair keyPair;
private final boolean v2Enabled;
public JwtService(@Qualifier("v2Enabled") boolean v2Enabled) {
this.v2Enabled = v2Enabled;
keyPair = Jwts.SIG.RS256.keyPair().build();
}
@Override
public String generateToken(Authentication authentication, Map<String, Object> claims) {
Object principal = authentication.getPrincipal();
String username = "";
if (principal instanceof UserDetails) {
username = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
username = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
}
return generateToken(username, claims);
}
@Override
public String generateToken(String username, Map<String, Object> claims) {
return Jwts.builder()
.claims(claims)
.subject(username)
.issuer(ISSUER)
.issuedAt(new Date())
.expiration(new Date(System.currentTimeMillis() + EXPIRATION))
.signWith(keyPair.getPrivate(), Jwts.SIG.RS256)
.compact();
}
@Override
public void validateToken(String token) throws AuthenticationFailureException {
extractAllClaimsFromToken(token);
// todo: test
if (isTokenExpired(token)) {
throw new AuthenticationFailureException("The token has expired");
}
}
@Override
public String extractUsername(String token) {
return extractClaim(token, Claims::getSubject);
}
@Override
public Map<String, Object> extractAllClaims(String token) {
Claims claims = extractAllClaimsFromToken(token);
return new HashMap<>(claims);
}
@Override
public boolean isTokenExpired(String token) {
return extractExpiration(token).before(new Date());
}
private Date extractExpiration(String token) {
return extractClaim(token, Claims::getExpiration);
}
private <T> T extractClaim(String token, Function<Claims, T> claimsResolver) {
final Claims claims = extractAllClaimsFromToken(token);
return claimsResolver.apply(claims);
}
private Claims extractAllClaimsFromToken(String token) {
try {
return Jwts.parser()
.verifyWith(keyPair.getPublic())
.build()
.parseSignedClaims(token)
.getPayload();
} catch (SignatureException e) {
log.warn("Invalid signature: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid signature", e);
} catch (MalformedJwtException e) {
log.warn("Invalid token: {}", e.getMessage());
throw new AuthenticationFailureException("Invalid token", e);
} catch (ExpiredJwtException e) {
log.warn("The token has expired: {}", e.getMessage());
throw new AuthenticationFailureException("The token has expired", e);
} catch (UnsupportedJwtException e) {
log.warn("The token is unsupported: {}", e.getMessage());
throw new AuthenticationFailureException("The token is unsupported", e);
} catch (IllegalArgumentException e) {
log.warn("Claims are empty: {}", e.getMessage());
throw new AuthenticationFailureException("Claims are empty", e);
}
}
@Override
public String extractTokenFromRequest(HttpServletRequest request) {
String authHeader = request.getHeader(AUTHORIZATION_HEADER);
if (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) {
return authHeader.substring(BEARER_PREFIX.length());
}
Cookie[] cookies = request.getCookies();
if (cookies != null) {
for (Cookie cookie : cookies) {
if (JWT_COOKIE_NAME.equals(cookie.getName())) {
return cookie.getValue();
}
}
}
return null;
}
@Override
public void addTokenToResponse(HttpServletResponse response, String token) {
response.setHeader(AUTHORIZATION_HEADER, Newlines.stripAll(BEARER_PREFIX + token));
ResponseCookie cookie =
ResponseCookie.from(JWT_COOKIE_NAME, Newlines.stripAll(token))
.httpOnly(true)
.secure(true)
.sameSite("None")
.maxAge(EXPIRATION / 1000)
.path("/")
.build();
response.addHeader("Set-Cookie", cookie.toString());
}
@Override
public void clearTokenFromResponse(HttpServletResponse response) {
// Remove Authorization header instead of setting empty string
response.setHeader(AUTHORIZATION_HEADER, null);
ResponseCookie cookie =
ResponseCookie.from(JWT_COOKIE_NAME, "")
.httpOnly(true)
.secure(true)
.sameSite("None")
.maxAge(0)
.path("/")
.build();
response.addHeader("Set-Cookie", cookie.toString());
}
@Override
public boolean isJwtEnabled() {
return v2Enabled;
}
}

View File

@ -0,0 +1,229 @@
package stirling.software.proprietary.security.saml2;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.AssertingPartyMetadata;
import stirling.software.proprietary.security.service.JwtServiceInterface;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class JwtSaml2AuthenticationRequestRepositoryTest {
private static final String SAML_REQUEST_TOKEN = "stirling_saml_request_token";
private Map<String, String> tokenStore;
@Mock
private JwtServiceInterface jwtService;
@Mock
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private JwtSaml2AuthenticationRequestRepository jwtSaml2AuthenticationRequestRepository;
@BeforeEach
void setUp() {
tokenStore = new ConcurrentHashMap<>();
jwtSaml2AuthenticationRequestRepository = new JwtSaml2AuthenticationRequestRepository(
tokenStore, jwtService, relyingPartyRegistrationRepository);
}
@Test
void saveAuthenticationRequest() {
var authRequest = mock(Saml2PostAuthenticationRequest.class);
var request = mock(MockHttpServletRequest.class);
var response = mock(MockHttpServletResponse.class);
String token = "testToken";
String id = "testId";
String relayState = "testRelayState";
String authnRequestUri = "example.com/authnRequest";
Map<String, Object> claims = Map.of();
String samlRequest = "testSamlRequest";
String relyingPartyRegistrationId = "stirling-pdf";
when(authRequest.getRelayState()).thenReturn(relayState);
when(authRequest.getId()).thenReturn(id);
when(authRequest.getAuthenticationRequestUri()).thenReturn(authnRequestUri);
when(authRequest.getSamlRequest()).thenReturn(samlRequest);
when(authRequest.getRelyingPartyRegistrationId()).thenReturn(relyingPartyRegistrationId);
when(jwtService.generateToken(eq(""), anyMap())).thenReturn(token);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(authRequest, request, response);
verify(request).setAttribute(SAML_REQUEST_TOKEN, relayState);
verify(response).addHeader(SAML_REQUEST_TOKEN, relayState);
}
@Test
void saveAuthenticationRequestWithNullRequest() {
var request = mock(MockHttpServletRequest.class);
var response = mock(MockHttpServletResponse.class);
jwtSaml2AuthenticationRequestRepository.saveAuthenticationRequest(null, request, response);
assertTrue(tokenStore.isEmpty());
}
@Test
void loadAuthenticationRequest() {
var request = mock(MockHttpServletRequest.class);
var relyingPartyRegistration = mock(RelyingPartyRegistration.class);
var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims = Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState
);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso");
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
@ParameterizedTest
@NullAndEmptySource
void loadAuthenticationRequestWithInvalidRelayState(String relayState) {
var request = mock(MockHttpServletRequest.class);
when(request.getParameter("RelayState")).thenReturn(relayState);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void loadAuthenticationRequestWithNonExistentToken() {
var request = mock(MockHttpServletRequest.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void loadAuthenticationRequestWithNullRelyingPartyRegistration() {
var request = mock(MockHttpServletRequest.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims = Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState
);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(null);
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.loadAuthenticationRequest(request);
assertNull(result);
}
@Test
void removeAuthenticationRequest() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
var relyingPartyRegistration = mock(RelyingPartyRegistration.class);
var assertingPartyMetadata = mock(AssertingPartyMetadata.class);
String relayState = "testRelayState";
String token = "testToken";
Map<String, Object> claims = Map.of(
"id", "testId",
"relyingPartyRegistrationId", "stirling-pdf",
"authenticationRequestUri", "example.com/authnRequest",
"samlRequest", "testSamlRequest",
"relayState", relayState
);
when(request.getParameter("RelayState")).thenReturn(relayState);
when(jwtService.extractAllClaims(token)).thenReturn(claims);
when(relyingPartyRegistrationRepository.findByRegistrationId("stirling-pdf")).thenReturn(relyingPartyRegistration);
when(relyingPartyRegistration.getRegistrationId()).thenReturn("stirling-pdf");
when(relyingPartyRegistration.getAssertingPartyMetadata()).thenReturn(assertingPartyMetadata);
when(assertingPartyMetadata.getSingleSignOnServiceLocation()).thenReturn("https://example.com/sso");
tokenStore.put(relayState, token);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNotNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
@Test
void removeAuthenticationRequestWithNullRelayState() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn(null);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNull(result);
}
@Test
void removeAuthenticationRequestWithNonExistentToken() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
when(request.getParameter("RelayState")).thenReturn("nonExistentRelayState");
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNull(result);
}
@Test
void removeAuthenticationRequestWithOnlyRelayState() {
var request = mock(HttpServletRequest.class);
var response = mock(HttpServletResponse.class);
String relayState = "testRelayState";
when(request.getParameter("RelayState")).thenReturn(relayState);
var result = jwtSaml2AuthenticationRequestRepository.removeAuthenticationRequest(request, response);
assertNull(result);
assertFalse(tokenStore.containsKey(relayState));
}
}