tpex_api/server/
tokens.rs

1use std::{str::FromStr, sync::Arc};
2
3use axum::{http::StatusCode};
4use axum_extra::headers::{authorization::Bearer, Authorization, HeaderMapExt};
5use num_traits::FromPrimitive;
6use tokio::io::{AsyncBufRead, AsyncSeek, AsyncWrite};
7use tpex::PlayerId;
8use crate::shared::*;
9
10use super::state::StateStruct;
11
12impl<T: AsyncBufRead + AsyncWrite + AsyncSeek + Unpin + Send + Sync> axum::extract::FromRequestParts<Arc<StateStruct<T>>> for TokenInfo {
13    type Rejection = StatusCode;
14
15    async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &Arc<StateStruct<T>>) -> Result<Self, Self::Rejection> {
16            let Some(auth) : Option<Authorization<Bearer>> = parts.headers.typed_get()
17            else { return Err(StatusCode::UNAUTHORIZED); };
18
19            let Ok(token) = auth.0.token().parse()
20            else { return Err(StatusCode::UNAUTHORIZED); };
21
22            let Ok(token_info) = state.tokens.get_token(&token).await
23            else { return Err(StatusCode::UNAUTHORIZED); };
24
25            // If the token would need banker perms to make, check that the user is still at that level
26            if token_info.level > TokenLevel::ProxyOne && !state.tpex.read().await.state().is_banker(&token_info.user) {
27                return Err(StatusCode::UNAUTHORIZED)
28            }
29
30            Ok(token_info)
31        }
32}
33
34pub struct TokenHandler {
35    pool: sqlx::SqlitePool
36}
37impl TokenHandler {
38    pub async fn new(url: &str) -> sqlx::Result<TokenHandler> {
39        sqlx::any::install_default_drivers();
40        let opt = sqlx::sqlite::SqliteConnectOptions::from_str(url)?.create_if_missing(true);
41        let ret = TokenHandler{
42            pool: sqlx::SqlitePool::connect_with(opt).await?
43        };
44
45        sqlx::migrate!("./migrations").run(&ret.pool).await?;
46
47        Ok(ret)
48    }
49    pub async fn create_token(&self, level: TokenLevel, user: PlayerId) -> sqlx::Result<Token> {
50        let token = Token::generate();
51
52        let slice = token.0.as_slice();
53        let level = level as i64;
54        let user = user.get_raw_name();
55
56        sqlx::query!(r#"INSERT INTO tokens(token, level, user) VALUES (?, ?, ?)"#, slice, level, user)
57        .execute(&self.pool).await?;
58
59        Ok(token)
60    }
61    pub async fn get_token(&self, token: &Token) -> sqlx::Result<TokenInfo> {
62        let slice = token.0.as_slice();
63        let query =
64            sqlx::query!(r#"SELECT token as "token: Vec<u8>", level, user FROM tokens WHERE token = ?"#, slice)
65            .fetch_one(&self.pool).await?;
66
67        Ok(TokenInfo {
68            token: Token(query.token.try_into().expect("Mismatched token length")),
69            #[allow(deprecated)]
70            user: tpex::PlayerId::assume_username_correct(query.user),
71            level: TokenLevel::from_i64(query.level).expect("Invalid token level")
72        })
73    }
74    pub async fn delete_token(&self, token: &Token) -> sqlx::Result<()> {
75        let slice = token.0.as_slice();
76        sqlx::query!(r#"DELETE FROM tokens WHERE token = ?"#, slice)
77        .execute(&self.pool).await?;
78        Ok(())
79    }
80}