1use crate::storage::Storage;
2use http_types::{Result, StatusCode};
3
4use std::any::Any;
5
6#[async_trait::async_trait]
7pub trait Scheme<User: Send + Sync + 'static> {
8 type Request: Any + Send + Sync;
9
10 async fn authenticate<S>(&self, state: &S, auth_param: &str) -> Result<Option<User>>
11 where
12 S: Storage<User, Self::Request> + Send + Sync + 'static;
13
14 fn should_401_on_multiple_values() -> bool {
15 true
16 }
17 fn should_403_on_bad_auth() -> bool {
18 true
19 }
20
21 fn header_name() -> &'static str {
22 "Authorization"
23 }
24 fn scheme_name() -> &'static str;
25}
26
27#[derive(Default, Debug)]
28pub struct BasicAuthScheme;
29
30#[derive(Debug)]
31pub struct BasicAuthRequest {
32 pub username: String,
33 pub password: String,
34}
35
36#[async_trait::async_trait]
37impl<User: Send + Sync + 'static> Scheme<User> for BasicAuthScheme {
38 type Request = BasicAuthRequest;
39
40 async fn authenticate<S>(&self, state: &S, auth_param: &str) -> Result<Option<User>>
41 where
42 S: Storage<User, Self::Request> + Send + Sync + 'static,
43 {
44 let bytes = base64::decode(auth_param);
45 if bytes.is_err() {
46 return Err(http_types::Error::from_str(
48 StatusCode::Unauthorized,
49 "Basic auth param must be valid base64.",
50 ));
51 }
52
53 let as_utf8 = String::from_utf8(bytes.unwrap());
54 if as_utf8.is_err() {
55 return Err(http_types::Error::from_str(
57 StatusCode::Unauthorized,
58 "Basic auth param base64 must contain valid utf-8.",
59 ));
60 }
61
62 let as_utf8 = as_utf8.unwrap();
63 let parts: Vec<_> = as_utf8.split(':').collect();
64
65 if parts.len() < 2 {
66 return Ok(None);
67 }
68
69 let (username, password) = (parts[0], parts[1]);
70
71 let user = state
72 .get_user(BasicAuthRequest {
73 username: username.to_owned(),
74 password: password.to_owned(),
75 })
76 .await?;
77
78 Ok(user)
79 }
80
81 fn scheme_name() -> &'static str {
82 "Basic "
83 }
84}
85
86#[derive(Default, Debug)]
87pub struct BearerAuthScheme {
88 prefix: String,
89}
90
91pub struct BearerAuthRequest {
92 pub token: String,
93}
94
95#[async_trait::async_trait]
96impl<User: Send + Sync + 'static> Scheme<User> for BearerAuthScheme {
97 type Request = BearerAuthRequest;
98
99 async fn authenticate<S>(&self, state: &S, auth_param: &str) -> Result<Option<User>>
100 where
101 S: Storage<User, Self::Request> + Send + Sync + 'static,
102 {
103 if !auth_param.starts_with(self.prefix.as_str()) {
104 return Ok(None);
105 }
106
107 let user = state
109 .get_user(BearerAuthRequest {
110 token: (&auth_param[self.prefix.len()..]).to_owned(),
111 })
112 .await?;
113 Ok(user)
114 }
115
116 fn scheme_name() -> &'static str {
117 "Bearer "
118 }
119}