1#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/43955412")]
2use rocket::{
7    fairing::{self, Fairing, Info, Kind},
8    http::{self, uri::Origin, Method, Status},
9    request::{self, FromRequest},
10    route, Build, Data, Request, Rocket, Route,
11};
12use sentinel_core::EntryBuilder;
13use std::sync::Mutex;
14
15pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
16
17pub type Extractor = fn(&Request<'_>) -> String;
20
21pub type Fallback<R> = fn(&Request<'_>, sentinel_core::Error) -> R;
23
24fn default_extractor(req: &Request<'_>) -> String {
25    req.uri().path().to_string()
26}
27
28fn default_fallback_for_guard(
29    _request: &Request<'_>,
30    err: sentinel_core::Error,
31) -> request::Outcome<SentinelGuard, BoxError> {
32    request::Outcome::Failure((Status::TooManyRequests, err.into()))
33}
34
35pub type SentinelConfigForGuard = SentinelConfig<request::Outcome<SentinelGuard, BoxError>>;
36pub type SentinelConfigForFairing = SentinelConfig<()>;
37
38pub struct SentinelConfig<R> {
48    pub extractor: Option<Extractor>,
49    pub fallback: Option<Fallback<R>>,
50}
51
52impl<R> SentinelConfig<R> {
53    pub fn with_extractor(mut self, extractor: Extractor) -> Self {
54        self.extractor = Some(extractor);
55        self
56    }
57
58    pub fn with_fallback(mut self, fallback: Fallback<R>) -> Self {
59        self.fallback = Some(fallback);
60        self
61    }
62}
63
64impl<R> Clone for SentinelConfig<R> {
67    fn clone(&self) -> Self {
68        Self {
69            extractor: self.extractor.clone(),
70            fallback: self.fallback.clone(),
71        }
72    }
73}
74
75impl<R> Default for SentinelConfig<R> {
78    fn default() -> Self {
79        Self {
80            extractor: None,
81            fallback: None,
82        }
83    }
84}
85
86#[derive(Debug)]
95pub struct SentinelGuard;
96
97#[rocket::async_trait]
98impl<'r> FromRequest<'r> for SentinelGuard {
99    type Error = BoxError;
100
101    async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
102        let empty_config = SentinelConfig::default();
103        let config = req
104            .rocket()
105            .state::<SentinelConfig<request::Outcome<SentinelGuard, BoxError>>>()
107            .unwrap_or(&empty_config);
108        let extractor = config.extractor.unwrap_or(default_extractor);
109        let fallback = config.fallback.unwrap_or(default_fallback_for_guard);
110
111        let resource = extractor(req);
112        let entry_builder = EntryBuilder::new(resource)
113            .with_traffic_type(sentinel_core::base::TrafficType::Inbound);
114
115        match entry_builder.build() {
116            Ok(entry) => {
117                entry.exit();
118                request::Outcome::Success(SentinelGuard {})
119            }
120            Err(err) => fallback(req, err),
121        }
122    }
123}
124
125#[derive(Debug)]
129pub struct SentinelFairingState {
130    pub msg: Mutex<String>,
131    pub uri: String,
136}
137
138impl SentinelFairingState {
139    pub fn new(uri: String) -> Self {
140        Self {
141            msg: Mutex::new(String::new()),
142            uri,
143        }
144    }
145}
146
147type FairingHandler = for<'r> fn(&'r Request<'_>, Data<'r>) -> route::Outcome<'r>;
148
149#[derive(Clone, Default)]
150pub struct SentinelFairingHandler(Option<FairingHandler>);
151
152impl SentinelFairingHandler {
153    pub fn new(h: FairingHandler) -> Self {
154        Self(Some(h))
155    }
156}
157
158#[rocket::async_trait]
159impl route::Handler for SentinelFairingHandler {
160    async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> {
161        fn default_handler<'r>(req: &'r Request<'_>, _data: Data<'r>) -> route::Outcome<'r> {
162            match req.rocket().state::<SentinelFairingState>() {
163                Some(_) => route::Outcome::Failure(Status::TooManyRequests),
164                None => route::Outcome::Failure(Status::InternalServerError),
165            }
166        }
167
168        let h = self.0.unwrap_or(default_handler);
169        h(req, data)
170    }
171}
172
173impl Into<Vec<Route>> for SentinelFairingHandler {
174    fn into(self) -> Vec<Route> {
175        vec![Route::new(Method::Get, "/", self)]
176    }
177}
178
179#[derive(Default)]
184pub struct SentinelFairing {
185    uri: String,
188    handler: SentinelFairingHandler,
190    config: SentinelConfig<()>,
194}
195
196impl SentinelFairing {
197    pub fn new(uri: &'static str) -> Result<Self, http::uri::Error> {
198        Ok(SentinelFairing::default().with_uri(uri)?)
199    }
200
201    pub fn with_extractor(mut self, extractor: Extractor) -> Self {
202        self.config = self.config.with_extractor(extractor);
203        self
204    }
205
206    pub fn with_fallback(mut self, fallback: Fallback<()>) -> Self {
207        self.config = self.config.with_fallback(fallback);
208        self
209    }
210
211    pub fn with_handler(mut self, h: FairingHandler) -> Self {
212        self.handler = SentinelFairingHandler::new(h);
213        self
214    }
215
216    pub fn with_uri(mut self, uri: &'static str) -> Result<Self, http::uri::Error> {
217        let origin = Origin::parse(uri)?;
218        self.uri = origin.path().to_string();
219        Ok(self)
220    }
221}
222
223#[rocket::async_trait]
224impl Fairing for SentinelFairing {
225    fn info(&self) -> Info {
226        Info {
227            name: "Sentinel Fairing",
228            kind: Kind::Ignite | Kind::Request,
229        }
230    }
231
232    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
233        let handler = self.handler.clone();
234        Ok(rocket
235            .manage(SentinelFairingState::new(self.uri.clone()))
236            .mount(self.uri.clone(), handler))
237    }
238
239    async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
240        let empty_config = SentinelConfig::default();
241        let config = req
242            .rocket()
243            .state::<SentinelConfig<()>>()
244            .unwrap_or(&empty_config);
245        let extractor = self
246            .config
247            .extractor
248            .unwrap_or(config.extractor.unwrap_or(default_extractor));
249        let fallback = self.config.fallback.or(config.fallback);
250
251        let resource = extractor(&req);
252        let entry_builder = EntryBuilder::new(resource)
253            .with_traffic_type(sentinel_core::base::TrafficType::Inbound);
254
255        match entry_builder.build() {
256            Ok(entry) => {
257                entry.exit();
258            }
259            Err(err) => {
260                match fallback {
261                    Some(fallback) => fallback(req, err),
262                    None => {
263                        if let Some(state) = req.rocket().state::<SentinelFairingState>() {
264                            if let Ok(mut msg) = state.msg.lock() {
265                                *msg = format!(
266                                    "Request to {:?} blocked by sentinel: {:?}",
267                                    req.uri().path(),
268                                    err
269                                );
270                            }
271                            req.set_uri(Origin::parse(&state.uri).unwrap());
273                        }
274                    }
275                }
276            }
277        };
278    }
279}