Browse Source

initial commit

main
Alex Feldman-Crough 4 weeks ago
commit
e607b6d567
  1. 2
      .gitignore
  2. 41
      Cargo.toml
  3. 1
      result
  4. 22
      src/api/error.rs
  5. 120
      src/api/mod.rs
  6. 205
      src/api/types.rs
  7. 185
      src/client/client.rs
  8. 4
      src/client/mod.rs
  9. 60
      src/client/routes.rs
  10. 3
      src/config/mod.rs
  11. 194
      src/config/oauth2.rs
  12. 21
      src/http.rs
  13. 8
      src/lib.rs

2
.gitignore

@ -0,0 +1,2 @@
/target
Cargo.lock

41
Cargo.toml

@ -0,0 +1,41 @@
[package]
name = "google-oauth2"
version = "0.1.0"
authors = ["Alex Feldman-Crough <alex@fldcr.com>"]
edition = "2018"
[dependencies]
chrono = "0.4"
derive_more = "0.99"
futures = "0.3"
hyper-tls = "0.5"
thiserror = "1.0"
log = "0.4"
rand = "0.8"
serde_urlencoded = "0.7"
serde_json = "1.0"
routerify = "2.1"
[dependencies.hyper]
version = "0.14"
default-features = false
features = ["client", "http1", "http2", "stream", "runtime"]
[dependencies.serde]
version = "1.0"
features = ["derive"]
[dependencies.secret]
version = "0.1"
git = "https://git.fldcr.com/jafc/secret"
branch = "main"
features = ["serde"]
[dependencies.tokio]
version = "1.6"
features = ["signal", "time"]
[dependencies.url]
version = "2.2"
features = ["serde"]

1
result

@ -0,0 +1 @@
/nix/store/qk48i5xyx4yf0dk3y61xm310jmrlsg89-openssl-all

22
src/api/error.rs

