mirror of
https://github.com/Stirling-Tools/Stirling-PDF.git
synced 2025-06-06 18:30:57 +00:00
216 lines
7.4 KiB
Java
216 lines
7.4 KiB
Java
package stirling.software.SPDF.config;
|
|
|
|
import java.lang.reflect.Method;
|
|
import java.util.HashSet;
|
|
import java.util.Map;
|
|
import java.util.Set;
|
|
import java.util.TreeSet;
|
|
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.context.ApplicationContext;
|
|
import org.springframework.context.ApplicationListener;
|
|
import org.springframework.context.event.ContextRefreshedEvent;
|
|
import org.springframework.stereotype.Component;
|
|
import org.springframework.web.bind.annotation.RequestMethod;
|
|
import org.springframework.web.method.HandlerMethod;
|
|
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;
|
|
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
|
|
|
|
@Component
|
|
public class EndpointInspector implements ApplicationListener<ContextRefreshedEvent> {
|
|
private static final Logger logger = LoggerFactory.getLogger(EndpointInspector.class);
|
|
|
|
private final ApplicationContext applicationContext;
|
|
private final Set<String> validGetEndpoints = new HashSet<>();
|
|
private boolean endpointsDiscovered = false;
|
|
|
|
@Autowired
|
|
public EndpointInspector(ApplicationContext applicationContext) {
|
|
this.applicationContext = applicationContext;
|
|
}
|
|
|
|
@Override
|
|
public void onApplicationEvent(ContextRefreshedEvent event) {
|
|
if (!endpointsDiscovered) {
|
|
discoverEndpoints();
|
|
endpointsDiscovered = true;
|
|
}
|
|
}
|
|
|
|
private void discoverEndpoints() {
|
|
try {
|
|
Map<String, RequestMappingHandlerMapping> mappings =
|
|
applicationContext.getBeansOfType(RequestMappingHandlerMapping.class);
|
|
|
|
for (Map.Entry<String, RequestMappingHandlerMapping> entry : mappings.entrySet()) {
|
|
RequestMappingHandlerMapping mapping = entry.getValue();
|
|
Map<RequestMappingInfo, HandlerMethod> handlerMethods = mapping.getHandlerMethods();
|
|
|
|
for (Map.Entry<RequestMappingInfo, HandlerMethod> handlerEntry :
|
|
handlerMethods.entrySet()) {
|
|
RequestMappingInfo mappingInfo = handlerEntry.getKey();
|
|
HandlerMethod handlerMethod = handlerEntry.getValue();
|
|
|
|
boolean isGetHandler = false;
|
|
try {
|
|
Set<RequestMethod> methods = mappingInfo.getMethodsCondition().getMethods();
|
|
isGetHandler = methods.isEmpty() || methods.contains(RequestMethod.GET);
|
|
} catch (Exception e) {
|
|
isGetHandler = true;
|
|
}
|
|
|
|
if (isGetHandler) {
|
|
Set<String> patterns = extractPatternsUsingDirectPaths(mappingInfo);
|
|
|
|
if (patterns.isEmpty()) {
|
|
patterns = extractPatternsFromString(mappingInfo);
|
|
}
|
|
|
|
validGetEndpoints.addAll(patterns);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (validGetEndpoints.isEmpty()) {
|
|
logger.warn("No endpoints discovered. Adding common endpoints as fallback.");
|
|
validGetEndpoints.add("/");
|
|
validGetEndpoints.add("/api/**");
|
|
validGetEndpoints.add("/**");
|
|
}
|
|
} catch (Exception e) {
|
|
logger.error("Error discovering endpoints", e);
|
|
}
|
|
}
|
|
|
|
private Set<String> extractPatternsUsingDirectPaths(RequestMappingInfo mappingInfo) {
|
|
Set<String> patterns = new HashSet<>();
|
|
|
|
try {
|
|
Method getDirectPathsMethod = mappingInfo.getClass().getMethod("getDirectPaths");
|
|
Object result = getDirectPathsMethod.invoke(mappingInfo);
|
|
if (result instanceof Set) {
|
|
@SuppressWarnings("unchecked")
|
|
Set<String> resultSet = (Set<String>) result;
|
|
patterns.addAll(resultSet);
|
|
}
|
|
} catch (Exception e) {
|
|
// Return empty set if method not found or fails
|
|
}
|
|
|
|
return patterns;
|
|
}
|
|
|
|
private Set<String> extractPatternsFromString(RequestMappingInfo mappingInfo) {
|
|
Set<String> patterns = new HashSet<>();
|
|
try {
|
|
String infoString = mappingInfo.toString();
|
|
if (infoString.contains("{")) {
|
|
String patternsSection =
|
|
infoString.substring(infoString.indexOf("{") + 1, infoString.indexOf("}"));
|
|
|
|
for (String pattern : patternsSection.split(",")) {
|
|
pattern = pattern.trim();
|
|
if (!pattern.isEmpty()) {
|
|
patterns.add(pattern);
|
|
}
|
|
}
|
|
}
|
|
} catch (Exception e) {
|
|
// Return empty set if parsing fails
|
|
}
|
|
return patterns;
|
|
}
|
|
|
|
public boolean isValidGetEndpoint(String uri) {
|
|
if (!endpointsDiscovered) {
|
|
discoverEndpoints();
|
|
endpointsDiscovered = true;
|
|
}
|
|
|
|
if (validGetEndpoints.contains(uri)) {
|
|
return true;
|
|
}
|
|
|
|
if (matchesWildcardOrPathVariable(uri)) {
|
|
return true;
|
|
}
|
|
|
|
if (matchesPathSegments(uri)) {
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
private boolean matchesWildcardOrPathVariable(String uri) {
|
|
for (String pattern : validGetEndpoints) {
|
|
if (pattern.contains("*") || pattern.contains("{")) {
|
|
int wildcardIndex = pattern.indexOf('*');
|
|
int variableIndex = pattern.indexOf('{');
|
|
|
|
int cutoffIndex;
|
|
if (wildcardIndex < 0) {
|
|
cutoffIndex = variableIndex;
|
|
} else if (variableIndex < 0) {
|
|
cutoffIndex = wildcardIndex;
|
|
} else {
|
|
cutoffIndex = Math.min(wildcardIndex, variableIndex);
|
|
}
|
|
|
|
String staticPrefix = pattern.substring(0, cutoffIndex);
|
|
|
|
if (uri.startsWith(staticPrefix)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
private boolean matchesPathSegments(String uri) {
|
|
for (String pattern : validGetEndpoints) {
|
|
if (!pattern.contains("*") && !pattern.contains("{")) {
|
|
String[] patternSegments = pattern.split("/");
|
|
String[] uriSegments = uri.split("/");
|
|
|
|
if (uriSegments.length < patternSegments.length) {
|
|
continue;
|
|
}
|
|
|
|
boolean match = true;
|
|
for (int i = 0; i < patternSegments.length; i++) {
|
|
if (!patternSegments[i].equals(uriSegments[i])) {
|
|
match = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (match) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
public Set<String> getValidGetEndpoints() {
|
|
if (!endpointsDiscovered) {
|
|
discoverEndpoints();
|
|
endpointsDiscovered = true;
|
|
}
|
|
return new HashSet<>(validGetEndpoints);
|
|
}
|
|
|
|
private void logAllEndpoints() {
|
|
Set<String> sortedEndpoints = new TreeSet<>(validGetEndpoints);
|
|
|
|
logger.info("=== BEGIN: All discovered GET endpoints ===");
|
|
for (String endpoint : sortedEndpoints) {
|
|
logger.info("Endpoint: {}", endpoint);
|
|
}
|
|
logger.info("=== END: All discovered GET endpoints ===");
|
|
}
|
|
}
|