/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.web.firewall;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpMethod;
import org.springframework.security.web.firewall.FirewalledRequest;
import org.springframework.security.web.firewall.FirewalledResponse;
import org.springframework.security.web.firewall.HttpFirewall;
import org.springframework.security.web.firewall.RequestRejectedException;

public class StrictHttpFirewall
implements HttpFirewall {
    private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.unmodifiableSet(Collections.emptySet());
    private static final String ENCODED_PERCENT = "%25";
    private static final String PERCENT = "%";
    private static final List<String> FORBIDDEN_ENCODED_PERIOD = Collections.unmodifiableList(Arrays.asList("%2e", "%2E"));
    private static final List<String> FORBIDDEN_SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B"));
    private static final List<String> FORBIDDEN_FORWARDSLASH = Collections.unmodifiableList(Arrays.asList("%2f", "%2F"));
    private static final List<String> FORBIDDEN_BACKSLASH = Collections.unmodifiableList(Arrays.asList("\\", "%5c", "%5C"));
    private Set<String> encodedUrlBlacklist = new HashSet<String>();
    private Set<String> decodedUrlBlacklist = new HashSet<String>();
    private Set<String> allowedHttpMethods = StrictHttpFirewall.createDefaultAllowedHttpMethods();
    private Predicate<String> allowedHostnames = hostname -> true;

    public StrictHttpFirewall() {
        this.urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
        this.urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
        this.urlBlacklistsAddAll(FORBIDDEN_BACKSLASH);
        this.encodedUrlBlacklist.add(ENCODED_PERCENT);
        this.encodedUrlBlacklist.addAll(FORBIDDEN_ENCODED_PERIOD);
        this.decodedUrlBlacklist.add(PERCENT);
    }

    public void setUnsafeAllowAnyHttpMethod(boolean unsafeAllowAnyHttpMethod) {
        this.allowedHttpMethods = unsafeAllowAnyHttpMethod ? ALLOW_ANY_HTTP_METHOD : StrictHttpFirewall.createDefaultAllowedHttpMethods();
    }

    public void setAllowedHttpMethods(Collection<String> allowedHttpMethods) {
        if (allowedHttpMethods == null) {
            throw new IllegalArgumentException("allowedHttpMethods cannot be null");
        }
        this.allowedHttpMethods = allowedHttpMethods == ALLOW_ANY_HTTP_METHOD ? ALLOW_ANY_HTTP_METHOD : new HashSet<String>(allowedHttpMethods);
    }

    public void setAllowSemicolon(boolean allowSemicolon) {
        if (allowSemicolon) {
            this.urlBlacklistsRemoveAll(FORBIDDEN_SEMICOLON);
        } else {
            this.urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
        }
    }

    public void setAllowUrlEncodedSlash(boolean allowUrlEncodedSlash) {
        if (allowUrlEncodedSlash) {
            this.urlBlacklistsRemoveAll(FORBIDDEN_FORWARDSLASH);
        } else {
            this.urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
        }
    }

    public void setAllowUrlEncodedPeriod(boolean allowUrlEncodedPeriod) {
        if (allowUrlEncodedPeriod) {
            this.encodedUrlBlacklist.removeAll(FORBIDDEN_ENCODED_PERIOD);
        } else {
            this.encodedUrlBlacklist.addAll(FORBIDDEN_ENCODED_PERIOD);
        }
    }

    public void setAllowBackSlash(boolean allowBackSlash) {
        if (allowBackSlash) {
            this.urlBlacklistsRemoveAll(FORBIDDEN_BACKSLASH);
        } else {
            this.urlBlacklistsAddAll(FORBIDDEN_BACKSLASH);
        }
    }

    public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
        if (allowUrlEncodedPercent) {
            this.encodedUrlBlacklist.remove(ENCODED_PERCENT);
            this.decodedUrlBlacklist.remove(PERCENT);
        } else {
            this.encodedUrlBlacklist.add(ENCODED_PERCENT);
            this.decodedUrlBlacklist.add(PERCENT);
        }
    }

    public void setAllowedHostnames(Predicate<String> allowedHostnames) {
        if (allowedHostnames == null) {
            throw new IllegalArgumentException("allowedHostnames cannot be null");
        }
        this.allowedHostnames = allowedHostnames;
    }

    private void urlBlacklistsAddAll(Collection<String> values) {
        this.encodedUrlBlacklist.addAll(values);
        this.decodedUrlBlacklist.addAll(values);
    }

    private void urlBlacklistsRemoveAll(Collection<String> values) {
        this.encodedUrlBlacklist.removeAll(values);
        this.decodedUrlBlacklist.removeAll(values);
    }

    @Override
    public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
        this.rejectForbiddenHttpMethod(request);
        this.rejectedBlacklistedUrls(request);
        this.rejectedUntrustedHosts(request);
        if (!StrictHttpFirewall.isNormalized(request)) {
            throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
        }
        String requestUri = request.getRequestURI();
        if (!StrictHttpFirewall.containsOnlyPrintableAsciiCharacters(requestUri)) {
            throw new RequestRejectedException("The requestURI was rejected because it can only contain printable ASCII characters.");
        }
        return new FirewalledRequest(request){

            @Override
            public void reset() {
            }
        };
    }

    private void rejectForbiddenHttpMethod(HttpServletRequest request) {
        if (this.allowedHttpMethods == ALLOW_ANY_HTTP_METHOD) {
            return;
        }
        if (!this.allowedHttpMethods.contains(request.getMethod())) {
            throw new RequestRejectedException("The request was rejected because the HTTP method \"" + request.getMethod() + "\" was not included within the whitelist " + this.allowedHttpMethods);
        }
    }

    private void rejectedBlacklistedUrls(HttpServletRequest request) {
        for (String forbidden : this.encodedUrlBlacklist) {
            if (!StrictHttpFirewall.encodedUrlContains(request, forbidden)) continue;
            throw new RequestRejectedException("The request was rejected because the URL contained a potentially malicious String \"" + forbidden + "\"");
        }
        for (String forbidden : this.decodedUrlBlacklist) {
            if (!StrictHttpFirewall.decodedUrlContains(request, forbidden)) continue;
            throw new RequestRejectedException("The request was rejected because the URL contained a potentially malicious String \"" + forbidden + "\"");
        }
    }

    private void rejectedUntrustedHosts(HttpServletRequest request) {
        String serverName = request.getServerName();
        if (serverName != null && !this.allowedHostnames.test(serverName)) {
            throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted.");
        }
    }

    @Override
    public HttpServletResponse getFirewalledResponse(HttpServletResponse response) {
        return new FirewalledResponse(response);
    }

    private static Set<String> createDefaultAllowedHttpMethods() {
        HashSet<String> result = new HashSet<String>();
        result.add(HttpMethod.DELETE.name());
        result.add(HttpMethod.GET.name());
        result.add(HttpMethod.HEAD.name());
        result.add(HttpMethod.OPTIONS.name());
        result.add(HttpMethod.PATCH.name());
        result.add(HttpMethod.POST.name());
        result.add(HttpMethod.PUT.name());
        return result;
    }

    private static boolean isNormalized(HttpServletRequest request) {
        if (!StrictHttpFirewall.isNormalized(request.getRequestURI())) {
            return false;
        }
        if (!StrictHttpFirewall.isNormalized(request.getContextPath())) {
            return false;
        }
        if (!StrictHttpFirewall.isNormalized(request.getServletPath())) {
            return false;
        }
        return StrictHttpFirewall.isNormalized(request.getPathInfo());
    }

    private static boolean encodedUrlContains(HttpServletRequest request, String value) {
        if (StrictHttpFirewall.valueContains(request.getContextPath(), value)) {
            return true;
        }
        return StrictHttpFirewall.valueContains(request.getRequestURI(), value);
    }

    private static boolean decodedUrlContains(HttpServletRequest request, String value) {
        if (StrictHttpFirewall.valueContains(request.getServletPath(), value)) {
            return true;
        }
        return StrictHttpFirewall.valueContains(request.getPathInfo(), value);
    }

    private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
        int length = uri.length();
        for (int i = 0; i < length; ++i) {
            char c = uri.charAt(i);
            if (c >= ' ' && c <= '~') continue;
            return false;
        }
        return true;
    }

    private static boolean valueContains(String value, String contains) {
        return value != null && value.contains(contains);
    }

    private static boolean isNormalized(String path) {
        if (path == null) {
            return true;
        }
        if (path.indexOf("//") > -1) {
            return false;
        }
        int j = path.length();
        while (j > 0) {
            int i = path.lastIndexOf(47, j - 1);
            int gap = j - i;
            if (gap == 2 && path.charAt(i + 1) == '.') {
                return false;
            }
            if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') {
                return false;
            }
            j = i;
        }
        return true;
    }
}

