This commit is contained in:
Dario Ghunney Ware 2025-02-19 20:56:21 +00:00
parent 9fda276aad
commit 086e4e0e15
14 changed files with 42 additions and 61 deletions

View File

@ -6,7 +6,6 @@ import java.security.interfaces.RSAPrivateKey;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.Resource;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
@ -29,6 +28,7 @@ import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrin
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.provider.KeycloakProvider;
import stirling.software.SPDF.utils.UrlUtils;
@Slf4j
@ -63,7 +63,8 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
}
} else {
// Redirect to login page after logout
getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH);
String path = checkForErrors(request);
getRedirectStrategy().sendRedirect(request, response, path);
}
}
}
@ -80,7 +81,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
CustomSaml2AuthenticatedPrincipal principal =
(CustomSaml2AuthenticatedPrincipal) samlAuthentication.getPrincipal();
String nameIdValue = principal.getName();
String nameIdValue = principal.name();
try {
// Read certificate from the resource
@ -103,11 +104,16 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
// Redirect to identity provider for logout
samlClient.redirectToIdentityProvider(response, null, nameIdValue);
} catch (Exception e) {
log.error(nameIdValue, e);
log.error(
"Error retrieving logout URL from Provider {} for user {}",
samlConf.getProvider(),
nameIdValue,
e);
getRedirectStrategy().sendRedirect(request, response, LOGOUT_PATH);
}
}
// Redirect for OAuth2 authentication logout
private void getRedirect_oauth2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
@ -126,11 +132,12 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
// Redirect based on OAuth2 provider
switch (registrationId.toLowerCase()) {
case "keycloak" -> {
KeycloakProvider keycloak = oauth.getClient().getKeycloak();
String logoutUrl =
oauth.getIssuer()
keycloak.getIssuer()
+ "/protocol/openid-connect/logout"
+ "?client_id="
+ oauth.getClientId()
+ keycloak.getClientId()
+ "&post_logout_redirect_uri="
+ response.encodeRedirectURL(redirectUrl);
log.info("Redirecting to Keycloak logout URL: {}", logoutUrl);
@ -144,20 +151,12 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
response.sendRedirect(redirectUrl);
}
default -> {
String logoutUrl = oauth.getLogoutUrl();
if (StringUtils.isNotBlank(logoutUrl)) {
log.info("Redirecting to logout URL: {}", logoutUrl);
response.sendRedirect(logoutUrl);
} else {
log.info("Redirecting to default logout URL: {}", redirectUrl);
response.sendRedirect(redirectUrl);
}
log.info("Redirecting to default logout URL: {}", redirectUrl);
response.sendRedirect(redirectUrl);
}
}
}
// Redirect for OAuth2 authentication logout
private static SamlClient getSamlClient(
String registrationId, SAML2 samlConf, List<X509Certificate> certificates)
throws SamlException {
@ -169,7 +168,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpSLOUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
@ -177,7 +176,7 @@ public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
return new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpSLOUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);

View File

@ -150,7 +150,7 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
OAUTH2 oAuth = securityProp.getOauth2();
blockRegistration = oAuth != null && oAuth.getBlockRegistration();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
username = ((CustomSaml2AuthenticatedPrincipal) principal).name();
loginMethod = LoginMethod.SAML2USER;
SAML2 saml2 = securityProp.getSaml2();
blockRegistration = saml2 != null && saml2.getBlockRegistration();

View File