@ -0,0 +1,22 @@
use thiserror::Error;
pub use std::result::Result as StdResult;
use crate::http;
pub type Result<T=()> = StdResult<T, Error>;
#[derive(Debug, Error)]
pub enum Error {
#[error("The API indicated a {0} error.{}{}",
if .1.is_some() { "\n Details:" } else { "" },
.1.as_ref().unwrap_or(&serde_json::Value::Null))]
ApiError(http::StatusCode, Option<serde_json::Value>),
#[error("Unexpected error while communicating with server")]
ProtocolError(#[from] #[source] hyper::Error),
#[error("Failed to process JSON")]
JsonError(#[from] #[source] serde_json::Error),
#[error("Failed to deserialize URL-encoded parameters")]
UrlDecodingError(#[from] #[source] serde_urlencoded::de::Error),
#[error("Failed to serialize URL-encoded parameters")]
UrlEncodingError(#[from] #[source] serde_urlencoded::ser::Error),
}

120
src/api/mod.rs

@ -0,0 +1,120 @@
use futures::stream::TryStreamExt as _;
use serde::{Deserialize, Serialize};
use url::Url;
pub use self::error::{Error, Result};
pub use self::types::{
AccessToken, Authorization, AuthorizationCode, OneTimeToken, RefreshToken, RefreshableToken,
StateKey, Token,
};
use crate::config;
use crate::http;
pub mod error;
pub mod types;
pub fn new_authorization(
config: &config::OAuth2,
state_key: Option<types::StateKey>,
) -> types::Authorization {
log::info!("creating new authorization with state key {:?}", state_key);
let mut url = config.authorization_endpoint.clone();
{
let mut query = url.query_pairs_mut();
query
.append_pair("access_type", config.access_type.as_str())
.append_pair("client_id", &config.client_id)
.append_pair("prompt", config.prompt.as_str())
.append_pair("redirect_uri", config.redirect_uri.as_str())
.append_pair("response_type", "code")
.append_pair("scope", config.scope.as_str());
if let Some(hint) = config.login_hint.as_ref() {
query.append_pair("login_hint", &hint);
}
if let Some(key) = state_key.as_ref() {
query.append_pair("state", &key);
}
}
log::info!("authorization url: {}", url);
types::Authorization::new(url, state_key)
}
pub async fn exchange_authorization_code(
config: &config::OAuth2,
client: &http::Client,
authorization_code: &types::AuthorizationCode,
) -> Result<types::Token> {
log::info!("exchanging authorization code for access token");
let req = types::OAuth2Request {
client_id: &config.client_id,
client_secret: config.client_secret.as_deref(),
grant_type: types::OAuth2RequestGrant::AuthorizationCode {
code: authorization_code,
redirect_uri: &config.redirect_uri,
scope: config.scope.as_str(),
},
};
let now = chrono::Utc::now();
let resp: types::OAuth2Response = api_call(client, &config.token_endpoint, req).await?;
let token = resp.into_token(now);
log::info!("access token will expire at {}", token.expires_at());
Ok(token)
}
pub async fn refresh_token(
config: &config::OAuth2,
client: &http::Client,
refresh_token: &types::RefreshToken,
) -> Result<types::OneTimeToken> {
log::info!("refreshing OAuth2 token");
let req = types::OAuth2Request {
client_id: &config.client_id,
client_secret: config.client_secret.as_deref(),
grant_type: types::OAuth2RequestGrant::RefreshToken { refresh_token },
};
let now = chrono::Utc::now();
let resp: types::OAuth2Response = api_call(client, &config.token_endpoint,
req).await?;
let token = resp
.into_token(now)
.one_time()
.expect("refreshing should yeild a one-time token");
log::info!("access token will expire at {}", token.expires_at());
Ok(token)
}
async fn api_call<I, O>(client: &http::Client, uri: &Url, body: I) -> Result<O>
where
I: Serialize,
O: for<'de> Deserialize<'de>,
{
let request_string = serde_urlencoded::to_string(body)?;
let request = http::request()
.uri(uri.as_str())
.method(http::Method::POST)
.header(
http::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(http::header::ACCEPT, "application/json; charset=utf-8")
.body(http::Body::from(request_string))
.expect("malformed HTTP request");
let response = client.request(request).await?;
let status = response.status();
let body: Vec<u8> = response
.into_body()
.map_ok(|bytes| {
let iter = bytes.into_iter().map(Ok::<_, hyper::Error>);
futures::stream::iter(iter)
})
.try_flatten()
.try_collect()
.await?;
if status.is_success() {
let result = serde_json::from_slice(&body)?;
Ok(result)
} else {
let json = serde_json::from_slice(&body).ok();
Err(Error::ApiError(status, json))
}
}

205
src/api/types.rs

@ -0,0 +1,205 @@
use serde::{Deserialize, Serialize};
use serde::de::Deserializer;
use serde::ser::Serializer;
use url::Url;
use derive_more::{From, Into, Deref};
use secret::Secret;
use chrono::{DateTime, Duration, Utc};
use std::sync::Arc;
#[derive(Clone, Eq, Debug, Deserialize, Deref, From, Into, Hash, PartialEq,
Serialize)]
#[deref(forward)]
pub struct StateKey(String);
impl StateKey {
pub fn from_usize(u: usize) -> Self {
StateKey(format!("{:016X}", u))
}
}
#[derive(Clone, Debug)]
pub struct Authorization {
url: Url,
state_key: Option<StateKey>,
}
impl Authorization {
pub const fn new(url: Url, state_key: Option<StateKey>) -> Self {
Authorization { url, state_key }
}
pub const fn url(&self) -> &Url {
&self.url
}
pub const fn state_key(&self) -> &Option<StateKey> {
&self.state_key
}
}
#[derive(Clone, Eq, Debug, Hash, PartialEq)]
pub struct Token<Refresh = Option<RefreshToken>> {
access_token: AccessToken,
refresh_token: Refresh,
expires_at: DateTime<Utc>,
}
impl<R> Token<R> {
pub fn access_token(&self) -> &AccessToken {
&self.access_token
}
pub fn expires_at(&self) -> DateTime<Utc>{
self.expires_at
}
pub fn set_refresh_token(self, refresh_token: RefreshToken)
-> RefreshableToken
{
Token {
refresh_token,
access_token: self.access_token,
expires_at: self.expires_at,
}
}
pub fn expired(&self) -> bool {
self.expires_at < Utc::now()
}
}
impl Token {
pub fn refresh_token(&self) -> Option<&RefreshToken> {
self.refresh_token.as_ref()
}
pub fn one_time(self) -> Result<OneTimeToken, Token> {
if self.refresh_token.is_none() {
Ok(Token {
access_token: self.access_token,
expires_at: self.expires_at,
refresh_token: (),
})
} else {
Err(self)
}
}
pub fn refreshable(self) -> Result<RefreshableToken, Token> {
if let Some(token) = self.refresh_token {
Ok(Token {
access_token: self.access_token,
expires_at: self.expires_at,
refresh_token: token,
})
} else {
Err(self)
}
}
}
impl RefreshableToken {
pub fn refresh_token(&self) -> &RefreshToken {
&self.refresh_token
}
}
pub type RefreshableToken = Token<RefreshToken>;
pub type OneTimeToken = Token<()>;
#[derive(Clone, Eq, Debug, Deserialize, Deref, From, Into, Hash, PartialEq,
Serialize)]
pub struct AuthorizationCode(Secret<String>);
#[derive(Clone, Eq, Debug, Deref, From, Into, Hash, PartialEq)]
pub struct AccessToken(Arc<Secret<str>>);
impl<'de> Deserialize<'de> for AccessToken {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = <&'de str>::deserialize(de)?;
let arc = Arc::from(s);
Ok(AccessToken(Secret::from_arc(arc)))
}
}
impl Serialize for AccessToken {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
ser.serialize_str(&self.0)
}
}
#[derive(Clone, Eq, Debug, Deref, From, Into, Hash, PartialEq)]
pub struct RefreshToken(Arc<Secret<str>>);
impl<'de> Deserialize<'de> for RefreshToken {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = <&'de str>::deserialize(de)?;
let arc = Arc::from(s);
Ok(RefreshToken(Secret::from_arc(arc)))
}
}
impl Serialize for RefreshToken {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
ser.serialize_str(&self.0)
}
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct OAuth2Request<'a> {
pub(crate) client_id: &'a str,
pub(crate) client_secret: &'a Secret<str>,
#[serde(flatten)]
pub(crate) grant_type: OAuth2RequestGrant<'a>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct OAuth2Response {
pub(crate) access_token: AccessToken,
pub(crate) refresh_token: Option<RefreshToken>,
pub(crate) expires_in: u64,
}
impl OAuth2Response {
pub fn into_token(self, req_time: DateTime<Utc>) -> Token {
Token {
access_token: self.access_token,
refresh_token: self.refresh_token,
expires_at: req_time - Duration::seconds(self.expires_in as i64),
}
}
}
#[derive(Clone, Debug, Serialize)]
#[serde(tag = "grant_type", rename_all="snake_case")]
pub(crate) enum OAuth2RequestGrant<'a> {
RefreshToken {
refresh_token: &'a RefreshToken,
},
AuthorizationCode {
code: &'a AuthorizationCode,
redirect_uri: &'a Url,
scope: &'a str,
},
}
#[derive(Debug, Deserialize)]
pub(crate) struct CallbackData {
pub code: AuthorizationCode,
#[serde(default, rename="state")]
pub state_key: Option<StateKey>,
}

