use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use service_logging::{
log,
Severity::{self, Verbose},
};
use std::fmt;
use wasm_service::{handler_return, Context, Handler, HandlerReturn, Request};
mod encoder;
use encoder::{decode, encode};
const AES_KEY_BYTES: usize = 32;
const TEXT_HTML: &str = "text/html";
const APPLICATION_JSON: &str = "application/json";
const GITHUB_GET_USER_API: &str = "https://api.github.com/user";
const MAX_APP_URL_LEN: usize = 800;
#[derive(Debug)]
pub enum Error {
Config(String),
Encryption(String),
Serde {
msg: &'static str,
e: serde_json::Error,
},
ArrayLen,
TimeoutExpired,
Random(String),
CookieDecode,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
fn config_error_return(e: impl std::error::Error) -> HandlerReturn {
HandlerReturn {
status: 200,
text: format!("Internal Error: {:?}", e),
}
}
fn config_error(field: &str, msg: &str) -> Error {
Error::Config(format!("field '{}': {}", field, msg))
}
fn config_field_empty(field: &str) -> Error {
config_error(field, "must not be empty")
}
pub fn is_valid_username_token(name: &str) -> bool {
if name.is_empty()
|| name.len() > 39
|| name.starts_with('-')
|| name.ends_with('-')
|| name.contains("--")
{
return false;
}
name.chars()
.into_iter()
.find(|c| !((*c).is_alphanumeric() || *c == '-'))
.is_none()
}
pub fn is_valid_return_url(return_url: &str) -> bool {
if return_url.is_empty() || return_url.len() > MAX_APP_URL_LEN {
return false;
}
let url = match reqwest::Url::parse(return_url) {
Err(_) => return false,
Ok(url) => url,
};
let scheme = url.scheme();
if scheme != "http" && scheme != "https" {
return false;
};
true
}
#[derive(Deserialize)]
struct TokenData {
access_token: String,
token_type: String,
}
impl fmt::Display for TokenData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TokenData({}...)", &self.access_token[..4])
}
}
#[derive(Debug, Deserialize)]
pub struct UserData {
pub login: String,
pub id: u64,
pub name: Option<String>,
pub email: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct State(String);
#[async_trait]
pub trait AuthCheck {
async fn check_authorized(
&self,
req: &Request,
ctx: &mut Context,
user: &UserData,
) -> Result<(), HandlerReturn>;
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Session(String, String);
pub struct OAuthConfig {
pub auth_failed_response: fn(&Request, ctx: &mut Context, return_url: &str),
pub auth_checker: Box<dyn AuthCheck>,
pub auth_error_redirect: fn(ctx: &mut Context, url: Option<&str>) -> HandlerReturn,
pub app_url: String,
pub logged_out_app_url: String,
pub authorize_url_path: String,
pub code_url_path: String,
pub login_failed_url_path: String,
pub logout_url_path: String,
pub user_agent: String,
pub cors_origins: Vec<String>,
pub cors_allow_methods: String,
pub cors_allow_age_sec: u64,
pub cors_allow_headers: String,
pub provider_authorize_url: String,
pub provider_token_url: String,
pub oauth_scopes: String,
pub client_id: String,
pub client_secret: String,
pub state_secret: Vec<u8>,
pub state_timeout_sec: u64,
pub session_secret: Vec<u8>,
pub session_timeout_sec: u64,
pub session_cookie_path_prefix: String,
}
impl Default for OAuthConfig {
fn default() -> Self {
Self {
client_id: String::new(),
client_secret: String::new(),
state_secret: Vec::new(),
session_secret: Vec::new(),
app_url: String::new(),
logged_out_app_url: "/".to_string(),
session_cookie_path_prefix: "/".to_string(),
oauth_scopes: "read:user".to_string(),
auth_failed_response,
auth_checker: Box::new(AlwaysDeny {}),
auth_error_redirect,
authorize_url_path: "/authorize".to_string(),
code_url_path: "/authorized".to_string(),
login_failed_url_path: "/login-failed".to_string(),
logout_url_path: "/logout".to_string(),
user_agent: "wasm-oauth".to_string(),
cors_origins: vec!["*".to_string()],
cors_allow_methods: "GET,POST,OPTIONS".to_string(),
cors_allow_age_sec: 24 * 60 * 60,
cors_allow_headers: "Content-Type,Origin,Accept,Accept-Language,X-Requested-With"
.to_string(),
provider_authorize_url: "https://github.com/login/oauth/authorize".to_string(),
provider_token_url: "https://github.com/login/oauth/access_token".to_string(),
session_timeout_sec: 3 * 24 * 60 * 60,
state_timeout_sec: 5 * 60,
}
}
}
pub struct AlwaysAllow {}
#[async_trait]
impl AuthCheck for AlwaysAllow {
async fn check_authorized(
&self,
_req: &Request,
_ctx: &mut Context,
_user: &UserData,
) -> Result<(), HandlerReturn> {
Ok(())
}
}
pub struct AlwaysDeny {}
#[async_trait]
impl AuthCheck for AlwaysDeny {
async fn check_authorized(
&self,
_req: &Request,
_ctx: &mut Context,
_user: &UserData,
) -> Result<(), HandlerReturn> {
Err(handler_return(403, "Not Allowed"))
}
}
pub struct OAuthHandler {
config: OAuthConfig,
}
#[derive(Debug)]
pub struct UserAllowList {
pub allowed_users: Vec<String>,
pub login_failed_url: String,
}
#[async_trait]
impl AuthCheck for UserAllowList {
async fn check_authorized(
&self,
_req: &Request,
mut ctx: &mut Context,
user: &UserData,
) -> Result<(), HandlerReturn> {
log!(ctx, Verbose, _:"gh_valid_name",name:&user.login);
if self.allowed_users.iter().any(|u| u == &user.login) {
log!(ctx, Verbose, _:"user_allowed", user: &user.login);
Ok(())
} else {
log!(ctx, Verbose, _:"user_allowed_not_found", user: &user.login);
Err(auth_error_redirect(
&mut ctx,
Some(&format!("{}?user={}", self.login_failed_url, &user.login)),
))
}
}
}
fn auth_error_redirect(ctx: &mut Context, url: Option<&str>) -> HandlerReturn {
let _ = ctx
.response()
.header("Location", url.unwrap_or("/login-failed"))
.unwrap()
.header(
"Set-Cookie",
"session=0; Path=/x; HttpOnly; Secure; SameSite=None; Max-Age=0",
)
.unwrap();
handler_return(303, "")
}
fn auth_failed_response(req: &Request, ctx: &mut Context, return_url: &str) {
let msg = if let Some(user) = req.get_query_value("user") {
format!(
r#"<p>Github user '{}' is not authorized to use this app.
Contact this app's administrator if '{}' should be added to the authorized users list,
or log out of <a href="https://github.com">Github</a> to try a different user</p>"#,
user, user
)
} else {
String::default()
};
let body = format!(
r#"<html>
<body><p>This app requires an authorized github user.</p>
{}<p><a href="{}">Return to app</a></p>
</body>
</html>
"#,
msg, return_url
);
ctx.response()
.status(200)
.content_type(TEXT_HTML)
.unwrap()
.text(body);
}
impl OAuthHandler {
pub fn init(config: OAuthConfig) -> Result<Self, Error> {
if config.app_url.is_empty() {
return Err(config_field_empty("app_url"));
}
if config.client_id.is_empty() {
return Err(config_field_empty("client_id"));
}
if config.client_secret.is_empty() {
return Err(config_field_empty("client_secret"));
}
if config.user_agent.is_empty() {
return Err(config_field_empty("user_agent"));
}
if config.state_secret.len() != AES_KEY_BYTES {
return Err(config_error("state_secret", "must be 32-byte secret key"));
}
if config.cors_origins.is_empty() {
return Err(config_field_empty("cors_origins"));
}
if config.cors_allow_age_sec > 7 * 24 * 60 * 60 || config.cors_allow_age_sec < 10 * 60 {
return Err(config_error(
"cors_allow_age_sec",
"should be from 10 minutes to 7 days - did you use a value in seconds?",
));
}
if config.state_timeout_sec > 10 * 60 || config.state_timeout_sec < 60 {
return Err(config_error("state_timeout_sec", "should be from 1 minute to 10 minutes (Github times out after 10 min). Value is in seconds"));
}
if config.session_secret.len() != AES_KEY_BYTES {
return Err(config_error("session_secret", "must be 32-byte secret key"));
}
if config.session_timeout_sec > 31 * 24 * 60 * 60 || config.session_timeout_sec < 10 * 60 {
return Err(config_error(
"session_timeout_sec",
"Session timeout should be from 10 minutes to 31 days. Value is in seconds",
));
}
if config.cors_allow_methods.find("GET").is_none()
|| config.cors_allow_methods.find("OPTIONS").is_none()
{
return Err(config_error("cors_allow_methods", "must include at least GET and OPTIONS. Value should be a comma-separated list of options, such as \"GET,POST,OPTIONS\""));
}
if config.oauth_scopes.find("read:user").is_none() {
return Err(config_error(
"oauth_scopes",
"must include at least read:user. Values should be a space-separated list",
));
}
Ok(OAuthHandler { config })
}
pub fn add_cors_headers(
&self,
req: &Request,
ctx: &mut Context,
) -> Result<(), wasm_service::Error> {
let allow_origin = self.map_cors_origin(req.get_header("origin"));
ctx.response()
.header("Access-Control-Allow-Origin", allow_origin)?
.header(
"Access-Control-Allow-Methods",
&self.config.cors_allow_methods,
)?
.header(
"Access-Control-Max-Age",
self.config.cors_allow_age_sec.to_string(),
)?
.header(
"Access-Control-Allow-Headers",
"Content-Type,Origin,Accept,Accept-Language,X-Requested-With",
)?;
Ok(())
}
pub fn map_cors_origin(&self, origin: Option<String>) -> &str {
if let Some(origin) = origin {
if let Some((i, _)) = self
.config
.cors_origins
.iter()
.enumerate()
.find(|(_, path)| *path == origin.as_str())
{
return &self.config.cors_origins[i];
}
}
&self.config.cors_origins[0]
}
async fn get_github_token(
&self,
mut ctx: &mut Context,
code: &str,
state: &str,
) -> Result<String, HandlerReturn> {
log!(ctx, Verbose, _:"get_token", code: code);
match self
.parse_json_response::<TokenData>(
"get_token",
reqwest::Client::new()
.post(&self.config.provider_token_url)
.form(&[
("client_id", &self.config.client_id),
("client_secret", &self.config.client_secret),
("code", &code.to_string()),
("state", &state.to_string()),
])
.header("Accept", APPLICATION_JSON)
.header("Cache-Control", "no-store, max-age=0")
.header("User-Agent", &self.config.user_agent)
.send()
.await,
)
.await
{
Ok(token_data) if token_data.token_type == "bearer" => Ok(token_data.access_token),
Ok(token_data) => {
log!(ctx, Severity::Error, _:"get_token type-err",
expect:"bearer", actual:&token_data.token_type);
Err((self.config.auth_error_redirect)(&mut ctx, None))
}
Err(msg) => {
log!(ctx, Severity::Error, text: msg);
Err((self.config.auth_error_redirect)(&mut ctx, None))
}
}
}
async fn get_github_user(
&self,
mut ctx: &mut Context,
oauth_token: &str,
) -> Result<UserData, HandlerReturn> {
log!(ctx, Verbose, _:"get_github_user", token: &oauth_token[..8]);
match self
.parse_json_response(
"get_user",
reqwest::Client::new()
.get(GITHUB_GET_USER_API)
.header("Authorization", format!("token {}", oauth_token))
.header("Accept", APPLICATION_JSON)
.header("Cache-Control", "no-store, max-age=0")
.header("User-Agent", &self.config.user_agent)
.send()
.await,
)
.await
{
Ok(user) => Ok(user),
Err(msg) => {
log!(ctx, Severity::Error, text: msg);
Err((self.config.auth_error_redirect)(&mut ctx, None))
}
}
}
async fn parse_json_response<T: DeserializeOwned>(
&self,
query: &str,
response: Result<reqwest::Response, reqwest::Error>,
) -> Result<T, String> {
let response = response
.map_err(|e| format!("gh_query http-err, q={} error={}", query, e.to_string()))?;
if !response.status().is_success() {
return Err(format!(
"gh_query status-err, q={} status={}",
query,
response.status()
));
}
let headers = dump_headers(&response);
let text = response
.text()
.await
.map_err(|e| format!("gh_query body-err, q={} error={}", query, e.to_string()))?;
let obj = serde_json::from_str::<T>(&text).map_err(|e| {
format!(
"gh_query json-err, q={} body={}, error={}, headers={}",
query,
&text,
e.to_string(),
headers
)
})?;
Ok(obj)
}
fn handle_oauth_login<'req>(
&self,
req: &'req Request,
mut ctx: &mut Context,
) -> Result<(), HandlerReturn> {
let return_url = req
.get_query_value("return_url")
.map(|s| s.to_string())
.filter(|u| is_valid_return_url(u))
.unwrap_or_else(|| self.config.app_url.to_string());
let location = match encode(
State(return_url),
&self.config.state_secret,
self.config.state_timeout_sec as u64,
) {
Ok(state) => Some(format!(
"{}?client_id={}&state={}&scope={}",
self.config.provider_authorize_url,
self.config.client_id,
state,
self.config.oauth_scopes,
)),
Err(e) => {
log!(ctx, Severity::Error, _:"oa1.encode", error:e);
None
}
};
match location {
Some(location) => {
log!(ctx, Verbose, _:"oa1", location: location);
ctx.response()
.status(302)
.header("Location", location)
.unwrap();
Ok(())
}
None => Err((self.config.auth_error_redirect)(&mut ctx, None)),
}
}
async fn handle_oauth_response(
&self,
req: &Request,
mut ctx: &mut Context,
) -> Result<(), HandlerReturn> {
let state = req.get_query_value("state").ok_or_else(|| {
log!(ctx, Severity::Error, _:"oa2:missing_state");
(self.config.auth_error_redirect)(&mut ctx, None)
})?;
let decoded_state: State =
decode(state.as_ref(), &self.config.state_secret).map_err(|e| {
log!(ctx, Severity::Error, _:"oa2:decode", error: e);
(self.config.auth_error_redirect)(&mut ctx, None)
})?;
let return_url = decoded_state.0;
let code = req.get_query_value("code").ok_or_else(|| {
log!(ctx, Severity::Error, _:"oa2:missing_code");
(self.config.auth_error_redirect)(&mut ctx, None)
})?;
let token = self
.get_github_token(&mut ctx, code.as_ref(), state.as_ref())
.await?;
let user = self.get_github_user(&mut ctx, &token).await?;
if is_valid_username_token(&user.login) {
self.config
.auth_checker
.check_authorized(req, &mut ctx, &user)
.await?;
} else {
log!(ctx, Severity::Error, msg:"Invalid chars in github username",name:&user.login);
return Err(auth_error_redirect(
&mut ctx,
Some(&self.config.login_failed_url_path),
));
}
let session_cookie = encode(
Session(user.login.clone(), token.clone()),
&self.config.session_secret,
self.config.session_timeout_sec as u64,
)
.map_err(|e| {
log!(ctx, Severity::Error, _:"oa2:encode_session", error: e);
(self.config.auth_error_redirect)(&mut ctx, None)
})?;
log!(ctx, Verbose, _:"oa2:set-session", cookie: session_cookie,
user: &user.login, token: &token[..8], return_url: &return_url);
ctx.response()
.status(302)
.header("Location", &return_url)
.unwrap()
.header(
"Set-Cookie",
format!(
"session={}; Path={}; HttpOnly; Secure; SameSite=None; Max-Age={}",
session_cookie,
self.config.session_cookie_path_prefix,
self.config.session_timeout_sec
),
)
.unwrap();
Ok(())
}
pub fn re_authorize(&self, ctx: &mut Context, redirect_url: &str) -> HandlerReturn {
log!(ctx, Verbose, _:"re_authorize", redirect_url: redirect_url);
ctx.response()
.header(
"Location",
format!(
"{}?redirect_url={}",
self.config.authorize_url_path, redirect_url
),
)
.unwrap()
.header(
"Set-Cookie",
"session=0; Path=/x; HttpOnly; Secure; SameSite=None; Max-Age=0",
)
.unwrap();
handler_return(303, "")
}
pub fn verify_auth_user(
&self,
req: &Request,
mut ctx: &mut Context,
) -> Result<Session, HandlerReturn> {
use wasm_service::Method;
let redirect_url = if req.method() == Method::GET {
req.url().to_string()
} else {
self.config.app_url.to_string()
};
let sess_cookie = req
.get_cookie_value("session")
.ok_or_else(|| self.re_authorize(&mut ctx, &redirect_url))?;
let session: Session = match decode(&sess_cookie, &self.config.session_secret) {
Ok(session) => session,
Err(Error::TimeoutExpired) => {
let url = req.url();
let redirect_url = if req.method() == Method::GET {
url.to_string()
} else {
self.config.app_url.to_string()
};
log!(ctx, Verbose,_:"verify_auth timeout", redirect_url:redirect_url);
return Err(self.re_authorize(&mut ctx, &redirect_url));
}
Err(e) => {
log!(ctx, Severity::Error, _:"session-decode error", error: e);
return Err((self.config.auth_error_redirect)(&mut ctx, None));
}
};
log!(ctx, Verbose, _:"verify_auth success", user: &session.0);
self.add_cors_headers(req, &mut ctx)
.map_err(config_error_return)?;
Ok(session)
}
fn logout(
&self,
_req: &Request,
ctx: &mut Context,
url: Option<&str>,
) -> Result<(), HandlerReturn> {
let _ = ctx
.response()
.header("Location", url.unwrap_or(&self.config.logged_out_app_url))
.unwrap()
.header(
"Set-Cookie",
"session=0; Path=/0; HttpOnly; Secure; SameSite=None; Max-Age=0",
)
.unwrap();
Err(handler_return(303, ""))
}
pub fn would_handle(&self, req: &Request) -> bool {
let path = req.url().path();
let conf = &self.config;
let auth_prefixes = vec![
&conf.authorize_url_path,
&conf.code_url_path,
&conf.login_failed_url_path,
&conf.logout_url_path,
];
req.method() == wasm_service::Method::OPTIONS
|| (req.method() == wasm_service::Method::GET
&& auth_prefixes.iter().any(|&prefix| path.starts_with(prefix)))
}
}
fn dump_headers(resp: &reqwest::Response) -> String {
resp.headers()
.into_iter()
.map(|(k, v)| format!("({}:{})", k, v.to_str().unwrap_or("")))
.collect::<Vec<String>>()
.join("\n")
}
#[async_trait(? Send)]
impl Handler for OAuthHandler {
async fn handle(&self, req: &Request, mut ctx: &mut Context) -> Result<(), HandlerReturn> {
use wasm_service::Method::{GET, OPTIONS};
match (req.method(), req.url().path()) {
(OPTIONS, _) => {
self.add_cors_headers(req, &mut ctx)
.map_err(config_error_return)?;
ctx.response().status(204);
}
(GET, authorize_url) if authorize_url == self.config.authorize_url_path => {
self.handle_oauth_login(req, &mut ctx)?;
}
(GET, code_url) if code_url == self.config.code_url_path => {
log!(ctx, Verbose, _:"authorized", url:req.url());
self.handle_oauth_response(req, &mut ctx).await?;
}
(GET, failed_url) if failed_url == self.config.login_failed_url_path => {
(self.config.auth_failed_response)(req, &mut ctx, &self.config.app_url);
}
(GET, logout_url) if logout_url == self.config.logout_url_path => {
self.logout(req, &mut ctx, None)?;
}
_ => { }
}
Ok(())
}
}
struct InitHandler {}
#[async_trait(? Send)]
impl Handler for InitHandler {
async fn handle(&self, req: &Request, ctx: &mut Context) -> Result<(), HandlerReturn> {
log!(ctx, Verbose, _:"handler", method: req.method(), url: req.url());
ctx.response()
.header("X-Frame-Options", "DENY")
.unwrap()
.header(
"Cache-Control",
"no-store, no-cache, must-revalidate, proxy-revalidate",
)
.unwrap();
Ok(())
}
}
#[test]
fn test_gh_username() {
assert!(is_valid_username_token("alice"));
assert!(is_valid_username_token("mid-hyphen"));
assert!(!is_valid_username_token(""));
assert!(!is_valid_username_token(
"1234567890123456789012345678901234567890"
));
assert!(!is_valid_username_token("joe$"));
assert!(!is_valid_username_token("bob%"));
assert!(!is_valid_username_token("s p a c e"));
assert!(!is_valid_username_token("hyphen--s"));
assert!(!is_valid_username_token("-start0hyphen"));
assert!(!is_valid_username_token("end-hyphen-"));
assert!(!is_valid_username_token("nonascii▲"));
}
#[test]
fn test_return_url() {
assert!(is_valid_return_url("https://api.example.com"));
assert!(is_valid_return_url(
"https://api.example.com/path/_x-y.html?this=that&other=this"
));
}