1use super::Error;
7use crate::comms::WebsocketSender;
8use crate::config::get_config;
9use crate::http::{Authorization, Request, Response};
10use crate::view::{ToTemplateValue, Value};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use time::{Duration, OffsetDateTime};
15
16use std::collections::HashMap;
17use std::fmt::Debug;
18use std::sync::Arc;
19
20#[derive(Clone)]
22pub struct AuthHandler {
23    auth: Arc<Box<dyn Authentication>>,
24}
25
26impl Default for AuthHandler {
27    fn default() -> Self {
28        Self::new(AllowAll {})
29    }
30}
31
32impl AuthHandler {
33    pub fn new(auth: impl Authentication + 'static) -> Self {
35        AuthHandler {
36            auth: Arc::new(Box::new(auth)),
37        }
38    }
39
40    pub fn auth(&self) -> &Box<dyn Authentication> {
42        &self.auth
43    }
44}
45
46#[async_trait]
48#[allow(unused_variables)]
49pub trait Authentication: Sync + Send {
50    async fn authorize(&self, request: &Request) -> Result<bool, Error>;
53
54    async fn denied(&self, request: &Request) -> Result<Response, Error> {
57        Ok(Response::forbidden())
58    }
59
60    fn handler(self) -> AuthHandler
63    where
64        Self: Sized + 'static,
65    {
66        AuthHandler::new(self)
67    }
68}
69
70pub struct AllowAll;
72
73#[async_trait]
74impl Authentication for AllowAll {
75    async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
76        Ok(true)
77    }
78}
79
80pub struct DenyAll;
85
86#[async_trait]
87impl Authentication for DenyAll {
88    async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
89        Ok(false)
90    }
91}
92
93pub struct BasicAuth {
95    pub user: String,
97    pub password: String,
99}
100
101#[async_trait]
102impl Authentication for BasicAuth {
103    async fn authorize(&self, request: &Request) -> Result<bool, Error> {
104        Ok(
105            if let Some(Authorization::Basic { user, password }) = request.authorization() {
106                self.user == user && self.password == password
107            } else {
108                false
109            },
110        )
111    }
112
113    async fn denied(&self, _request: &Request) -> Result<Response, Error> {
114        Ok(Response::unauthorized("Basic"))
115    }
116}
117
118pub struct Token {
123    pub token: String,
125}
126
127#[async_trait]
128impl Authentication for Token {
129    async fn authorize(&self, request: &Request) -> Result<bool, Error> {
130        Ok(
131            if let Some(Authorization::Token { token }) = request.authorization() {
132                self.token == token
133            } else {
134                false
135            },
136        )
137    }
138}
139
140#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
142pub enum SessionId {
143    Guest(String),
145    Authenticated(i64),
147}
148
149impl SessionId {
150    pub fn authenticated(&self) -> bool {
152        use SessionId::*;
153
154        match self {
155            Guest(_) => false,
156            Authenticated(_) => true,
157        }
158    }
159
160    pub fn guest(&self) -> bool {
162        !self.authenticated()
163    }
164
165    pub fn user_id(&self) -> Option<i64> {
168        match self {
169            SessionId::Authenticated(id) => Some(*id),
170            _ => None,
171        }
172    }
173}
174
175impl std::fmt::Display for SessionId {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        match self {
178            SessionId::Authenticated(id) => write!(f, "{}", id),
179            SessionId::Guest(id) => write!(f, "{}", id),
180        }
181    }
182}
183
184impl Default for SessionId {
185    fn default() -> Self {
186        use rand::{distributions::Alphanumeric, thread_rng, Rng};
187
188        SessionId::Guest(
189            thread_rng()
190                .sample_iter(&Alphanumeric)
191                .take(16)
192                .map(char::from)
193                .collect::<String>(),
194        )
195    }
196}
197
198#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
203pub struct Session {
204    #[serde(rename = "p")]
206    pub payload: serde_json::Value,
207    #[serde(rename = "e")]
209    pub expiration: i64,
210    #[serde(rename = "s")]
212    pub session_id: SessionId,
213}
214
215impl Default for Session {
216    fn default() -> Self {
217        Self::new(serde_json::json!({})).expect("json")
218    }
219}
220
221impl ToTemplateValue for Session {
222    fn to_template_value(&self) -> Result<Value, crate::view::Error> {
223        let mut hash = HashMap::new();
224        hash.insert("expiration".into(), Value::Integer(self.expiration));
225        hash.insert(
226            "session_id".into(),
227            Value::String(self.session_id.to_string()),
228        );
229        hash.insert(
230            "payload".into(),
231            Value::String(serde_json::to_string(&self.payload).unwrap()),
232        );
233
234        Ok(Value::Hash(hash))
235    }
236}
237
238impl Session {
239    pub fn anonymous() -> Self {
241        Self::default()
242    }
243
244    pub fn empty() -> Self {
246        Self::default()
247    }
248
249    pub fn new(payload: impl Serialize) -> Result<Self, Error> {
251        Ok(Self {
252            payload: serde_json::to_value(payload)?,
253            expiration: (OffsetDateTime::now_utc() + get_config().general.session_duration())
254                .unix_timestamp(),
255            session_id: SessionId::default(),
256        })
257    }
258
259    pub fn new_authenticated(payload: impl Serialize, user_id: i64) -> Result<Self, Error> {
261        let mut session = Self::new(payload)?;
262        session.session_id = SessionId::Authenticated(user_id);
263
264        Ok(session)
265    }
266
267    pub fn renew(mut self, renew_for: Duration) -> Self {
269        self.expiration = (OffsetDateTime::now_utc() + renew_for).unix_timestamp();
270        self
271    }
272
273    pub fn should_renew(&self) -> bool {
275        if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
276            let now = OffsetDateTime::now_utc();
277            let remains = expiration - now;
278            let session_duration = get_config().general.session_duration();
279            remains < session_duration / 2 && remains.is_positive() } else {
281            true
282        }
283    }
284
285    pub fn expired(&self) -> bool {
287        if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
288            let now = OffsetDateTime::now_utc();
289            expiration < now
290        } else {
291            false
292        }
293    }
294
295    pub fn websocket(&self) -> WebsocketSender {
298        use crate::comms::Comms;
299        Comms::websocket(&self.session_id)
300    }
301
302    pub fn authenticated(&self) -> bool {
304        !self.expired() && self.session_id.authenticated()
305    }
306
307    pub fn guest(&self) -> bool {
309        !self.expired() && self.session_id.guest()
310    }
311}
312
313#[derive(Default)]
315pub struct SessionAuth {
316    redirect: Option<String>,
317}
318
319impl SessionAuth {
320    pub fn redirect(url: impl ToString) -> Self {
323        Self {
324            redirect: Some(url.to_string()),
325        }
326    }
327}
328
329#[async_trait]
330impl Authentication for SessionAuth {
331    async fn authorize(&self, request: &Request) -> Result<bool, Error> {
332        Ok(request.session().authenticated())
333    }
334
335    async fn denied(&self, _request: &Request) -> Result<Response, Error> {
336        if let Some(ref redirect) = self.redirect {
337            Ok(Response::new().redirect(redirect))
338        } else {
339            Ok(Response::forbidden())
340        }
341    }
342}
343
344#[cfg(test)]
345mod test {
346    use super::*;
347
348    #[test]
349    fn test_should_renew() {
350        let mut session = Session::default();
351        assert!(!session.should_renew());
352
353        assert_eq!(get_config().general.session_duration(), Duration::weeks(4));
354
355        session.expiration = (OffsetDateTime::now_utc() + Duration::weeks(2)
356            - Duration::seconds(5))
357        .unix_timestamp();
358        assert!(session.should_renew());
359
360        session.expiration =
361            (OffsetDateTime::now_utc() + Duration::weeks(2) + Duration::seconds(5))
362                .unix_timestamp();
363        assert!(!session.should_renew());
364    }
365}