185
src/client/client.rs

@ -0,0 +1,185 @@
use tokio::sync::{Notify, Mutex};
use url::Url;
use std::mem;
use std::borrow::Cow;
use rand::Rng;
use crate::api;
use crate::config;
use crate::http::{self, header, Response, StatusCode};
pub struct Client {
state: Mutex<State>,
data: Data,
}
impl Client {
pub fn new(config: config::OAuth2, default_redirect: Url) -> Self {
Client {
state: Mutex::new(State::Initial),
data: Data {
config,
default_redirect,
http: http::new_client(),
notify: Notify::new(),
},
}
}
pub async fn token(&self) -> Result<api::AccessToken> {
loop {
log::debug!("fetching access token from Client state");
let mut state = self.state.lock().await;
match &mut *state {
State::Busy => unreachable!(),
State::Initial | State::Unauthorized(_, _) => {
log::debug!("the application has not been authorized yet");
mem::drop(state);
self.data.notify.notified().await;
},
State::Authorized(ref mut token) => {
if token.expired() {
let new_token = api::refresh_token(
&self.data.config,
&self.data.http,
token.refresh_token()
).await?;
*token = new_token.set_refresh_token(
token.refresh_token().clone()
);
}
return Ok(token.access_token().clone());
},
}
}
}
}
impl Client {
pub(super)
async fn authorization_redirect(&self, source_url: Option<Url>) -> Result<Response> {
let source_url = source_url
.map_or(Cow::Borrowed(&self.data.default_redirect),
Cow::Owned);
let mut lock = self.state.lock().await;
let state = mem::replace(&mut *lock, State::Busy);
let (auth, url) = match state {
State::Busy => unreachable!(),
State::Authorized(_) => {
log::debug!("already authorized, redirecting to {}", source_url);
*lock = state;
return Ok(redirect_to(&source_url))
}
State::Initial => {
let key_u = rand::thread_rng().gen::<usize>();
let key = api::types::StateKey::from_usize(key_u);
let auth = api::new_authorization(
&self.data.config,
Some(key),
);
(auth, source_url.into_owned())
},
State::Unauthorized(auth, mut url) => {
if url != *source_url {
log::debug!("changing redirect url from {} to {}",
url,
source_url);
url = source_url.into_owned();
}
(auth, url)
}
};
let resp = redirect_to(auth.url());
*lock = State::Unauthorized(auth, url);
Ok(resp)
}
pub(super)
async fn callback_handler(
&self,
data: api::types::CallbackData,
) -> Result<Response> {
let mut lock = self.state.lock().await;
let state = mem::replace(&mut *lock, State::Busy);
match state {
State::Busy => unreachable!(),
State::Authorized(_) => {
*lock = state;
Err(Error::AlreadyAuthorized)
},
State::Initial => {
*lock = state;
Err(Error::NotInitialized)
},
State::Unauthorized(auth, redirect) => {
log::debug!("checking that the correct state key has been passed");
if data.state_key != *auth.state_key() {
return Err(Error::StateKeyMismatch {
provided: data.state_key,
expected: auth.state_key().clone(),
});
}
log::debug!("fetching the authorization token");
let token = api::exchange_authorization_code(&self.data.config,
&self.data.http,
&data.code)
.await?
.refreshable()
.expect("OAuth2 is already configured to return \
refreshable tokens");
*lock = State::Authorized(token);
log::debug!("notifying threads waiting on a token");
self.data.notify.notify_waiters();
Ok(redirect_to(&redirect))
}
}
}
}
struct Data {
config: config::OAuth2,
default_redirect: Url,
notify: Notify,
http: http::Client,
}
enum State {
Busy,
Initial,
Unauthorized(api::Authorization, Url),
Authorized(api::RefreshableToken),
}
pub type Result<T = ()> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
ApiError(#[from] api::Error),
#[error(
"Processed an authorization callback, but the application is \
already authorized"
)]
AlreadyAuthorized,
#[error("Authorization was not requested")]
NotInitialized,
#[error("The request had missing or malformed parameters")]
InvalidRequest(#[from] #[source] serde_urlencoded::de::Error),
#[error(
"The callback returned state key {provided:?}, but we expected \
{expected:?}"
)]
StateKeyMismatch {
provided: Option<api::StateKey>,
expected: Option<api::StateKey>,
},
}
fn redirect_to(url: &Url) -> Response {
log::debug!("redirecting to {}", url);
http::response()
.status(StatusCode::SEE_OTHER)
.header(header::LOCATION, url.as_str())
.body("".into())
.expect("redirect response invalid")
}

