import java.net.URI;
import java.net.URL;
import java.nio.charset.StandardCharsets;
+import javax.annotation.Nullable;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
-import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.lang.StringUtils;
import org.sonar.api.internal.apachecommons.lang.StringEscapeUtils;
public static final String VALIDATION_RELAY_STATE = "validation-query";
public static final String SAML_VALIDATION_CONTROLLER_CONTEXT = "saml";
public static final String SAML_VALIDATION_KEY = "validation";
+ private static final String RELAY_STATE_PARAMETER = "RelayState";
+ private static final String SAML_RESPONSE_PARAMETER = "SAMLResponse";
private String redirectionPageTemplate;
private final Server server;
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
- HttpServletRequest httpRequest = (HttpServletRequest) request;
- if (isSamlValidation(httpRequest)) {
+ String relayState = request.getParameter(RELAY_STATE_PARAMETER);
+
+ if (isSamlValidation(relayState)) {
HttpServletResponse httpResponse = (HttpServletResponse) response;
- String samlResponse = StringEscapeUtils.escapeHtml(request.getParameter("SAMLResponse"));
URI redirectionEndpointUrl = URI.create(server.getContextPath() + "/")
.resolve(SAML_VALIDATION_CONTROLLER_CONTEXT + "/")
.resolve(SAML_VALIDATION_KEY);
+ String samlResponse = StringEscapeUtils.escapeHtml(request.getParameter(SAML_RESPONSE_PARAMETER));
+ String csrfToken = getCsrfTokenFromRelayState(relayState);
String template = StringUtils.replaceEachRepeatedly(redirectionPageTemplate,
- new String[]{"%VALIDATION_URL%", "%SAML_RESPONSE%"},
- new String[]{redirectionEndpointUrl.toString(), samlResponse});
+ new String[]{"%VALIDATION_URL%", "%SAML_RESPONSE%", "%CSRF_TOKEN%"},
+ new String[]{redirectionEndpointUrl.toString(), samlResponse, csrfToken});
httpResponse.setContentType("text/html");
httpResponse.getWriter().print(template);
chain.doFilter(request, response);
}
- private static boolean isSamlValidation(HttpServletRequest request) {
- return VALIDATION_RELAY_STATE.equals(request.getParameter("RelayState"));
+ private static boolean isSamlValidation(@Nullable String relayState) {
+ if (relayState != null) {
+ return VALIDATION_RELAY_STATE.equals(relayState.split("/")[0]) && !getCsrfTokenFromRelayState(relayState).isEmpty();
+ }
+ return false;
+ }
+
+ private static String getCsrfTokenFromRelayState(@Nullable String relayState) {
+ if (relayState != null && relayState.contains("/")) {
+ return StringEscapeUtils.escapeHtml(relayState.split("/")[1]);
+ }
+ return "";
}
}
<h1>SAML Authentication Validation</h1>
<form id="saml_validate" action="%VALIDATION_URL%" method="POST">
<input name="SAMLResponse" value="%SAML_RESPONSE%" type="hidden"/>
+ <input name="CSRFToken" value="%CSRF_TOKEN%" type="hidden"/>
<button>Click Here to See Result</button>
</form>
</body>
package org.sonar.server.authentication;
+import com.tngtech.java.junit.dataprovider.DataProvider;
+import com.tngtech.java.junit.dataprovider.DataProviderRunner;
+import com.tngtech.java.junit.dataprovider.UseDataProvider;
import java.io.IOException;
import java.io.PrintWriter;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.sonar.api.platform.Server;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+@RunWith(DataProviderRunner.class)
public class SamlValidationRedirectionFilterTest {
SamlValidationRedirectionFilter underTest;
}
@Test
- public void do_filter_validation_relay_state() throws ServletException, IOException {
+ public void do_filter_validation_relay_state_with_csrfToken() throws ServletException, IOException {
HttpServletRequest servletRequest = mock(HttpServletRequest.class);
HttpServletResponse servletResponse = mock(HttpServletResponse.class);
FilterChain filterChain = mock(FilterChain.class);
String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample);
- when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query");
+ when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN");
PrintWriter pw = mock(PrintWriter.class);
when(servletResponse.getWriter()).thenReturn(pw);
ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
verify(pw).print(htmlProduced.capture());
assertThat(htmlProduced.getValue()).contains(validSample);
+ assertThat(htmlProduced.getValue()).contains("action=\"/saml/validation\"");
+ assertThat(htmlProduced.getValue()).contains("value=\"CSRF_TOKEN\"");
+ }
+
+ @Test
+ public void do_filter_validation_relay_state_with_malicious_csrfToken() throws ServletException, IOException {
+ HttpServletRequest servletRequest = mock(HttpServletRequest.class);
+ HttpServletResponse servletResponse = mock(HttpServletResponse.class);
+ FilterChain filterChain = mock(FilterChain.class);
+
+ String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
+ when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample);
+
+ String maliciousToken = "test\"</input><script>*Malicious Token*</script><input value=\"";
+
+ when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/" + maliciousToken);
+ PrintWriter pw = mock(PrintWriter.class);
+ when(servletResponse.getWriter()).thenReturn(pw);
+
+ underTest.doFilter(servletRequest, servletResponse, filterChain);
+
+ verify(servletResponse).setContentType("text/html");
+ ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
+ verify(pw).print(htmlProduced.capture());
+ assertThat(htmlProduced.getValue()).contains(validSample);
+ assertThat(htmlProduced.getValue()).doesNotContain("<script>/*Malicious Token*/</script>");
+
}
@Test
String maliciousSaml = "test\"</input><script>/*hack website*/</script><input value=\"";
when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(maliciousSaml);
- when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query");
+ when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN");
PrintWriter pw = mock(PrintWriter.class);
when(servletResponse.getWriter()).thenReturn(pw);
}
@Test
- public void do_filter_no_validation_relay_state() throws ServletException, IOException {
+ @UseDataProvider("invalidRelayStateValues")
+ public void do_filter_invalid_relayState_values(String relayStateValue) throws ServletException, IOException {
HttpServletRequest servletRequest = mock(HttpServletRequest.class);
HttpServletResponse servletResponse = mock(HttpServletResponse.class);
FilterChain filterChain = mock(FilterChain.class);
- doReturn("random_query").when(servletRequest).getParameter("RelayState");
+ doReturn(relayStateValue).when(servletRequest).getParameter("RelayState");
underTest.doFilter(servletRequest, servletResponse, filterChain);
verifyNoInteractions(servletResponse);
public void extract_nonexistant_template() {
assertThrows(IllegalStateException.class, () -> underTest.extractTemplate("not-there"));
}
+
+ @DataProvider
+ public static Object[] invalidRelayStateValues() {
+ return new Object[]{"random_query", "validation-query", null};
+ }
+
}
import org.sonar.auth.saml.SamlIdentityProvider;
import org.sonar.server.authentication.AuthenticationError;
import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
import org.sonar.server.authentication.SamlValidationRedirectionFilter;
+import org.sonar.server.authentication.event.AuthenticationException;
import org.sonar.server.user.ThreadLocalUserSession;
import org.sonar.server.ws.ServletFilterHandler;
private final ThreadLocalUserSession userSession;
private final SamlAuthenticator samlAuthenticator;
private final OAuth2ContextFactory oAuth2ContextFactory;
+ private final SamlIdentityProvider samlIdentityProvider;
+ private final OAuthCsrfVerifier oAuthCsrfVerifier;
- public ValidationAction(ThreadLocalUserSession userSession, SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory) {
+ public ValidationAction(ThreadLocalUserSession userSession, SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory,
+ SamlIdentityProvider samlIdentityProvider, OAuthCsrfVerifier oAuthCsrfVerifier) {
this.samlAuthenticator = samlAuthenticator;
this.userSession = userSession;
this.oAuth2ContextFactory = oAuth2ContextFactory;
+ this.samlIdentityProvider = samlIdentityProvider;
+ this.oAuthCsrfVerifier = oAuthCsrfVerifier;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletResponse httpResponse = (HttpServletResponse) response;
HttpServletRequest httpRequest = (HttpServletRequest) request;
+
+ try {
+ oAuthCsrfVerifier.verifyState(httpRequest, httpResponse, samlIdentityProvider, "CSRFToken");
+ } catch (AuthenticationException exception) {
+ AuthenticationError.handleError(httpRequest, httpResponse, exception.getMessage());
+ return;
+ }
+
if (!userSession.hasSession() || !userSession.isSystemAdministrator()) {
AuthenticationError.handleError(httpRequest, httpResponse, "User needs to be logged in as system administrator to access this page.");
return;
import org.sonar.auth.saml.SamlIdentityProvider;
import org.sonar.server.authentication.AuthenticationError;
import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
import org.sonar.server.exceptions.ForbiddenException;
import org.sonar.server.user.UserSession;
import org.sonar.server.ws.ServletFilterHandler;
public static final String VALIDATION_RELAY_STATE = "validation-query";
public static final String VALIDATION_INIT_KEY = "validation_init";
private final SamlAuthenticator samlAuthenticator;
-
+ private final OAuthCsrfVerifier oAuthCsrfVerifier;
private final OAuth2ContextFactory oAuth2ContextFactory;
private final UserSession userSession;
- public ValidationInitAction(SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory, UserSession userSession) {
+ public ValidationInitAction(SamlAuthenticator samlAuthenticator, OAuthCsrfVerifier oAuthCsrfVerifier, OAuth2ContextFactory oAuth2ContextFactory, UserSession userSession) {
this.samlAuthenticator = samlAuthenticator;
+ this.oAuthCsrfVerifier = oAuthCsrfVerifier;
this.oAuth2ContextFactory = oAuth2ContextFactory;
this.userSession = userSession;
}
return;
}
+ String csrfState = oAuthCsrfVerifier.generateState(request,response);
+
try {
samlAuthenticator.initLogin(oAuth2ContextFactory.generateCallbackUrl(SamlIdentityProvider.KEY),
- VALIDATION_RELAY_STATE, request, response);
+ VALIDATION_RELAY_STATE + "/" + csrfState, request, response);
} catch (IllegalStateException e) {
response.sendRedirect("/" + SAML_VALIDATION_CONTROLLER_CONTEXT + "/" + SAML_VALIDATION_KEY);
}
import org.junit.Test;
import org.sonar.api.server.ws.WebService;
import org.sonar.auth.saml.SamlAuthenticator;
+import org.sonar.auth.saml.SamlIdentityProvider;
import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
+import org.sonar.server.authentication.event.AuthenticationEvent;
+import org.sonar.server.authentication.event.AuthenticationException;
import org.sonar.server.user.ThreadLocalUserSession;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
private SamlAuthenticator samlAuthenticator;
private ThreadLocalUserSession userSession;
+ private OAuthCsrfVerifier oAuthCsrfVerifier;
+
+ private SamlIdentityProvider samlIdentityProvider;
+
@Before
public void setup() {
samlAuthenticator = mock(SamlAuthenticator.class);
userSession = mock(ThreadLocalUserSession.class);
+ oAuthCsrfVerifier = mock(OAuthCsrfVerifier.class);
+ samlIdentityProvider = mock(SamlIdentityProvider.class);
var oAuth2ContextFactory = mock(OAuth2ContextFactory.class);
- underTest = new ValidationAction(userSession, samlAuthenticator, oAuth2ContextFactory);
+ underTest = new ValidationAction(userSession, samlAuthenticator, oAuth2ContextFactory, samlIdentityProvider, oAuthCsrfVerifier);
}
@Test
verifyNoInteractions(samlAuthenticator);
}
+ @Test
+ public void do_filter_failed_csrf_verification() throws ServletException, IOException {
+ HttpServletRequest servletRequest = spy(HttpServletRequest.class);
+ HttpServletResponse servletResponse = mock(HttpServletResponse.class);
+ StringWriter stringWriter = new StringWriter();
+ doReturn(new PrintWriter(stringWriter)).when(servletResponse).getWriter();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ doReturn("IdentityProviderName").when(samlIdentityProvider).getName();
+ doThrow(AuthenticationException.newBuilder()
+ .setSource(AuthenticationEvent.Source.oauth2(samlIdentityProvider))
+ .setMessage("Cookie is missing").build()).when(oAuthCsrfVerifier).verifyState(any(),any(),any(), any());
+
+ doReturn(true).when(userSession).hasSession();
+ doReturn(true).when(userSession).isSystemAdministrator();
+
+ underTest.doFilter(servletRequest, servletResponse, filterChain);
+
+ verifyNoInteractions(samlAuthenticator);
+ }
+
@Test
public void verify_definition() {
String controllerKey = "foo";
import org.sonar.api.server.ws.WebService;
import org.sonar.auth.saml.SamlAuthenticator;
import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
import org.sonar.server.tester.UserSessionRule;
import static org.assertj.core.api.Assertions.assertThat;
private ValidationInitAction underTest;
private SamlAuthenticator samlAuthenticator;
private OAuth2ContextFactory oAuth2ContextFactory;
+ private OAuthCsrfVerifier oAuthCsrfVerifier;
@Before
public void setUp() throws Exception {
samlAuthenticator = mock(SamlAuthenticator.class);
oAuth2ContextFactory = mock(OAuth2ContextFactory.class);
- underTest = new ValidationInitAction(samlAuthenticator, oAuth2ContextFactory, userSession);
+ oAuthCsrfVerifier = mock(OAuthCsrfVerifier.class);
+ underTest = new ValidationInitAction(samlAuthenticator, oAuthCsrfVerifier, oAuth2ContextFactory, userSession);
}
@Test
HttpServletResponse servletResponse = mock(HttpServletResponse.class);
FilterChain filterChain = mock(FilterChain.class);
String callbackUrl = "http://localhost:9000/api/validation_test";
- when(oAuth2ContextFactory.generateCallbackUrl(anyString()))
- .thenReturn(callbackUrl);
+
+ mockCsrfTokenGeneration(servletRequest, servletResponse);
+ when(oAuth2ContextFactory.generateCallbackUrl(anyString())).thenReturn(callbackUrl);
underTest.doFilter(servletRequest, servletResponse, filterChain);
when(oAuth2ContextFactory.generateCallbackUrl(anyString()))
.thenReturn(callbackUrl);
+ mockCsrfTokenGeneration(servletRequest, servletResponse);
doThrow(new IllegalStateException()).when(samlAuthenticator).initLogin(any(), any(), any(), any());
underTest.doFilter(servletRequest, servletResponse, filterChain);
assertThat(validationInitAction.description()).isNotEmpty();
assertThat(validationInitAction.handler()).isNotNull();
}
+
+ private void mockCsrfTokenGeneration(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
+ when(oAuthCsrfVerifier.generateState(servletRequest, servletResponse)).thenReturn("CSRF_TOKEN");
+ }
}