Skip to main content

sylvia_iot_data/routes/
middleware.rs

1//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2
3use std::{
4    collections::HashMap,
5    task::{Context, Poll},
6};
7
8use axum::{
9    extract::Request,
10    response::{IntoResponse, Response},
11};
12use futures::future::BoxFuture;
13use reqwest;
14use serde::{self, Deserialize};
15use tower::{Layer, Service};
16
17use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
18
19#[derive(Clone)]
20pub struct GetTokenInfoData {
21    /// The access token.
22    pub token: String,
23    pub user_id: String,
24    pub account: String,
25    pub roles: HashMap<String, bool>,
26    pub name: String,
27    pub client_id: String,
28    pub scopes: Vec<String>,
29}
30
31#[derive(Clone)]
32pub struct AuthService {
33    client: reqwest::Client,
34    auth_uri: String,
35}
36
37#[derive(Clone)]
38pub struct AuthMiddleware<S> {
39    client: reqwest::Client,
40    auth_uri: String,
41    service: S,
42}
43
44/// The user/client information of the token.
45#[derive(Clone, Deserialize)]
46struct GetTokenInfo {
47    data: GetTokenInfoDataInner,
48}
49
50#[derive(Clone, Deserialize)]
51struct GetTokenInfoDataInner {
52    #[serde(rename = "userId")]
53    pub user_id: String,
54    pub account: String,
55    pub roles: HashMap<String, bool>,
56    pub name: String,
57    #[serde(rename = "clientId")]
58    pub client_id: String,
59    pub scopes: Vec<String>,
60}
61
62impl AuthService {
63    pub fn new(client: reqwest::Client, auth_uri: String) -> Self {
64        AuthService { client, auth_uri }
65    }
66}
67
68impl<S> Layer<S> for AuthService {
69    type Service = AuthMiddleware<S>;
70
71    fn layer(&self, inner: S) -> Self::Service {
72        AuthMiddleware {
73            client: self.client.clone(),
74            auth_uri: self.auth_uri.clone(),
75            service: inner,
76        }
77    }
78}
79
80impl<S> Service<Request> for AuthMiddleware<S>
81where
82    S: Service<Request, Response = Response> + Clone + Send + 'static,
83    S::Future: Send + 'static,
84{
85    type Response = S::Response;
86    type Error = S::Error;
87    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
88
89    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        self.service.poll_ready(cx)
91    }
92
93    fn call(&mut self, mut req: Request) -> Self::Future {
94        let mut svc = self.service.clone();
95        let client = self.client.clone();
96        let auth_uri = self.auth_uri.clone();
97
98        Box::pin(async move {
99            let token = match sylvia_http::parse_header_auth(&req) {
100                Err(e) => return Ok(e.into_response()),
101                Ok(token) => match token {
102                    None => {
103                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
104                        return Ok(e.into_response());
105                    }
106                    Some(token) => token,
107                },
108            };
109
110            let token_req = match client
111                .request(reqwest::Method::GET, auth_uri.as_str())
112                .header(reqwest::header::AUTHORIZATION, token.as_str())
113                .build()
114            {
115                Err(e) => {
116                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
117                    return Ok(e.into_response());
118                }
119                Ok(req) => req,
120            };
121            let resp = match client.execute(token_req).await {
122                Err(e) => {
123                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
124                    return Ok(e.into_response());
125                }
126                Ok(resp) => match resp.status() {
127                    reqwest::StatusCode::UNAUTHORIZED => {
128                        return Ok(ErrResp::ErrAuth(None).into_response());
129                    }
130                    reqwest::StatusCode::OK => resp,
131                    _ => {
132                        let e = ErrResp::ErrIntMsg(Some(format!(
133                            "auth error with status code: {}",
134                            resp.status()
135                        )));
136                        return Ok(e.into_response());
137                    }
138                },
139            };
140            let token_info = match resp.json::<GetTokenInfo>().await {
141                Err(e) => {
142                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
143                    return Ok(e.into_response());
144                }
145                Ok(info) => info,
146            };
147
148            let mut split = token.split_whitespace();
149            split.next(); // skip "Bearer".
150            let token = match split.next() {
151                None => {
152                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
153                    return Ok(e.into_response());
154                }
155                Some(token) => token.to_string(),
156            };
157
158            req.extensions_mut().insert(GetTokenInfoData {
159                token,
160                user_id: token_info.data.user_id,
161                account: token_info.data.account,
162                roles: token_info.data.roles,
163                name: token_info.data.name,
164                client_id: token_info.data.client_id,
165                scopes: token_info.data.scopes,
166            });
167
168            let res = svc.call(req).await?;
169            Ok(res)
170        })
171    }
172}