@ -78,20 +78,18 @@ public class UserService implements UserServiceInterface {
}
// Handle OAUTH2 login and user auto creation.
public boolean processSSOPostLogin(String username, boolean autoCreateUser)
public void processSSOPostLogin(String username, boolean autoCreateUser)
throws IllegalArgumentException, SQLException, UnsupportedProviderException {
if (!isUsernameValid(username)) {
return false;
return;
}
Optional<User> existingUser = findByUsernameIgnoreCase(username);
if (existingUser.isPresent()) {
return true;
return;
}
if (autoCreateUser) {
saveUser(username, AuthenticationType.SSO);
return true;
}
return false;
}
public Authentication getAuthentication(String apiKey) {
@ -382,7 +380,7 @@ public class UserService implements UserServiceInterface {
} else if (principal instanceof OAuth2User oAuth2User) {
usernameP = oAuth2User.getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal saml2User) {
usernameP = saml2User.getName();
usernameP = saml2User.name();
} else if (principal instanceof String) {
usernameP = (String) principal;
}
@ -403,7 +401,7 @@ public class UserService implements UserServiceInterface {
.getAttribute(
applicationProperties.getSecurity().getOauth2().getUseAsUsername());
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
return ((CustomSaml2AuthenticatedPrincipal) principal).getName();
return ((CustomSaml2AuthenticatedPrincipal) principal).name();
} else {
return principal.toString();
}

View File

@ -79,18 +79,21 @@ public class CustomOAuth2AuthenticationSuccessHandler
if (userService.isUserDisabled(username)) {
getRedirectStrategy()
.sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(username, AuthenticationType.SSO)
&& oAuth.getAutoCreateUser()) {
response.sendRedirect(contextPath + "/logout?oAuth2AuthenticationErrorWeb=true");
return;
}
try {
if (oAuth.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(contextPath + "/logout?oAuth2AdminBlockedUser=true");
return;
}
if (principal instanceof OAuth2User) {
userService.processSSOPostLogin(username, oAuth.getAutoCreateUser());

View File

@ -13,8 +13,10 @@ import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemReader;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.core.io.Resource;
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public class CertificateUtils {
public static X509Certificate readCertificate(Resource certificateResource) throws Exception {

View File

@ -4,27 +4,13 @@ import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
public class CustomSaml2AuthenticatedPrincipal
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public record CustomSaml2AuthenticatedPrincipal(String name, Map<String, List<Object>> attributes, String nameId, List<String> sessionIndexes)
implements Saml2AuthenticatedPrincipal, Serializable {
private final String name;
private final Map<String, List<Object>> attributes;
private final String nameId;
private final List<String> sessionIndexes;
public CustomSaml2AuthenticatedPrincipal(
String name,
Map<String, List<Object>> attributes,
String nameId,
List<String> sessionIndexes) {
this.name = name;
this.attributes = attributes;
this.nameId = nameId;
this.sessionIndexes = sessionIndexes;
}
@Override
public String getName() {
return this.name;
@ -35,11 +21,4 @@ public class CustomSaml2AuthenticatedPrincipal
return this.attributes;
}
public String getNameId() {
return this.nameId;
}
public List<String> getSessionIndexes() {
return this.sessionIndexes;
}
}

View File

@ -2,6 +2,7 @@ package stirling.software.SPDF.config.security.saml2;
import java.io.IOException;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.security.authentication.ProviderNotFoundException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.core.Saml2Error;
@ -14,6 +15,7 @@ import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthenticationFailureHandler {
@Override

View File

@ -42,7 +42,7 @@ public class CustomSaml2AuthenticationSuccessHandler
log.debug("Starting SAML2 authentication success handling");
if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
String username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
String username = ((CustomSaml2AuthenticatedPrincipal) principal).name();
log.debug("Authenticated principal found for user: {}", username);
HttpSession session = request.getSession(false);

View File

@ -7,6 +7,7 @@ import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
@ -18,6 +19,7 @@ import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.User;
@Slf4j
@ConditionalOnProperty(name = "security.saml2.enabled", havingValue = "true")
public class CustomSaml2ResponseAuthenticationConverter
implements Converter<ResponseToken, Saml2Authentication> {

View File

@ -48,7 +48,7 @@ public class SessionPersistentRegistry implements SessionRegistry {
} else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).name();
} else if (principal instanceof String) {
principalName = (String) principal;
}
@ -79,7 +79,7 @@ public class SessionPersistentRegistry implements SessionRegistry {
} else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).name();
} else if (principal instanceof String) {
principalName = (String) principal;
}

View File

@ -302,7 +302,7 @@ public class UserController {
} else if (principal instanceof OAuth2User) {
userNameP = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
userNameP = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
userNameP = ((CustomSaml2AuthenticatedPrincipal) principal).name();
} else if (principal instanceof String) {
userNameP = (String) principal;
}

View File

@ -345,7 +345,7 @@ public class AccountWebController {
model.addAttribute("oAuth2Login", true);
}
if (principal instanceof CustomSaml2AuthenticatedPrincipal userDetails) {
username = userDetails.getName();
username = userDetails.name();
model.addAttribute("saml2Login", true);
}
if (username != null) {

View File

@ -83,7 +83,7 @@ public class Provider {
private UsernameAttribute validateKeycloakUsernameAttribute(
UsernameAttribute usernameAttribute) {
switch (usernameAttribute) {
case EMAIL, PREFERRED_NAME -> {
case EMAIL, PREFERRED_USERNAME -> {
return usernameAttribute;
}
default ->

View File

@ -5,13 +5,9 @@ import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import stirling.software.SPDF.model.ApplicationProperties;
import static org.mockito.Mockito.mock;
@ -31,7 +27,7 @@ class CustomLogoutSuccessHandlerTest {
void testSuccessfulLogout() throws IOException {
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
String logoutPath = "/login?logout=true";
String logoutPath = "logout=true";
when(response.isCommitted()).thenReturn(false);
when(request.getContextPath()).thenReturn("");