]> source.dussan.org Git - sonarqube.git/commitdiff
SONAR-17694 fix SSF-323
authorMatteo Mara <matteo.mara@sonarsource.com>
Fri, 2 Dec 2022 16:56:14 +0000 (17:56 +0100)
committersonartech <sonartech@sonarsource.com>
Wed, 7 Dec 2022 20:02:57 +0000 (20:02 +0000)
server/sonar-webserver-auth/src/main/java/org/sonar/server/authentication/SamlValidationRedirectionFilter.java
server/sonar-webserver-auth/src/main/resources/validation-redirection.html
server/sonar-webserver-auth/src/test/java/org/sonar/server/authentication/SamlValidationRedirectionFilterTest.java
server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationAction.java
server/sonar-webserver-webapi/src/main/java/org/sonar/server/saml/ws/ValidationInitAction.java
server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationActionTest.java
server/sonar-webserver-webapi/src/test/java/org/sonar/server/saml/ws/ValidationInitActionTest.java

index 7d66c9047f0607fc8a1689ddf64a20aaf476d0e5..4f242a563ba9f56d3d8ea16057327d80d1b3091c 100644 (file)
@@ -25,12 +25,12 @@ import java.io.IOException;
 import java.net.URI;
 import java.net.URL;
 import java.nio.charset.StandardCharsets;
+import javax.annotation.Nullable;
 import javax.servlet.FilterChain;
 import javax.servlet.FilterConfig;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
-import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import org.apache.commons.lang.StringUtils;
 import org.sonar.api.internal.apachecommons.lang.StringEscapeUtils;