4
src/client/mod.rs

@ -0,0 +1,4 @@
pub use self::client::{Client, Error};
mod client;
mod routes;

60
src/client/routes.rs

@ -0,0 +1,60 @@
use url::Url;
use std::sync::Arc;
use routerify::Router;
use std::error::Error as StdError;
use crate::http::{header, Body, Request, Response};
use super::client::{self, Client};
impl Client {
pub fn router<E>(self: Arc<Self>) -> Router<Body, E>
where
E: 'static + From<client::Error> + Send + Sync + StdError
{
Router::builder()
.data(self)
.get("/authorize", route_authorize)
.get("/callback", route_callback)
.build()
.expect("constructing the router should not fail")
}
}
async fn route_authorize<E>(req: Request) -> Result<Response, E>
where
E: From<client::Error>,
{
let client = req.client();
let referrer = req.headers()
.get(header::REFERER)
.and_then(|x| x.to_str().ok())
.and_then(|x| Url::parse(x).ok());
client.authorization_redirect(referrer)
.await
.map_err(E::from)
}
async fn route_callback<E>(req: Request) -> Result<Response, E>
where
E: From<client::Error>,
{
let client = req.client();
let query_str = req.uri().query().unwrap_or_default();
let query = serde_urlencoded::from_str(&query_str)
.map_err(|e| client::Error::from(e).into())?;
client.callback_handler(query)
.await
.map_err(E::from)
}
pub trait RequestExt: routerify::prelude::RequestExt {
fn client(&self) -> Arc<Client> {
let client_ref = self.data().unwrap();
Arc::clone(&client_ref)
}
}
impl<T> RequestExt for T
where
T: routerify::prelude::RequestExt
{}

3
src/config/mod.rs

@ -0,0 +1,3 @@
pub use oauth2::OAuth2;
pub mod oauth2;

194
src/config/oauth2.rs

