From f235dcec41d3fad489d16318bb59b8869f064ac0 Mon Sep 17 00:00:00 2001 From: Matteo Mara Date: Fri, 2 Dec 2022 17:56:14 +0100 Subject: [PATCH] SONAR-17694 fix SSF-323 --- .../SamlValidationRedirectionFilter.java | 30 +++++++++--- .../resources/validation-redirection.html | 1 + .../SamlValidationRedirectionFilterTest.java | 49 +++++++++++++++++-- .../server/saml/ws/ValidationAction.java | 17 ++++++- .../server/saml/ws/ValidationInitAction.java | 10 ++-- .../server/saml/ws/ValidationActionTest.java | 34 ++++++++++++- .../saml/ws/ValidationInitActionTest.java | 15 ++++-- 7 files changed, 135 insertions(+), 21 deletions(-) diff --git a/server/sonar-webserver-auth/src/main/java/org/sonar/server/authentication/SamlValidationRedirectionFilter.java b/server/sonar-webserver-auth/src/main/java/org/sonar/server/authentication/SamlValidationRedirectionFilter.java index 7d66c9047f0..4f242a563ba 100644 --- a/server/sonar-webserver-auth/src/main/java/org/sonar/server/authentication/SamlValidationRedirectionFilter.java +++ b/server/sonar-webserver-auth/src/main/java/org/sonar/server/authentication/SamlValidationRedirectionFilter.java @@ -25,12 +25,12 @@ import java.io.IOException; import java.net.URI; import java.net.URL; import java.nio.charset.StandardCharsets; +import javax.annotation.Nullable; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.lang.StringUtils; import org.sonar.api.internal.apachecommons.lang.StringEscapeUtils; @@ -44,6 +44,8 @@ public class SamlValidationRedirectionFilter extends ServletFilter { public static final String VALIDATION_RELAY_STATE = "validation-query"; public static final String SAML_VALIDATION_CONTROLLER_CONTEXT = "saml"; public static final String SAML_VALIDATION_KEY = "validation"; + private static final String RELAY_STATE_PARAMETER = "RelayState"; + private static final String SAML_RESPONSE_PARAMETER = "SAMLResponse"; private String redirectionPageTemplate; private final Server server; @@ -73,18 +75,20 @@ public class SamlValidationRedirectionFilter extends ServletFilter { @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - HttpServletRequest httpRequest = (HttpServletRequest) request; - if (isSamlValidation(httpRequest)) { + String relayState = request.getParameter(RELAY_STATE_PARAMETER); + + if (isSamlValidation(relayState)) { HttpServletResponse httpResponse = (HttpServletResponse) response; - String samlResponse = StringEscapeUtils.escapeHtml(request.getParameter("SAMLResponse")); URI redirectionEndpointUrl = URI.create(server.getContextPath() + "/") .resolve(SAML_VALIDATION_CONTROLLER_CONTEXT + "/") .resolve(SAML_VALIDATION_KEY); + String samlResponse = StringEscapeUtils.escapeHtml(request.getParameter(SAML_RESPONSE_PARAMETER)); + String csrfToken = getCsrfTokenFromRelayState(relayState); String template = StringUtils.replaceEachRepeatedly(redirectionPageTemplate, - new String[]{"%VALIDATION_URL%", "%SAML_RESPONSE%"}, - new String[]{redirectionEndpointUrl.toString(), samlResponse}); + new String[]{"%VALIDATION_URL%", "%SAML_RESPONSE%", "%CSRF_TOKEN%"}, + new String[]{redirectionEndpointUrl.toString(), samlResponse, csrfToken}); httpResponse.setContentType("text/html"); httpResponse.getWriter().print(template); @@ -93,7 +97,17 @@ public class SamlValidationRedirectionFilter extends ServletFilter { chain.doFilter(request, response); } - private static boolean isSamlValidation(HttpServletRequest request) { - return VALIDATION_RELAY_STATE.equals(request.getParameter("RelayState")); + private static boolean isSamlValidation(@Nullable String relayState) { + if (relayState != null) { + return VALIDATION_RELAY_STATE.equals(relayState.split("/")[0]) && !getCsrfTokenFromRelayState(relayState).isEmpty(); + } + return false; + } + + private static String getCsrfTokenFromRelayState(@Nullable String relayState) { + if (relayState != null && relayState.contains("/")) { + return StringEscapeUtils.escapeHtml(relayState.split("/")[1]); + } + return ""; } } diff --git a/server/sonar-webserver-auth/src/main/resources/validation-redirection.html b/server/sonar-webserver-auth/src/main/resources/validation-redirection.html index df9a48a29d6..f3060136404 100644 --- a/server/sonar-webserver-auth/src/main/resources/validation-redirection.html +++ b/server/sonar-webserver-auth/src/main/resources/validation-redirection.html @@ -10,6 +10,7 @@

SAML Authentication Validation

+
diff --git a/server/sonar-webserver-auth/src/test/java/org/sonar/server/authentication/SamlValidationRedirectionFilterTest.java b/server/sonar-webserver-auth/src/test/java/org/sonar/server/authentication/SamlValidationRedirectionFilterTest.java index b27c68a5f1f..c7ba12282fa 100644 --- a/server/sonar-webserver-auth/src/test/java/org/sonar/server/authentication/SamlValidationRedirectionFilterTest.java +++ b/server/sonar-webserver-auth/src/test/java/org/sonar/server/authentication/SamlValidationRedirectionFilterTest.java @@ -20,6 +20,9 @@ package org.sonar.server.authentication; +import com.tngtech.java.junit.dataprovider.DataProvider; +import com.tngtech.java.junit.dataprovider.DataProviderRunner; +import com.tngtech.java.junit.dataprovider.UseDataProvider; import java.io.IOException; import java.io.PrintWriter; import javax.servlet.FilterChain; @@ -29,6 +32,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.sonar.api.platform.Server; @@ -41,6 +45,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +@RunWith(DataProviderRunner.class) public class SamlValidationRedirectionFilterTest { SamlValidationRedirectionFilter underTest; @@ -62,14 +67,14 @@ public class SamlValidationRedirectionFilterTest { } @Test - public void do_filter_validation_relay_state() throws ServletException, IOException { + public void do_filter_validation_relay_state_with_csrfToken() throws ServletException, IOException { HttpServletRequest servletRequest = mock(HttpServletRequest.class); HttpServletResponse servletResponse = mock(HttpServletResponse.class); FilterChain filterChain = mock(FilterChain.class); String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample); - when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query"); + when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN"); PrintWriter pw = mock(PrintWriter.class); when(servletResponse.getWriter()).thenReturn(pw); @@ -79,6 +84,33 @@ public class SamlValidationRedirectionFilterTest { ArgumentCaptor htmlProduced = ArgumentCaptor.forClass(String.class); verify(pw).print(htmlProduced.capture()); assertThat(htmlProduced.getValue()).contains(validSample); + assertThat(htmlProduced.getValue()).contains("action=\"/saml/validation\""); + assertThat(htmlProduced.getValue()).contains("value=\"CSRF_TOKEN\""); + } + + @Test + public void do_filter_validation_relay_state_with_malicious_csrfToken() throws ServletException, IOException { + HttpServletRequest servletRequest = mock(HttpServletRequest.class); + HttpServletResponse servletResponse = mock(HttpServletResponse.class); + FilterChain filterChain = mock(FilterChain.class); + + String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample); + + String maliciousToken = "test\" htmlProduced = ArgumentCaptor.forClass(String.class); + verify(pw).print(htmlProduced.capture()); + assertThat(htmlProduced.getValue()).contains(validSample); + assertThat(htmlProduced.getValue()).doesNotContain(""); + } @Test @@ -90,7 +122,7 @@ public class SamlValidationRedirectionFilterTest { String maliciousSaml = "test\" underTest.extractTemplate("not-there")); } + + @DataProvider + public static Object[] invalidRelayStateValues() { + return new Object[]{"random_query", "validation-query", null}; + } + } diff --git a/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationAction.java b/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationAction.java index 4b58458513c..3b1c5e37a86 100644 --- a/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationAction.java +++ b/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationAction.java @@ -34,7 +34,9 @@ import org.sonar.auth.saml.SamlAuthenticator; import org.sonar.auth.saml.SamlIdentityProvider; import org.sonar.server.authentication.AuthenticationError; import org.sonar.server.authentication.OAuth2ContextFactory; +import org.sonar.server.authentication.OAuthCsrfVerifier; import org.sonar.server.authentication.SamlValidationRedirectionFilter; +import org.sonar.server.authentication.event.AuthenticationException; import org.sonar.server.user.ThreadLocalUserSession; import org.sonar.server.ws.ServletFilterHandler; @@ -44,11 +46,16 @@ public class ValidationAction extends ServletFilter implements SamlAction { private final ThreadLocalUserSession userSession; private final SamlAuthenticator samlAuthenticator; private final OAuth2ContextFactory oAuth2ContextFactory; + private final SamlIdentityProvider samlIdentityProvider; + private final OAuthCsrfVerifier oAuthCsrfVerifier; - public ValidationAction(ThreadLocalUserSession userSession, SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory) { + public ValidationAction(ThreadLocalUserSession userSession, SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory, + SamlIdentityProvider samlIdentityProvider, OAuthCsrfVerifier oAuthCsrfVerifier) { this.samlAuthenticator = samlAuthenticator; this.userSession = userSession; this.oAuth2ContextFactory = oAuth2ContextFactory; + this.samlIdentityProvider = samlIdentityProvider; + this.oAuthCsrfVerifier = oAuthCsrfVerifier; } @Override @@ -60,6 +67,14 @@ public class ValidationAction extends ServletFilter implements SamlAction { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletResponse httpResponse = (HttpServletResponse) response; HttpServletRequest httpRequest = (HttpServletRequest) request; + + try { + oAuthCsrfVerifier.verifyState(httpRequest, httpResponse, samlIdentityProvider, "CSRFToken"); + } catch (AuthenticationException exception) { + AuthenticationError.handleError(httpRequest, httpResponse, exception.getMessage()); + return; + } + if (!userSession.hasSession() || !userSession.isSystemAdministrator()) { AuthenticationError.handleError(httpRequest, httpResponse, "User needs to be logged in as system administrator to access this page."); return; diff --git a/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationInitAction.java b/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationInitAction.java index 4181fe6d83f..d25122d5ea8 100644 --- a/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationInitAction.java +++ b/server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationInitAction.java @@ -32,6 +32,7 @@ import org.sonar.auth.saml.SamlAuthenticator; import org.sonar.auth.saml.SamlIdentityProvider; import org.sonar.server.authentication.AuthenticationError; import org.sonar.server.authentication.OAuth2ContextFactory; +import org.sonar.server.authentication.OAuthCsrfVerifier; import org.sonar.server.exceptions.ForbiddenException; import org.sonar.server.user.UserSession; import org.sonar.server.ws.ServletFilterHandler; @@ -44,12 +45,13 @@ public class ValidationInitAction extends ServletFilter implements SamlAction { public static final String VALIDATION_RELAY_STATE = "validation-query"; public static final String VALIDATION_INIT_KEY = "validation_init"; private final SamlAuthenticator samlAuthenticator; - + private final OAuthCsrfVerifier oAuthCsrfVerifier; private final OAuth2ContextFactory oAuth2ContextFactory; private final UserSession userSession; - public ValidationInitAction(SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory, UserSession userSession) { + public ValidationInitAction(SamlAuthenticator samlAuthenticator, OAuthCsrfVerifier oAuthCsrfVerifier, OAuth2ContextFactory oAuth2ContextFactory, UserSession userSession) { this.samlAuthenticator = samlAuthenticator; + this.oAuthCsrfVerifier = oAuthCsrfVerifier; this.oAuth2ContextFactory = oAuth2ContextFactory; this.userSession = userSession; } @@ -82,9 +84,11 @@ public class ValidationInitAction extends ServletFilter implements SamlAction { return; } + String csrfState = oAuthCsrfVerifier.generateState(request,response); + try { samlAuthenticator.initLogin(oAuth2ContextFactory.generateCallbackUrl(SamlIdentityProvider.KEY), - VALIDATION_RELAY_STATE, request, response); + VALIDATION_RELAY_STATE + "/" + csrfState, request, response); } catch (IllegalStateException e) { response.sendRedirect("/" + SAML_VALIDATION_CONTROLLER_CONTEXT + "/" + SAML_VALIDATION_KEY); } diff --git a/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationActionTest.java b/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationActionTest.java index a00371db1d6..beb73abc605 100644 --- a/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationActionTest.java +++ b/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationActionTest.java @@ -31,12 +31,17 @@ import org.junit.Before; import org.junit.Test; import org.sonar.api.server.ws.WebService; import org.sonar.auth.saml.SamlAuthenticator; +import org.sonar.auth.saml.SamlIdentityProvider; import org.sonar.server.authentication.OAuth2ContextFactory; +import org.sonar.server.authentication.OAuthCsrfVerifier; +import org.sonar.server.authentication.event.AuthenticationEvent; +import org.sonar.server.authentication.event.AuthenticationException; import org.sonar.server.user.ThreadLocalUserSession; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -48,12 +53,18 @@ public class ValidationActionTest { private SamlAuthenticator samlAuthenticator; private ThreadLocalUserSession userSession; + private OAuthCsrfVerifier oAuthCsrfVerifier; + + private SamlIdentityProvider samlIdentityProvider; + @Before public void setup() { samlAuthenticator = mock(SamlAuthenticator.class); userSession = mock(ThreadLocalUserSession.class); + oAuthCsrfVerifier = mock(OAuthCsrfVerifier.class); + samlIdentityProvider = mock(SamlIdentityProvider.class); var oAuth2ContextFactory = mock(OAuth2ContextFactory.class); - underTest = new ValidationAction(userSession, samlAuthenticator, oAuth2ContextFactory); + underTest = new ValidationAction(userSession, samlAuthenticator, oAuth2ContextFactory, samlIdentityProvider, oAuthCsrfVerifier); } @Test @@ -98,6 +109,27 @@ public class ValidationActionTest { verifyNoInteractions(samlAuthenticator); } + @Test + public void do_filter_failed_csrf_verification() throws ServletException, IOException { + HttpServletRequest servletRequest = spy(HttpServletRequest.class); + HttpServletResponse servletResponse = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + doReturn(new PrintWriter(stringWriter)).when(servletResponse).getWriter(); + FilterChain filterChain = mock(FilterChain.class); + + doReturn("IdentityProviderName").when(samlIdentityProvider).getName(); + doThrow(AuthenticationException.newBuilder() + .setSource(AuthenticationEvent.Source.oauth2(samlIdentityProvider)) + .setMessage("Cookie is missing").build()).when(oAuthCsrfVerifier).verifyState(any(),any(),any(), any()); + + doReturn(true).when(userSession).hasSession(); + doReturn(true).when(userSession).isSystemAdministrator(); + + underTest.doFilter(servletRequest, servletResponse, filterChain); + + verifyNoInteractions(samlAuthenticator); + } + @Test public void verify_definition() { String controllerKey = "foo"; diff --git a/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationInitActionTest.java b/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationInitActionTest.java index c284e4defec..a12e56dbb66 100644 --- a/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationInitActionTest.java +++ b/server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationInitActionTest.java @@ -30,6 +30,7 @@ import org.junit.Test; import org.sonar.api.server.ws.WebService; import org.sonar.auth.saml.SamlAuthenticator; import org.sonar.server.authentication.OAuth2ContextFactory; +import org.sonar.server.authentication.OAuthCsrfVerifier; import org.sonar.server.tester.UserSessionRule; import static org.assertj.core.api.Assertions.assertThat; @@ -48,12 +49,14 @@ public class ValidationInitActionTest { private ValidationInitAction underTest; private SamlAuthenticator samlAuthenticator; private OAuth2ContextFactory oAuth2ContextFactory; + private OAuthCsrfVerifier oAuthCsrfVerifier; @Before public void setUp() throws Exception { samlAuthenticator = mock(SamlAuthenticator.class); oAuth2ContextFactory = mock(OAuth2ContextFactory.class); - underTest = new ValidationInitAction(samlAuthenticator, oAuth2ContextFactory, userSession); + oAuthCsrfVerifier = mock(OAuthCsrfVerifier.class); + underTest = new ValidationInitAction(samlAuthenticator, oAuthCsrfVerifier, oAuth2ContextFactory, userSession); } @Test @@ -71,8 +74,9 @@ public class ValidationInitActionTest { HttpServletResponse servletResponse = mock(HttpServletResponse.class); FilterChain filterChain = mock(FilterChain.class); String callbackUrl = "http://localhost:9000/api/validation_test"; - when(oAuth2ContextFactory.generateCallbackUrl(anyString())) - .thenReturn(callbackUrl); + + mockCsrfTokenGeneration(servletRequest, servletResponse); + when(oAuth2ContextFactory.generateCallbackUrl(anyString())).thenReturn(callbackUrl); underTest.doFilter(servletRequest, servletResponse, filterChain); @@ -91,6 +95,7 @@ public class ValidationInitActionTest { when(oAuth2ContextFactory.generateCallbackUrl(anyString())) .thenReturn(callbackUrl); + mockCsrfTokenGeneration(servletRequest, servletResponse); doThrow(new IllegalStateException()).when(samlAuthenticator).initLogin(any(), any(), any(), any()); underTest.doFilter(servletRequest, servletResponse, filterChain); @@ -143,4 +148,8 @@ public class ValidationInitActionTest { assertThat(validationInitAction.description()).isNotEmpty(); assertThat(validationInitAction.handler()).isNotNull(); } + + private void mockCsrfTokenGeneration(HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + when(oAuthCsrfVerifier.generateState(servletRequest, servletResponse)).thenReturn("CSRF_TOKEN"); + } } -- 2.39.5