sonya_meta/
api.rs

1use crate::config::Secure;
2use actix_web::dev::{HttpServiceFactory, RequestHead};
3use actix_web::guard::Guard;
4use actix_web::rt::time::sleep;
5use actix_web::web::Data;
6use actix_web::{web, HttpResponse};
7use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
8use log::error;
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11use std::future::Future;
12use std::time::{Duration, SystemTime};
13
14const BEARER: &str = "Bearer ";
15
16#[macro_export]
17macro_rules! queue_scope_factory {
18    (   $create_queue:ident,
19        $send_to_queue:ident,
20        $close_queue:ident,
21        $subscribe_queue_by_id_ws:ident,
22        $subscribe_queue_by_id_longpoll:ident,
23        $subscribe_queue_ws:ident,
24        $subscribe_queue_longpoll:ident,
25        $secure:expr,
26    ) => {
27        match $secure {
28            None => web::scope("/queue")
29                .route("/create/{queue_name}", web::post().to($create_queue))
30                .route("/send/{queue_name}", web::post().to($send_to_queue))
31                .route("/close/{queue_name}", web::post().to($close_queue))
32                .service(
33                    web::scope("/listen")
34                        .route(
35                            "/longpoll/{queue_name}",
36                            web::get().to($subscribe_queue_longpoll),
37                        )
38                        .service(web::resource("/ws/{queue_name}").to($subscribe_queue_ws))
39                        .route(
40                            "/longpoll/{queue_name}/{uniq_id}",
41                            web::get().to($subscribe_queue_by_id_longpoll),
42                        )
43                        .service(
44                            web::resource("/ws/{queue_name}/{uniq_id}")
45                                .to($subscribe_queue_by_id_ws),
46                        ),
47                ),
48            Some(st) => web::scope("/queue")
49                .route(
50                    "/create/{queue_name}",
51                    web::post()
52                        .guard($crate::api::service_token_guard(st))
53                        .to($create_queue),
54                )
55                .route(
56                    "/send/{queue_name}",
57                    web::post()
58                        .guard($crate::api::service_token_guard(st))
59                        .to($send_to_queue),
60                )
61                .route(
62                    "/close/{queue_name}",
63                    web::post()
64                        .guard($crate::api::service_token_guard(st))
65                        .to($close_queue),
66                )
67                .service($crate::api::generate_jwt_method_factory(st.clone()))
68                .service(
69                    web::scope("/listen")
70                        .route(
71                            "/longpoll/{queue_name}",
72                            web::get()
73                                .guard($crate::api::service_token_guard(st))
74                                .to($subscribe_queue_longpoll),
75                        )
76                        .service(
77                            web::resource("/ws/{queue_name}")
78                                .guard($crate::api::service_token_guard(st))
79                                .to($subscribe_queue_ws),
80                        )
81                        .route(
82                            "/longpoll/{queue_name}/{uniq_id}",
83                            web::get()
84                                .guard($crate::api::jwt_token_guard(st))
85                                .to($subscribe_queue_by_id_longpoll),
86                        )
87                        .service(
88                            web::resource("/ws/{queue_name}/{uniq_id}")
89                                .guard($crate::api::jwt_token_guard(st))
90                                .to($subscribe_queue_by_id_ws),
91                        ),
92                ),
93        }
94    };
95}
96
97pub fn service_token_guard(secure: &Secure) -> impl Guard {
98    let service_token = secure.service_token.clone();
99    actix_web::guard::fn_guard(move |head| {
100        extract_access_token(head)
101            .filter(|token| *token == service_token)
102            .is_some()
103    })
104}
105
106pub fn jwt_token_guard(secure: &Secure) -> impl Guard {
107    let service_token = secure.service_token.clone();
108    actix_web::guard::fn_guard(move |head| {
109        extract_access_token(head)
110            .and_then(|token| {
111                decode::<Claims>(
112                    &token,
113                    &DecodingKey::from_secret(service_token.as_bytes()),
114                    &Validation::default(),
115                )
116                .ok()
117                .filter(|c| {
118                    head.uri
119                        .path()
120                        .ends_with(&format!("/{}/{}", c.claims.iss, c.claims.sub))
121                })
122            })
123            .is_some()
124    })
125}
126
127fn extract_access_token(head: &RequestHead) -> Option<String> {
128    extract_access_token_from_header(head).or_else(|| extract_access_token_from_query(head))
129}
130
131fn extract_access_token_from_query(head: &RequestHead) -> Option<String> {
132    extract_any_data_from_query::<AccessTokenQuery>(head).map(|a| a.access_token)
133}
134
135pub fn extract_any_data_from_query<T: DeserializeOwned>(head: &RequestHead) -> Option<T> {
136    head.uri.query().and_then(|v| {
137        serde_urlencoded::from_str(v)
138            .map_err(|e| {
139                error!("extracting sequence error: {}", e);
140                e
141            })
142            .ok()
143    })
144}
145
146fn extract_access_token_from_header(head: &RequestHead) -> Option<String> {
147    head.headers.get("Authorization").and_then(|head| {
148        let token = head
149            .to_str()
150            .ok()
151            .filter(|t| t.starts_with(BEARER))?
152            .trim_start_matches(BEARER);
153        Some(token.to_string())
154    })
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct AccessTokenQuery {
159    pub access_token: String,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163struct Claims {
164    sub: String,
165    exp: usize,
166    iss: String,
167}
168
169pub fn generate_jwt_method_factory(secure: Secure) -> impl HttpServiceFactory {
170    web::resource("/generate_jwt/{queue}/{uniq_id}")
171        .guard(service_token_guard(&secure))
172        .app_data(Data::new(secure))
173        .route(web::post().to(
174            move |secure: web::Data<Secure>, info: web::Path<(String, String)>| async move {
175                let (queue_name, id) = info.into_inner();
176                let expiration_res = SystemTime::now()
177                    .checked_add(Duration::from_secs(secure.jwt_token_expiration))
178                    .expect("could not to add minute to current time")
179                    .duration_since(SystemTime::UNIX_EPOCH);
180
181                let expiration = match expiration_res {
182                    Ok(e) => e.as_secs() as usize,
183                    Err(e) => return Err(actix_web::error::ErrorInternalServerError(e)),
184                };
185
186                let claims = Claims {
187                    sub: id,
188                    iss: queue_name,
189                    exp: expiration,
190                };
191                let token = encode(
192                    &Header::default(),
193                    &claims,
194                    &EncodingKey::from_secret(secure.service_token.as_bytes()),
195                );
196                match token {
197                    Ok(token) => {
198                        Ok(HttpResponse::Ok().json(JwtTokenResponse { token, expiration }))
199                    }
200                    Err(e) => Err(actix_web::error::ErrorForbidden(e)),
201                }
202            },
203        ))
204}
205
206#[derive(Debug, Serialize, Deserialize)]
207struct JwtTokenResponse {
208    token: String,
209    expiration: usize,
210}
211
212pub const MAX_RECONNECT_ATTEMPTS: u8 = 10;
213
214/// Calculate sleep time with formula `seconds = 1.5 * sqrt(attempts)`
215/// To getting increasing time intervals between reconnections.
216pub fn sleep_between_reconnects(attempt: u8) -> impl Future<Output = ()> {
217    sleep(Duration::from_secs((1.5 * (attempt as f32)).sqrt() as u64))
218}