@ -0,0 +1,194 @@
use serde::Deserialize;
use serde::de::{Deserializer, Visitor, SeqAccess, Error as _};
use serde::ser::{Serialize, Serializer, SerializeSeq};
use secret::Secret;
use url::Url;
use std::fmt::{self, Debug, Display, Formatter};
#[derive(Clone, Debug, Deserialize)]
pub struct OAuth2 {
#[serde(default)]
pub access_type: AccessType,
#[serde(default="def::authorization_endpoint")]
pub authorization_endpoint: Url,
pub client_id: String,
pub client_secret: Secret<String>,
#[serde(default)]
pub login_hint: Option<String>,
#[serde(default)]
pub prompt: Prompt,
pub redirect_uri: Url,
#[serde(default)]
pub scope: Scopes,
#[serde(default="def::token_endpoint")]
pub token_endpoint: Url,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[serde(rename_all="snake_case")]
pub enum Prompt {
None,
Consent,
SelectAccount
}
impl Prompt {
pub const fn as_str(&self) -> &'static str {
match self {
Prompt::None => "none",
Prompt::Consent => "consent",
Prompt::SelectAccount => "select_account",
}
}
}
impl Default for Prompt {
fn default() -> Self {
Prompt::None
}
}
impl Display for Prompt {
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
fmt.write_str(self.as_str())
}
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[serde(rename_all="snake_case")]
pub enum AccessType {
Online,
Offline,
}
impl AccessType {
pub const fn as_str(&self) -> &'static str {
match self {
AccessType::Online => "online",
AccessType::Offline => "offline",
}
}
}
impl Default for AccessType {
fn default() -> Self {
AccessType::Online
}
}
impl Display for AccessType {
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
fmt.write_str(self.as_str())
}
}
#[derive(Debug, Default, Clone)]
pub struct Scopes {
string: String,
length: usize,
}
impl Scopes {
pub fn as_str(&self) -> &str {
&self.string
}
pub fn iter(&self) -> impl Iterator<Item = &str> {
self.string.split(' ')
}
pub fn to_vec(&self) -> Vec<&str> {
self.iter().collect()
}
pub fn len(&self) -> usize {
self.length
}
}
impl<'de> Deserialize<'de> for Scopes {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
{
struct ScopesV;
impl<'de> Visitor<'de> for ScopesV {
type Value = Scopes;
fn expecting(&self, fmt: &mut Formatter) -> fmt::Result
{
write!(fmt, "a list of scopes")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Scopes, A::Error>
where
A: SeqAccess<'de>
{
fn check_str<'de, A> (s: &'de str) -> Result<(), A::Error>
where
A: SeqAccess<'de>
{
if s.contains(' ') {
Err(A::Error::invalid_value(
serde::de::Unexpected::Str(s),
&"a string with no whitespace",
))
} else {
Ok(())
}
}
let mut string = String::with_capacity(1024);
let mut length = 0;
if let Some(s) = seq.next_element::<&'de str>()? {
length += 1;
check_str::<A>(s)?;
string += s;
}
while let Some(s) = seq.next_element::<&'de str>()? {
length += 1;
check_str::<A>(s)?;
string.push(' ');
string += s;
}
string.shrink_to_fit();
Ok(Scopes { string, length })
}
}
de.deserialize_seq(ScopesV)
}
}
impl Serialize for Scopes {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer
{
let mut s = s.serialize_seq(Some(self.len()))?;
for scope in self.iter() {
s.serialize_element(scope)?;
}
s.end()
}
}
mod def {
use url::Url;
const GOOGLE_OAUTH2_ENDPOINT: &str =
"https://accounts.google.com/o/oauth2/v2/auth";
const GOOGLE_TOKEN_ENDPOINT: &str =
"https://oauth2.googleapis.com/token";
pub fn authorization_endpoint() -> Url {
Url::parse(GOOGLE_OAUTH2_ENDPOINT)
.expect("failed to parse GOOGLE_OAUTH2_ENDPOINT")
}
pub fn token_endpoint() -> Url {
Url::parse(GOOGLE_TOKEN_ENDPOINT)
.expect("failed to parse GOOGLE_TOKEN_ENDPOINT")
}
}

21
src/http.rs

@ -0,0 +1,21 @@
pub use hyper::http::{StatusCode, Method, header};
use hyper::client::HttpConnector;
use hyper_tls::HttpsConnector;
pub use hyper::body::Body;
pub type Client = hyper::Client<HttpsConnector<HttpConnector>>;
pub type Request<T = Body> = hyper::Request<T>;
pub type Response<T = Body> = hyper::Response<T>;
pub fn new_client() -> Client {
hyper::Client::builder()
.build(HttpsConnector::new())
}
pub fn request() -> hyper::http::request::Builder {
hyper::http::request::Builder::new()
}
pub fn response() -> hyper::http::response::Builder {
hyper::http::response::Builder::new()
}

8
src/lib.rs

@ -0,0 +1,8 @@
pub use client::Client;
pub use api::types::{Token, RefreshToken, AccessToken, AuthorizationCode,
StateKey, RefreshableToken, OneTimeToken};
pub mod api;
pub mod config;
pub mod client;
mod http;
Loading…
Cancel
Save