@@ -44,6 +44,8 @@ public class SamlValidationRedirectionFilter extends ServletFilter {
   public static final String VALIDATION_RELAY_STATE = "validation-query";
   public static final String SAML_VALIDATION_CONTROLLER_CONTEXT = "saml";
   public static final String SAML_VALIDATION_KEY = "validation";
+  private static final String RELAY_STATE_PARAMETER = "RelayState";
+  private static final String SAML_RESPONSE_PARAMETER = "SAMLResponse";
   private String redirectionPageTemplate;
   private final Server server;
 
@@ -73,18 +75,20 @@ public class SamlValidationRedirectionFilter extends ServletFilter {
 
   @Override
   public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
-    HttpServletRequest httpRequest = (HttpServletRequest) request;
-    if (isSamlValidation(httpRequest)) {
+    String relayState = request.getParameter(RELAY_STATE_PARAMETER);
+
+    if (isSamlValidation(relayState)) {
       HttpServletResponse httpResponse = (HttpServletResponse) response;
 
-      String samlResponse = StringEscapeUtils.escapeHtml(request.getParameter("SAMLResponse"));
       URI redirectionEndpointUrl = URI.create(server.getContextPath() + "/")
         .resolve(SAML_VALIDATION_CONTROLLER_CONTEXT + "/")
         .resolve(SAML_VALIDATION_KEY);
+      String samlResponse = StringEscapeUtils.escapeHtml(request.getParameter(SAML_RESPONSE_PARAMETER));
+      String csrfToken = getCsrfTokenFromRelayState(relayState);
 
       String template = StringUtils.replaceEachRepeatedly(redirectionPageTemplate,
-        new String[]{"%VALIDATION_URL%", "%SAML_RESPONSE%"},
-        new String[]{redirectionEndpointUrl.toString(), samlResponse});
+        new String[]{"%VALIDATION_URL%", "%SAML_RESPONSE%", "%CSRF_TOKEN%"},
+        new String[]{redirectionEndpointUrl.toString(), samlResponse, csrfToken});
 
       httpResponse.setContentType("text/html");
       httpResponse.getWriter().print(template);
@@ -93,7 +97,17 @@ public class SamlValidationRedirectionFilter extends ServletFilter {
     chain.doFilter(request, response);
   }
 
-  private static boolean isSamlValidation(HttpServletRequest request) {
-    return VALIDATION_RELAY_STATE.equals(request.getParameter("RelayState"));
+  private static boolean isSamlValidation(@Nullable String relayState) {
+    if (relayState != null) {
+      return VALIDATION_RELAY_STATE.equals(relayState.split("/")[0]) && !getCsrfTokenFromRelayState(relayState).isEmpty();
+    }
+    return false;
+  }
+
+  private static String getCsrfTokenFromRelayState(@Nullable String relayState) {
+    if (relayState != null && relayState.contains("/")) {
+      return StringEscapeUtils.escapeHtml(relayState.split("/")[1]);
+    }
+    return "";
   }
 }
index df9a48a29d6a26c4875494732e8beb1781661292..f30601364042d34951f91db83ef424b2a68d9fe1 100644 (file)
@@ -10,6 +10,7 @@
     <h1>SAML Authentication Validation</h1>
     <form id="saml_validate" action="%VALIDATION_URL%" method="POST">
         <input name="SAMLResponse" value="%SAML_RESPONSE%" type="hidden"/>
+        <input name="CSRFToken" value="%CSRF_TOKEN%" type="hidden"/>
         <button>Click Here to See Result</button>
     </form>
 </body>
index b27c68a5f1f06c6e2adcd4cbdb12d0e014fba66b..c7ba12282fac6bf99746cf87bf4c276937ba662b 100644 (file)
@@ -20,6 +20,9 @@
 
 package org.sonar.server.authentication;
 
+import com.tngtech.java.junit.dataprovider.DataProvider;
+import com.tngtech.java.junit.dataprovider.DataProviderRunner;
+import com.tngtech.java.junit.dataprovider.UseDataProvider;
 import java.io.IOException;
 import java.io.PrintWriter;
 import javax.servlet.FilterChain;
@@ -29,6 +32,7 @@ import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.sonar.api.platform.Server;
 
@@ -41,6 +45,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 
+@RunWith(DataProviderRunner.class)
 public class SamlValidationRedirectionFilterTest {
 
   SamlValidationRedirectionFilter underTest;
@@ -62,14 +67,14 @@ public class SamlValidationRedirectionFilterTest {
   }
 
   @Test
-  public void do_filter_validation_relay_state() throws ServletException, IOException {
+  public void do_filter_validation_relay_state_with_csrfToken() throws ServletException, IOException {
     HttpServletRequest servletRequest = mock(HttpServletRequest.class);
     HttpServletResponse servletResponse = mock(HttpServletResponse.class);
     FilterChain filterChain = mock(FilterChain.class);
 
     String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
     when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample);
-    when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query");
+    when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN");
     PrintWriter pw = mock(PrintWriter.class);
     when(servletResponse.getWriter()).thenReturn(pw);
 
@@ -79,6 +84,33 @@ public class SamlValidationRedirectionFilterTest {
     ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
     verify(pw).print(htmlProduced.capture());
     assertThat(htmlProduced.getValue()).contains(validSample);
+    assertThat(htmlProduced.getValue()).contains("action=\"/saml/validation\"");
+    assertThat(htmlProduced.getValue()).contains("value=\"CSRF_TOKEN\"");
+  }
+
+  @Test
+  public void do_filter_validation_relay_state_with_malicious_csrfToken() throws ServletException, IOException {
+    HttpServletRequest servletRequest = mock(HttpServletRequest.class);
+    HttpServletResponse servletResponse = mock(HttpServletResponse.class);
+    FilterChain filterChain = mock(FilterChain.class);
+
+    String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
+    when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample);
+
+    String maliciousToken = "test\"</input><script>*Malicious Token*</script><input value=\"";
+
+    when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/" + maliciousToken);
+    PrintWriter pw = mock(PrintWriter.class);
+    when(servletResponse.getWriter()).thenReturn(pw);
+
+    underTest.doFilter(servletRequest, servletResponse, filterChain);
+
+    verify(servletResponse).setContentType("text/html");
+    ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
+    verify(pw).print(htmlProduced.capture());
+    assertThat(htmlProduced.getValue()).contains(validSample);
+    assertThat(htmlProduced.getValue()).doesNotContain("<script>/*Malicious Token*/</script>");
+
   }
 
   @Test
@@ -90,7 +122,7 @@ public class SamlValidationRedirectionFilterTest {
     String maliciousSaml = "test\"</input><script>/*hack website*/</script><input value=\"";
 
     when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(maliciousSaml);
-    when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query");
+    when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN");
     PrintWriter pw = mock(PrintWriter.class);
     when(servletResponse.getWriter()).thenReturn(pw);
 
@@ -104,12 +136,13 @@ public class SamlValidationRedirectionFilterTest {
   }
 
   @Test
-  public void do_filter_no_validation_relay_state() throws ServletException, IOException {
+  @UseDataProvider("invalidRelayStateValues")
+  public void do_filter_invalid_relayState_values(String relayStateValue) throws ServletException, IOException {
     HttpServletRequest servletRequest = mock(HttpServletRequest.class);
     HttpServletResponse servletResponse = mock(HttpServletResponse.class);
     FilterChain filterChain = mock(FilterChain.class);
 
-    doReturn("random_query").when(servletRequest).getParameter("RelayState");
+    doReturn(relayStateValue).when(servletRequest).getParameter("RelayState");
     underTest.doFilter(servletRequest, servletResponse, filterChain);
 
     verifyNoInteractions(servletResponse);
@@ -119,4 +152,10 @@ public class SamlValidationRedirectionFilterTest {
   public void extract_nonexistant_template() {
     assertThrows(IllegalStateException.class, () -> underTest.extractTemplate("not-there"));
   }
+
+  @DataProvider
+  public static Object[] invalidRelayStateValues() {
+    return new Object[]{"random_query", "validation-query", null};
+  }
+
 }
index 4b58458513c3cac7639e4b180acbf8e6348c715e..3b1c5e37a8689013a51ea6705945887fbdee49be 100644 (file)
@@ -34,7 +34,9 @@ import org.sonar.auth.saml.SamlAuthenticator;
 import org.sonar.auth.saml.SamlIdentityProvider;
 import org.sonar.server.authentication.AuthenticationError;
 import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
 import org.sonar.server.authentication.SamlValidationRedirectionFilter;
+import org.sonar.server.authentication.event.AuthenticationException;
 import org.sonar.server.user.ThreadLocalUserSession;
 import org.sonar.server.ws.ServletFilterHandler;
 
@@ -44,11 +46,16 @@ public class ValidationAction extends ServletFilter implements SamlAction {
   private final ThreadLocalUserSession userSession;
   private final SamlAuthenticator samlAuthenticator;
   private final OAuth2ContextFactory oAuth2ContextFactory;
+  private final SamlIdentityProvider samlIdentityProvider;
+  private final OAuthCsrfVerifier oAuthCsrfVerifier;
 
-  public ValidationAction(ThreadLocalUserSession userSession, SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory) {
+  public ValidationAction(ThreadLocalUserSession userSession, SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory,
+    SamlIdentityProvider samlIdentityProvider, OAuthCsrfVerifier oAuthCsrfVerifier) {
     this.samlAuthenticator = samlAuthenticator;
     this.userSession = userSession;
     this.oAuth2ContextFactory = oAuth2ContextFactory;
+    this.samlIdentityProvider = samlIdentityProvider;
+    this.oAuthCsrfVerifier = oAuthCsrfVerifier;
   }
 
   @Override
@@ -60,6 +67,14 @@ public class ValidationAction extends ServletFilter implements SamlAction {
   public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
     HttpServletResponse httpResponse = (HttpServletResponse) response;
     HttpServletRequest httpRequest = (HttpServletRequest) request;
+
+    try {
+      oAuthCsrfVerifier.verifyState(httpRequest, httpResponse, samlIdentityProvider, "CSRFToken");
+    } catch (AuthenticationException exception) {
+      AuthenticationError.handleError(httpRequest, httpResponse, exception.getMessage());
+      return;
+    }
+
     if (!userSession.hasSession() || !userSession.isSystemAdministrator()) {
       AuthenticationError.handleError(httpRequest, httpResponse, "User needs to be logged in as system administrator to access this page.");
       return;
index 4181fe6d83f62b78abc685c24f9a284703821381..d25122d5ea819dc74ba85b6c786e206c0e50e4ea 100644 (file)
@@ -32,6 +32,7 @@ import org.sonar.auth.saml.SamlAuthenticator;
 import org.sonar.auth.saml.SamlIdentityProvider;
 import org.sonar.server.authentication.AuthenticationError;
 import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
 import org.sonar.server.exceptions.ForbiddenException;
 import org.sonar.server.user.UserSession;
 import org.sonar.server.ws.ServletFilterHandler;
@@ -44,12 +45,13 @@ public class ValidationInitAction extends ServletFilter implements SamlAction {
   public static final String VALIDATION_RELAY_STATE = "validation-query";
   public static final String VALIDATION_INIT_KEY = "validation_init";
   private final SamlAuthenticator samlAuthenticator;
-
+  private final OAuthCsrfVerifier oAuthCsrfVerifier;
   private final OAuth2ContextFactory oAuth2ContextFactory;
   private final UserSession userSession;
 
-  public ValidationInitAction(SamlAuthenticator samlAuthenticator, OAuth2ContextFactory oAuth2ContextFactory, UserSession userSession) {
+  public ValidationInitAction(SamlAuthenticator samlAuthenticator, OAuthCsrfVerifier oAuthCsrfVerifier, OAuth2ContextFactory oAuth2ContextFactory, UserSession userSession) {
     this.samlAuthenticator = samlAuthenticator;
+    this.oAuthCsrfVerifier = oAuthCsrfVerifier;
     this.oAuth2ContextFactory = oAuth2ContextFactory;
     this.userSession = userSession;
   }
@@ -82,9 +84,11 @@ public class ValidationInitAction extends ServletFilter implements SamlAction {
       return;
     }
 
+    String csrfState = oAuthCsrfVerifier.generateState(request,response);
+
     try {
       samlAuthenticator.initLogin(oAuth2ContextFactory.generateCallbackUrl(SamlIdentityProvider.KEY),
-        VALIDATION_RELAY_STATE, request, response);
+        VALIDATION_RELAY_STATE + "/" + csrfState, request, response);
     } catch (IllegalStateException e) {
       response.sendRedirect("/" + SAML_VALIDATION_CONTROLLER_CONTEXT + "/" + SAML_VALIDATION_KEY);
     }
index a00371db1d6efc77c299bf3d65a29a703ab522a2..beb73abc6058440022721a4f0912168c8ce761dc 100644 (file)
@@ -31,12 +31,17 @@ import org.junit.Before;
 import org.junit.Test;
 import org.sonar.api.server.ws.WebService;
 import org.sonar.auth.saml.SamlAuthenticator;
+import org.sonar.auth.saml.SamlIdentityProvider;
 import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
+import org.sonar.server.authentication.event.AuthenticationEvent;
+import org.sonar.server.authentication.event.AuthenticationException;
 import org.sonar.server.user.ThreadLocalUserSession;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
@@ -48,12 +53,18 @@ public class ValidationActionTest {
   private SamlAuthenticator samlAuthenticator;
   private ThreadLocalUserSession userSession;
 
+  private OAuthCsrfVerifier oAuthCsrfVerifier;
+
+  private SamlIdentityProvider samlIdentityProvider;
+
   @Before
   public void setup() {
     samlAuthenticator = mock(SamlAuthenticator.class);
     userSession = mock(ThreadLocalUserSession.class);
+    oAuthCsrfVerifier = mock(OAuthCsrfVerifier.class);
+    samlIdentityProvider = mock(SamlIdentityProvider.class);
     var oAuth2ContextFactory = mock(OAuth2ContextFactory.class);
-    underTest = new ValidationAction(userSession, samlAuthenticator, oAuth2ContextFactory);
+    underTest = new ValidationAction(userSession, samlAuthenticator, oAuth2ContextFactory, samlIdentityProvider, oAuthCsrfVerifier);
   }
 
   @Test
@@ -98,6 +109,27 @@ public class ValidationActionTest {
     verifyNoInteractions(samlAuthenticator);
   }
 
+  @Test
+  public void do_filter_failed_csrf_verification() throws ServletException, IOException {
+    HttpServletRequest servletRequest = spy(HttpServletRequest.class);
+    HttpServletResponse servletResponse = mock(HttpServletResponse.class);
+    StringWriter stringWriter = new StringWriter();
+    doReturn(new PrintWriter(stringWriter)).when(servletResponse).getWriter();
+    FilterChain filterChain = mock(FilterChain.class);
+
+    doReturn("IdentityProviderName").when(samlIdentityProvider).getName();
+    doThrow(AuthenticationException.newBuilder()
+      .setSource(AuthenticationEvent.Source.oauth2(samlIdentityProvider))
+      .setMessage("Cookie is missing").build()).when(oAuthCsrfVerifier).verifyState(any(),any(),any(), any());
+
+    doReturn(true).when(userSession).hasSession();
+    doReturn(true).when(userSession).isSystemAdministrator();
+
+    underTest.doFilter(servletRequest, servletResponse, filterChain);
+
+    verifyNoInteractions(samlAuthenticator);
+  }
+
   @Test
   public void verify_definition() {
     String controllerKey = "foo";
index c284e4defec7275f5922a79fc263eca19e857842..a12e56dbb66d9d61647bc30e9da419d0eec83ae4 100644 (file)
@@ -30,6 +30,7 @@ import org.junit.Test;
 import org.sonar.api.server.ws.WebService;
 import org.sonar.auth.saml.SamlAuthenticator;
 import org.sonar.server.authentication.OAuth2ContextFactory;
+import org.sonar.server.authentication.OAuthCsrfVerifier;
 import org.sonar.server.tester.UserSessionRule;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -48,12 +49,14 @@ public class ValidationInitActionTest {
   private ValidationInitAction underTest;
   private SamlAuthenticator samlAuthenticator;
   private OAuth2ContextFactory oAuth2ContextFactory;
+  private OAuthCsrfVerifier oAuthCsrfVerifier;
 
   @Before
   public void setUp() throws Exception {
     samlAuthenticator = mock(SamlAuthenticator.class);
     oAuth2ContextFactory = mock(OAuth2ContextFactory.class);
-    underTest = new ValidationInitAction(samlAuthenticator, oAuth2ContextFactory, userSession);
+    oAuthCsrfVerifier = mock(OAuthCsrfVerifier.class);
+    underTest = new ValidationInitAction(samlAuthenticator, oAuthCsrfVerifier, oAuth2ContextFactory, userSession);
   }
 
   @Test
@@ -71,8 +74,9 @@ public class ValidationInitActionTest {
     HttpServletResponse servletResponse = mock(HttpServletResponse.class);
     FilterChain filterChain = mock(FilterChain.class);
     String callbackUrl = "http://localhost:9000/api/validation_test";
-    when(oAuth2ContextFactory.generateCallbackUrl(anyString()))
-      .thenReturn(callbackUrl);
+
+    mockCsrfTokenGeneration(servletRequest, servletResponse);
+    when(oAuth2ContextFactory.generateCallbackUrl(anyString())).thenReturn(callbackUrl);
 
     underTest.doFilter(servletRequest, servletResponse, filterChain);
 
@@ -91,6 +95,7 @@ public class ValidationInitActionTest {
     when(oAuth2ContextFactory.generateCallbackUrl(anyString()))
       .thenReturn(callbackUrl);
 
+    mockCsrfTokenGeneration(servletRequest, servletResponse);
     doThrow(new IllegalStateException()).when(samlAuthenticator).initLogin(any(), any(), any(), any());
 
     underTest.doFilter(servletRequest, servletResponse, filterChain);
@@ -143,4 +148,8 @@ public class ValidationInitActionTest {
     assertThat(validationInitAction.description()).isNotEmpty();
     assertThat(validationInitAction.handler()).isNotNull();
   }
+
+  private void mockCsrfTokenGeneration(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
+    when(oAuthCsrfVerifier.generateState(servletRequest, servletResponse)).thenReturn("CSRF_TOKEN");
+  }
 }