// Copyright (C) 2013 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.gitblit.auth.github; import com.google.common.base.Charsets; import com.google.common.base.Strings; import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.inject.Inject; import org.apache.commons.codec.binary.Base64; import org.apache.http.HttpResponse; import org.apache.http.NameValuePair; import org.apache.http.client.ClientProtocolException; import org.apache.http.client.HttpClient; import org.apache.http.client.entity.UrlEncodedFormEntity; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.message.BasicNameValuePair; import org.apache.http.util.EntityUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.UnsupportedEncodingException; import java.net.HttpURLConnection; import java.net.URLEncoder; import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; class OAuthProtocol { private static final String ME_SEPARATOR = ","; private static final Logger log = LoggerFactory .getLogger(OAuthProtocol.class); private final GitHubOAuthConfig config; private final HttpClient http; private final Gson gson; private final String state; @Inject OAuthProtocol(GitHubOAuthConfig config, HttpClient http, Gson gson) { this.config = config; this.http = http; this.gson = gson; this.state = generateRandomState(); } void loginPhase1(HttpServletRequest request, HttpServletResponse response) throws IOException { log.debug("Initiating GitHub Login for ClientId=" + config.gitHubClientId); response.sendRedirect(String.format( "%s?client_id=%s&redirect_uri=%s&state=%s%s", config.gitHubOAuthUrl, config.gitHubClientId, getURLEncoded(config.oAuthFinalRedirectUrl), me(), getURLEncoded(request.getRequestURI().toString()))); } String loginPhase2(HttpServletRequest request, HttpServletResponse response) throws IOException { HttpPost post = new HttpPost(config.gitHubOAuthAccessTokenUrl); post.setHeader("Accept", "application/json"); List nvps = new ArrayList<>(3); nvps.add(new BasicNameValuePair("client_id", config.gitHubClientId)); nvps.add(new BasicNameValuePair("client_secret", config.gitHubClientSecret)); nvps.add(new BasicNameValuePair("code", request.getParameter("code"))); post.setEntity(new UrlEncodedFormEntity(nvps)); try { HttpResponse postResponse = http.execute(post); if (postResponse.getStatusLine().getStatusCode() != HttpURLConnection.HTTP_OK) { log.error("POST " + config.gitHubOAuthAccessTokenUrl + " request for access token failed with status " + postResponse.getStatusLine()); response.sendError(HttpURLConnection.HTTP_UNAUTHORIZED, "Request for access token not authorised"); EntityUtils.consume(postResponse.getEntity()); return null; } return getAccessToken(getAccessTokenJson(postResponse)); } catch (IOException e) { log.error("POST " + config.gitHubOAuthAccessTokenUrl + " request for access token failed", e); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Request for access token not authorised"); return null; } } String retrieveUser(String authToken) throws IOException { HttpGet get = new HttpGet(config.gitHubUserUrl); get.setHeader("Authorization", String.format("token %s", authToken)); try { return getLogin(getUserJson(httpGetGitHubUserInfo(get))); } catch (IOException e) { log.error("GET {} with authToken {} request failed", config.gitHubUserUrl, config.gitHubOAuthAccessTokenUrl, e); return null; } } private InputStream httpGetGitHubUserInfo(HttpGet get) throws IOException, ClientProtocolException { HttpResponse resp = http.execute(get); int statusCode = resp.getStatusLine().getStatusCode(); if (statusCode == HttpServletResponse.SC_OK) { return resp.getEntity().getContent(); } else { throw new IOException(String.format( "Invalid HTTP status code %s returned from %s", statusCode, get.getURI())); } } private String getAccessToken(JsonElement accessTokenJson) throws IOException { JsonElement accessTokenString = accessTokenJson.getAsJsonObject().get("access_token"); if (accessTokenString != null) { return accessTokenString.getAsString(); } else { throw new IOException(String.format( "Invalid JSON '%s': cannot find access_token field", accessTokenJson)); } } private JsonObject getAccessTokenJson(HttpResponse postResponse) throws UnsupportedEncodingException, IOException { JsonElement accessTokenJson = gson.fromJson(new InputStreamReader(postResponse.getEntity() .getContent(), Charsets.UTF_8), JsonElement.class); if (accessTokenJson.isJsonObject()) { return accessTokenJson.getAsJsonObject(); } else { throw new IOException(String.format( "Invalid JSON '%s': not a JSON Object")); } } boolean isOAuthRequest(HttpServletRequest httpRequest) { return OAuthProtocol.isGerritLogin(httpRequest) || OAuthProtocol.isOAuthFinal(httpRequest); } String getTargetUrl(ServletRequest request) { String requestState = state(request); int meEnd = requestState.indexOf(ME_SEPARATOR); if (meEnd >= 0 && requestState.subSequence(0, meEnd).equals(state)) { return requestState.substring(meEnd + 1); } else { log.warn("Illegal request state '" + requestState + "' on OAuthProtocol " + this); return null; } } private String me() { return state + ME_SEPARATOR; } private JsonObject getUserJson(InputStream userContentStream) throws IOException { JsonElement userJson = gson.fromJson(new InputStreamReader(userContentStream, Charsets.UTF_8), JsonElement.class); if (userJson.isJsonObject()) { return userJson.getAsJsonObject(); } else { throw new IOException(String.format( "Invalid JSON '%s': not a JSON Object", userJson)); } } static boolean isOAuthFinal(HttpServletRequest request) { return Strings.emptyToNull(request.getParameter("code")) != null; } static boolean isGerritLogin(HttpServletRequest request) { return request.getRequestURI().indexOf( GitHubOAuthConfig.GITHUB_LOGIN) >= 0; } private static String getLogin(JsonElement userJson) throws IOException { JsonElement userString = userJson.getAsJsonObject().get("login"); if (userString != null) { return userString.getAsString(); } else { throw new IOException(String.format( "Invalid JSON '%s': cannot find login field", userJson)); } } private static String generateRandomState() { byte[] randomState = new byte[32]; new SecureRandom().nextBytes(randomState); return Base64.encodeBase64URLSafeString(randomState); } private static String getURLEncoded(String url) { try { return URLEncoder.encode(url, Charsets.UTF_8.name()); } catch (UnsupportedEncodingException e) { // UTF-8 is hardcoded, cannot fail return null; } } private static String state(ServletRequest request) { return Strings.nullToEmpty(request.getParameter("state")); } @Override public String toString() { return "OAuthProtocol/" + state; } }