package stirling.software.SPDF.config.security; import java.io.IOException; import java.security.cert.X509Certificate; import java.security.interfaces.RSAPrivateKey; import java.util.ArrayList; import java.util.List; import org.springframework.core.io.Resource; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler; import com.coveo.saml.SamlClient; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import stirling.software.SPDF.SPdfApplication; import stirling.software.SPDF.config.security.saml2.CertificateUtils; import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal; import stirling.software.SPDF.model.ApplicationProperties; import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2; import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2; import stirling.software.SPDF.model.Provider; import stirling.software.SPDF.model.provider.UnsupportedProviderException; import stirling.software.SPDF.utils.UrlUtils; @Slf4j @AllArgsConstructor public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { private final ApplicationProperties applicationProperties; @Override public void onLogoutSuccess( HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException { if (!response.isCommitted()) { // Handle user logout due to disabled account if (request.getParameter("userIsDisabled") != null) { response.sendRedirect( request.getContextPath() + "/login?erroroauth=userIsDisabled"); return; } // Handle OAuth2 authentication error if (request.getParameter("oauth2AuthenticationErrorWeb") != null) { response.sendRedirect( request.getContextPath() + "/login?erroroauth=userAlreadyExistsWeb"); return; } if (authentication != null) { // Handle SAML2 logout redirection if (authentication instanceof Saml2Authentication) { getRedirect_saml2(request, response, authentication); return; } // Handle OAuth2 logout redirection else if (authentication instanceof OAuth2AuthenticationToken) { getRedirect_oauth2(request, response, authentication); return; } // Handle Username/Password logout else if (authentication instanceof UsernamePasswordAuthenticationToken) { getRedirectStrategy().sendRedirect(request, response, "/login?logout=true"); return; } // Handle unknown authentication types else { log.error( "authentication class unknown: " + authentication.getClass().getSimpleName()); getRedirectStrategy().sendRedirect(request, response, "/login?logout=true"); return; } } else { // Redirect to login page after logout getRedirectStrategy().sendRedirect(request, response, "/login?logout=true"); return; } } } // Redirect for SAML2 authentication logout private void getRedirect_saml2( HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException { SAML2 samlConf = applicationProperties.getSecurity().getSaml2(); String registrationId = samlConf.getRegistrationId(); Saml2Authentication samlAuthentication = (Saml2Authentication) authentication; CustomSaml2AuthenticatedPrincipal principal = (CustomSaml2AuthenticatedPrincipal) samlAuthentication.getPrincipal(); String nameIdValue = principal.getName(); try { // Read certificate from the resource Resource certificateResource = samlConf.getSpCert(); X509Certificate certificate = CertificateUtils.readCertificate(certificateResource); List certificates = new ArrayList<>(); certificates.add(certificate); // Construct URLs required for SAML configuration String serverUrl = SPdfApplication.getStaticBaseUrl() + ":" + SPdfApplication.getStaticPort(); String relyingPartyIdentifier = serverUrl + "/saml2/service-provider-metadata/" + registrationId; String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId; String idpUrl = samlConf.getIdpSingleLogoutUrl(); String idpIssuer = samlConf.getIdpIssuer(); // Create SamlClient instance for SAML logout SamlClient samlClient = new SamlClient( relyingPartyIdentifier, assertionConsumerServiceUrl, idpUrl, idpIssuer, certificates, SamlClient.SamlIdpBinding.POST); // Read private key for service provider Resource privateKeyResource = samlConf.getPrivateKey(); RSAPrivateKey privateKey = CertificateUtils.readPrivateKey(privateKeyResource); // Set service provider keys for the SamlClient samlClient.setSPKeys(certificate, privateKey); // Redirect to identity provider for logout samlClient.redirectToIdentityProvider(response, null, nameIdValue); } catch (Exception e) { log.error(nameIdValue, e); getRedirectStrategy().sendRedirect(request, response, "/login?logout=true"); } } // Redirect for OAuth2 authentication logout private void getRedirect_oauth2( HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException { String param = "logout=true"; String registrationId = null; String issuer = null; String clientId = null; OAUTH2 oauth = applicationProperties.getSecurity().getOauth2(); if (authentication instanceof OAuth2AuthenticationToken) { OAuth2AuthenticationToken oauthToken = (OAuth2AuthenticationToken) authentication; registrationId = oauthToken.getAuthorizedClientRegistrationId(); try { // Get OAuth2 provider details from configuration Provider provider = oauth.getClient().get(registrationId); issuer = provider.getIssuer(); clientId = provider.getClientId(); } catch (UnsupportedProviderException e) { log.error(e.getMessage()); } } else { registrationId = oauth.getProvider() != null ? oauth.getProvider() : ""; issuer = oauth.getIssuer(); clientId = oauth.getClientId(); } String errorMessage = ""; // Handle different error scenarios during logout if (request.getParameter("oauth2AuthenticationErrorWeb") != null) { param = "erroroauth=oauth2AuthenticationErrorWeb"; } else if ((errorMessage = request.getParameter("error")) != null) { param = "error=" + sanitizeInput(errorMessage); } else if ((errorMessage = request.getParameter("erroroauth")) != null) { param = "erroroauth=" + sanitizeInput(errorMessage); } else if (request.getParameter("oauth2AutoCreateDisabled") != null) { param = "error=oauth2AutoCreateDisabled"; } else if (request.getParameter("oauth2_admin_blocked_user") != null) { param = "erroroauth=oauth2_admin_blocked_user"; } else if (request.getParameter("userIsDisabled") != null) { param = "erroroauth=userIsDisabled"; } else if (request.getParameter("badcredentials") != null) { param = "error=badcredentials"; } String redirect_url = UrlUtils.getOrigin(request) + "/login?" + param; // Redirect based on OAuth2 provider switch (registrationId.toLowerCase()) { case "keycloak": // Add Keycloak specific logout URL if needed String logoutUrl = issuer + "/protocol/openid-connect/logout" + "?client_id=" + clientId + "&post_logout_redirect_uri=" + response.encodeRedirectURL(redirect_url); log.info("Redirecting to Keycloak logout URL: " + logoutUrl); response.sendRedirect(logoutUrl); break; case "github": // Add GitHub specific logout URL if needed String githubLogoutUrl = "https://github.com/logout"; log.info("Redirecting to GitHub logout URL: " + githubLogoutUrl); response.sendRedirect(githubLogoutUrl); break; case "google": // Add Google specific logout URL if needed // String googleLogoutUrl = // "https://accounts.google.com/Logout?continue=https://appengine.google.com/_ah/logout?continue=" // + response.encodeRedirectURL(redirect_url); log.info("Google does not have a specific logout URL"); // log.info("Redirecting to Google logout URL: " + googleLogoutUrl); // response.sendRedirect(googleLogoutUrl); // break; default: String defaultRedirectUrl = request.getContextPath() + "/login?" + param; log.info("Redirecting to default logout URL: " + defaultRedirectUrl); response.sendRedirect(defaultRedirectUrl); break; } } // Sanitize input to avoid potential security vulnerabilities private String sanitizeInput(String input) { return input.replaceAll("[^a-zA-Z0-9 ]", ""); } }