3 * Copyright (C) 2009-2023 SonarSource SA
4 * mailto:info AT sonarsource DOT com
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 3 of the License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with this program; if not, write to the Free Software Foundation,
18 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
20 package org.sonar.server.authentication;
22 import com.tngtech.java.junit.dataprovider.DataProvider;
23 import com.tngtech.java.junit.dataprovider.DataProviderRunner;
24 import com.tngtech.java.junit.dataprovider.UseDataProvider;
25 import java.io.IOException;
26 import java.io.PrintWriter;
27 import javax.servlet.FilterChain;
28 import javax.servlet.FilterConfig;
29 import javax.servlet.ServletException;
30 import javax.servlet.http.HttpServletRequest;
31 import javax.servlet.http.HttpServletResponse;
32 import org.junit.Before;
33 import org.junit.Test;
34 import org.junit.runner.RunWith;
35 import org.mockito.ArgumentCaptor;
36 import org.sonar.api.platform.Server;
38 import static org.assertj.core.api.Assertions.assertThat;
39 import static org.junit.Assert.assertThrows;
40 import static org.mockito.ArgumentMatchers.matches;
41 import static org.mockito.Mockito.doReturn;
42 import static org.mockito.Mockito.mock;
43 import static org.mockito.Mockito.verify;
44 import static org.mockito.Mockito.verifyNoInteractions;
45 import static org.mockito.Mockito.when;
47 @RunWith(DataProviderRunner.class)
48 public class SamlValidationRedirectionFilterTest {
50 SamlValidationRedirectionFilter underTest;
53 public void setup() throws ServletException {
54 Server server = mock(Server.class);
55 doReturn("").when(server).getContextPath();
56 underTest = new SamlValidationRedirectionFilter(server);
57 underTest.init(mock(FilterConfig.class));
61 public void do_get_pattern() {
62 assertThat(underTest.doGetPattern().matches("/oauth2/callback/saml")).isTrue();
63 assertThat(underTest.doGetPattern().matches("/oauth2/callback/")).isFalse();
64 assertThat(underTest.doGetPattern().matches("/oauth2/callback/test")).isFalse();
65 assertThat(underTest.doGetPattern().matches("/oauth2/")).isFalse();
69 public void do_filter_validation_relay_state_with_csrfToken() throws ServletException, IOException {
70 HttpServletRequest servletRequest = mock(HttpServletRequest.class);
71 HttpServletResponse servletResponse = mock(HttpServletResponse.class);
72 FilterChain filterChain = mock(FilterChain.class);
74 String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
75 when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample);
76 when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN");
77 PrintWriter pw = mock(PrintWriter.class);
78 when(servletResponse.getWriter()).thenReturn(pw);
80 underTest.doFilter(servletRequest, servletResponse, filterChain);
82 verify(servletResponse).setContentType("text/html");
83 ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
84 verify(pw).print(htmlProduced.capture());
85 assertThat(htmlProduced.getValue()).contains(validSample);
86 assertThat(htmlProduced.getValue()).contains("action=\"/saml/validation\"");
87 assertThat(htmlProduced.getValue()).contains("value=\"CSRF_TOKEN\"");
91 public void do_filter_validation_relay_state_with_malicious_csrfToken() throws ServletException, IOException {
92 HttpServletRequest servletRequest = mock(HttpServletRequest.class);
93 HttpServletResponse servletResponse = mock(HttpServletResponse.class);
94 FilterChain filterChain = mock(FilterChain.class);
96 String validSample = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890";
97 when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(validSample);
99 String maliciousToken = "test\"</input><script>*Malicious Token*</script><input value=\"";
101 when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/" + maliciousToken);
102 PrintWriter pw = mock(PrintWriter.class);
103 when(servletResponse.getWriter()).thenReturn(pw);
105 underTest.doFilter(servletRequest, servletResponse, filterChain);
107 verify(servletResponse).setContentType("text/html");
108 ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
109 verify(pw).print(htmlProduced.capture());
110 assertThat(htmlProduced.getValue()).contains(validSample);
111 assertThat(htmlProduced.getValue()).doesNotContain("<script>/*Malicious Token*/</script>");
116 public void do_filter_validation_wrong_SAML_response() throws ServletException, IOException {
117 HttpServletRequest servletRequest = mock(HttpServletRequest.class);
118 HttpServletResponse servletResponse = mock(HttpServletResponse.class);
119 FilterChain filterChain = mock(FilterChain.class);
121 String maliciousSaml = "test\"</input><script>/*hack website*/</script><input value=\"";
123 when(servletRequest.getParameter(matches("SAMLResponse"))).thenReturn(maliciousSaml);
124 when(servletRequest.getParameter(matches("RelayState"))).thenReturn("validation-query/CSRF_TOKEN");
125 PrintWriter pw = mock(PrintWriter.class);
126 when(servletResponse.getWriter()).thenReturn(pw);
128 underTest.doFilter(servletRequest, servletResponse, filterChain);
130 verify(servletResponse).setContentType("text/html");
131 ArgumentCaptor<String> htmlProduced = ArgumentCaptor.forClass(String.class);
132 verify(pw).print(htmlProduced.capture());
133 assertThat(htmlProduced.getValue()).doesNotContain("<script>/*hack website*/</script>");
134 assertThat(htmlProduced.getValue()).contains("action=\"/saml/validation\"");
138 @UseDataProvider("invalidRelayStateValues")
139 public void do_filter_invalid_relayState_values(String relayStateValue) throws ServletException, IOException {
140 HttpServletRequest servletRequest = mock(HttpServletRequest.class);
141 HttpServletResponse servletResponse = mock(HttpServletResponse.class);
142 FilterChain filterChain = mock(FilterChain.class);
144 doReturn(relayStateValue).when(servletRequest).getParameter("RelayState");
145 underTest.doFilter(servletRequest, servletResponse, filterChain);
147 verifyNoInteractions(servletResponse);
151 public void extract_nonexistant_template() {
152 assertThrows(IllegalStateException.class, () -> underTest.extractTemplate("not-there"));
156 public static Object[] invalidRelayStateValues() {
157 return new Object[]{"random_query", "validation-query", null};