1use crate::{
4 helper,
5 response::{Rejection, Response, ResponseCode},
6};
7use multer::Multipart;
8use serde::de::DeserializeOwned;
9use std::{borrow::Cow, net::IpAddr, str::FromStr, time::Instant};
10use zino_channel::{CloudEvent, Subscription};
11use zino_core::{
12 application::Agent,
13 error::Error,
14 extension::HeaderMapExt,
15 model::{ModelHooks, Query},
16 trace::{TraceContext, TraceState},
17 warn, JsonValue, Map, SharedString, Uuid,
18};
19use zino_storage::NamedFile;
20
21#[cfg(feature = "auth")]
22use zino_auth::{AccessKeyId, Authentication, ParseSecurityTokenError, SecurityToken, SessionId};
23
24#[cfg(feature = "auth")]
25use zino_core::{datetime::DateTime, extension::JsonObjectExt, validation::Validation};
26
27#[cfg(feature = "cookie")]
28use cookie::{Cookie, SameSite};
29
30#[cfg(feature = "jwt")]
31use jwt_simple::algorithms::MACLike;
32#[cfg(feature = "jwt")]
33use zino_auth::JwtClaims;
34
35#[cfg(any(feature = "cookie", feature = "jwt"))]
36use std::time::Duration;
37
38#[cfg(feature = "i18n")]
39use crate::i18n;
40#[cfg(feature = "i18n")]
41use fluent::FluentArgs;
42#[cfg(feature = "i18n")]
43use unic_langid::LanguageIdentifier;
44
45mod context;
46
47pub use context::Context;
48
49pub trait RequestContext {
51 type Method: AsRef<str>;
53 type Uri;
55
56 fn request_method(&self) -> &Self::Method;
58
59 fn original_uri(&self) -> &Self::Uri;
61
62 fn matched_route(&self) -> Cow<'_, str>;
64
65 fn request_path(&self) -> &str;
67
68 fn get_query_string(&self) -> Option<&str>;
70
71 fn get_header(&self, name: &str) -> Option<&str>;
73
74 fn client_ip(&self) -> Option<IpAddr>;
76
77 fn get_context(&self) -> Option<Context>;
79
80 fn get_data<T: Clone + Send + Sync + 'static>(&self) -> Option<T>;
82
83 fn set_data<T: Clone + Send + Sync + 'static>(&mut self, value: T) -> Option<T>;
86
87 async fn read_body_bytes(&mut self) -> Result<Vec<u8>, Error>;
89
90 #[inline]
92 fn path_segments(&self) -> Vec<&str> {
93 self.request_path().trim_matches('/').split('/').collect()
94 }
95
96 fn new_context(&self) -> Context {
98 #[cfg(feature = "metrics")]
100 {
101 metrics::gauge!("zino_http_requests_in_flight").increment(1.0);
102 metrics::counter!(
103 "zino_http_requests_total",
104 "method" => self.request_method().as_ref().to_owned(),
105 "route" => self.matched_route().into_owned(),
106 )
107 .increment(1);
108 }
109
110 let request_id = self
112 .get_header("x-request-id")
113 .and_then(|s| s.parse().ok())
114 .unwrap_or_else(Uuid::now_v7);
115 let trace_id = self
116 .get_trace_context()
117 .map_or_else(Uuid::now_v7, |t| Uuid::from_u128(t.trace_id()));
118 let session_id = self
119 .get_header("x-session-id")
120 .or_else(|| self.get_header("session_id"))
121 .and_then(|s| s.parse().ok());
122
123 let mut ctx = Context::new(request_id);
125 ctx.set_instance(self.request_path());
126 ctx.set_trace_id(trace_id);
127 ctx.set_session_id(session_id);
128
129 #[cfg(feature = "i18n")]
131 {
132 #[cfg(feature = "cookie")]
133 if let Some(cookie) = self.get_cookie("locale") {
134 ctx.set_locale(cookie.value());
135 return ctx;
136 }
137
138 let supported_locales = i18n::SUPPORTED_LOCALES.as_slice();
139 let locale = self
140 .get_header("accept-language")
141 .and_then(|languages| helper::select_language(languages, supported_locales))
142 .unwrap_or(&i18n::DEFAULT_LOCALE);
143 ctx.set_locale(locale);
144 }
145 ctx
146 }
147
148 #[inline]
150 fn get_trace_context(&self) -> Option<TraceContext> {
151 let traceparent = self.get_header("traceparent")?;
152 let mut trace_context = TraceContext::from_traceparent(traceparent)?;
153 if let Some(tracestate) = self.get_header("tracestate") {
154 *trace_context.trace_state_mut() = TraceState::from_tracestate(tracestate);
155 }
156 Some(trace_context)
157 }
158
159 fn new_trace_context(&self) -> TraceContext {
161 let mut trace_context = self
162 .get_trace_context()
163 .or_else(|| {
164 self.get_context()
165 .map(|ctx| TraceContext::with_trace_id(ctx.trace_id()))
166 })
167 .map(|t| t.child())
168 .unwrap_or_default();
169 trace_context.record_trace_state();
170 trace_context
171 }
172
173 #[cfg(feature = "cookie")]
175 fn new_cookie(
176 &self,
177 name: SharedString,
178 value: SharedString,
179 max_age: Option<Duration>,
180 ) -> Cookie<'static> {
181 let mut cookie_builder = Cookie::build((name, value))
182 .http_only(true)
183 .secure(true)
184 .same_site(SameSite::Lax)
185 .path(self.request_path().to_owned());
186 if let Some(max_age) = max_age.and_then(|d| d.try_into().ok()) {
187 cookie_builder = cookie_builder.max_age(max_age);
188 }
189 cookie_builder.build()
190 }
191
192 #[cfg(feature = "cookie")]
194 fn get_cookie(&self, name: &str) -> Option<Cookie<'_>> {
195 self.get_header("cookie")?.split(';').find_map(|cookie| {
196 if let Some((key, value)) = cookie.split_once('=') {
197 (key == name).then(|| Cookie::new(key, value))
198 } else {
199 None
200 }
201 })
202 }
203
204 #[inline]
206 fn start_time(&self) -> Instant {
207 self.get_context()
208 .map(|ctx| ctx.start_time())
209 .unwrap_or_else(Instant::now)
210 }
211
212 #[inline]
214 fn instance(&self) -> String {
215 self.get_context()
216 .map(|ctx| ctx.instance().to_owned())
217 .unwrap_or_else(|| self.request_path().to_owned())
218 }
219
220 #[inline]
222 fn request_id(&self) -> Uuid {
223 self.get_context()
224 .map(|ctx| ctx.request_id())
225 .unwrap_or_default()
226 }
227
228 #[inline]
230 fn trace_id(&self) -> Uuid {
231 self.get_context()
232 .map(|ctx| ctx.trace_id())
233 .unwrap_or_default()
234 }
235
236 #[inline]
238 fn session_id(&self) -> Option<String> {
239 self.get_context()
240 .and_then(|ctx| ctx.session_id().map(|s| s.to_owned()))
241 }
242
243 #[cfg(feature = "i18n")]
245 #[inline]
246 fn locale(&self) -> Option<LanguageIdentifier> {
247 self.get_context().and_then(|ctx| ctx.locale().cloned())
248 }
249
250 fn data_type(&self) -> Option<&str> {
257 self.get_header("content-type")
258 .map(|content_type| {
259 if let Some((essence, _)) = content_type.split_once(';') {
260 essence
261 } else {
262 content_type
263 }
264 })
265 .map(helper::get_data_type)
266 }
267
268 fn get_param(&self, name: &str) -> Option<&str> {
277 const CAPTURES: [char; 4] = [':', '*', '{', '}'];
278 if let Some(index) = self
279 .matched_route()
280 .split('/')
281 .position(|segment| segment.trim_matches(CAPTURES.as_slice()) == name)
282 {
283 self.request_path().splitn(index + 2, '/').nth(index)
284 } else {
285 None
286 }
287 }
288
289 fn decode_param(&self, name: &str) -> Result<Cow<'_, str>, Rejection> {
291 if let Some(value) = self.get_param(name) {
292 percent_encoding::percent_decode_str(value)
293 .decode_utf8()
294 .map_err(|err| Rejection::from_validation_entry(name.to_owned(), err).context(self))
295 } else {
296 Err(Rejection::from_validation_entry(
297 name.to_owned(),
298 warn!("param `{}` does not exist", name),
299 )
300 .context(self))
301 }
302 }
303
304 fn parse_param<T: FromStr<Err: Into<Error>>>(&self, name: &str) -> Result<T, Rejection> {
307 if let Some(param) = self.get_param(name) {
308 percent_encoding::percent_decode_str(param)
309 .decode_utf8_lossy()
310 .parse::<T>()
311 .map_err(|err| Rejection::from_validation_entry(name.to_owned(), err).context(self))
312 } else {
313 Err(Rejection::from_validation_entry(
314 name.to_owned(),
315 warn!("param `{}` does not exist", name),
316 )
317 .context(self))
318 }
319 }
320
321 fn get_query(&self, name: &str) -> Option<&str> {
329 self.get_query_string()?.split('&').find_map(|param| {
330 if let Some((key, value)) = param.split_once('=') {
331 (key == name).then_some(value)
332 } else {
333 None
334 }
335 })
336 }
337
338 fn decode_query(&self, name: &str) -> Result<Cow<'_, str>, Rejection> {
340 if let Some(value) = self.get_query(name) {
341 percent_encoding::percent_decode_str(value)
342 .decode_utf8()
343 .map_err(|err| Rejection::from_validation_entry(name.to_owned(), err).context(self))
344 } else {
345 Err(Rejection::from_validation_entry(
346 name.to_owned(),
347 warn!("query value `{}` does not exist", name),
348 )
349 .context(self))
350 }
351 }
352
353 fn parse_query<T: Default + DeserializeOwned>(&self) -> Result<T, Rejection> {
357 if let Some(query) = self.get_query_string() {
358 #[cfg(feature = "jwt")]
359 if let Some(timestamp) = self.get_query("timestamp").and_then(|s| s.parse().ok()) {
360 let duration = DateTime::from_timestamp(timestamp).span_between_now();
361 if duration > zino_auth::default_time_tolerance() {
362 let err = warn!("timestamp `{}` can not be trusted", timestamp);
363 let rejection = Rejection::from_validation_entry("timestamp", err);
364 return Err(rejection.context(self));
365 }
366 }
367 serde_qs::from_str::<T>(query)
368 .map_err(|err| Rejection::from_validation_entry("query", err).context(self))
369 } else {
370 Ok(T::default())
371 }
372 }
373
374 async fn parse_body<T: DeserializeOwned>(&mut self) -> Result<T, Rejection> {
384 let data_type = self.data_type().unwrap_or("form");
385 if data_type.contains('/') {
386 let err = warn!(
387 "deserialization of the data type `{}` is unsupported",
388 data_type
389 );
390 let rejection = Rejection::from_validation_entry("data_type", err).context(self);
391 return Err(rejection);
392 }
393
394 let is_form = data_type == "form";
395 let bytes = self
396 .read_body_bytes()
397 .await
398 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))?;
399 if is_form {
400 serde_qs::from_bytes(&bytes)
401 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))
402 } else {
403 serde_json::from_slice(&bytes)
404 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))
405 }
406 }
407
408 async fn parse_multipart(&mut self) -> Result<Multipart, Rejection> {
410 let Some(content_type) = self.get_header("content-type") else {
411 return Err(Rejection::from_validation_entry(
412 "content_type",
413 warn!("invalid `content-type` header"),
414 )
415 .context(self));
416 };
417 match multer::parse_boundary(content_type) {
418 Ok(boundary) => {
419 let result = self.read_body_bytes().await.map_err(|err| err.to_string());
420 let stream = futures::stream::once(async { result });
421 Ok(Multipart::new(stream, boundary))
422 }
423 Err(err) => Err(Rejection::from_validation_entry("boundary", err).context(self)),
424 }
425 }
426
427 async fn parse_file(&mut self) -> Result<NamedFile, Rejection> {
429 let multipart = self.parse_multipart().await?;
430 NamedFile::try_from_multipart(multipart)
431 .await
432 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))
433 }
434
435 async fn parse_files(&mut self) -> Result<Vec<NamedFile>, Rejection> {
437 let multipart = self.parse_multipart().await?;
438 NamedFile::try_collect_from_multipart(multipart)
439 .await
440 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))
441 }
442
443 async fn parse_form<T: DeserializeOwned>(
445 &mut self,
446 name: &str,
447 ) -> Result<(Option<T>, Vec<NamedFile>), Rejection> {
448 let multipart = self.parse_multipart().await?;
449 helper::parse_form(multipart, name)
450 .await
451 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))
452 }
453
454 async fn parse_form_data<T: DeserializeOwned>(
456 &mut self,
457 ) -> Result<(T, Vec<NamedFile>), Rejection> {
458 let multipart = self.parse_multipart().await?;
459 helper::parse_form_data(multipart)
460 .await
461 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))
462 }
463
464 #[cfg(feature = "auth")]
469 fn parse_authentication(&self) -> Result<Authentication, Rejection> {
470 let method = self.request_method();
471 let query = self.parse_query::<Map>().unwrap_or_default();
472 let mut authentication = Authentication::new(method.as_ref());
473 let mut validation = Validation::new();
474 if let Some(signature) = query.get_str("signature") {
475 authentication.set_signature(signature.to_owned());
476 if let Some(access_key_id) = query.parse_string("access_key_id") {
477 authentication.set_access_key_id(access_key_id);
478 } else {
479 validation.record("access_key_id", "should be nonempty");
480 }
481 if let Some(Ok(secs)) = query.parse_i64("expires") {
482 if DateTime::now().timestamp() <= secs {
483 let expires = DateTime::from_timestamp(secs);
484 authentication.set_expires(Some(expires));
485 } else {
486 validation.record("expires", "valid period has expired");
487 }
488 } else {
489 validation.record("expires", "invalid timestamp");
490 }
491 if !validation.is_success() {
492 return Err(Rejection::bad_request(validation).context(self));
493 }
494 } else if let Some(authorization) = self.get_header("authorization") {
495 if let Some((service_name, token)) = authorization.split_once(' ') {
496 authentication.set_service_name(service_name);
497 if let Some((access_key_id, signature)) = token.split_once(':') {
498 authentication.set_access_key_id(access_key_id);
499 authentication.set_signature(signature.to_owned());
500 } else {
501 validation.record("authorization", "invalid header value");
502 }
503 } else {
504 validation.record("authorization", "invalid service name");
505 }
506 if !validation.is_success() {
507 return Err(Rejection::bad_request(validation).context(self));
508 }
509 }
510 if let Some(content_md5) = self.get_header("content-md5") {
511 authentication.set_content_md5(content_md5.to_owned());
512 }
513 if let Some(date) = self.get_header("date") {
514 match DateTime::parse_utc_str(date) {
515 Ok(date) => {
516 #[cfg(feature = "jwt")]
517 if date.span_between_now() <= zino_auth::default_time_tolerance() {
518 authentication.set_date_header("date", date);
519 } else {
520 validation.record("date", "untrusted date");
521 }
522 #[cfg(not(feature = "jwt"))]
523 authentication.set_date_header("date", date);
524 }
525 Err(err) => {
526 validation.record_fail("date", err);
527 return Err(Rejection::bad_request(validation).context(self));
528 }
529 }
530 }
531 authentication.set_content_type(self.get_header("content-type").map(|s| s.to_owned()));
532 authentication.set_resource(self.request_path().to_owned(), None);
533 Ok(authentication)
534 }
535
536 #[cfg(feature = "auth")]
540 fn parse_access_key_id(&self) -> Result<AccessKeyId, Rejection> {
541 if let Some(access_key_id) = self.get_query("access_key_id") {
542 Ok(access_key_id.into())
543 } else {
544 let mut validation = Validation::new();
545 if let Some(authorization) = self.get_header("authorization") {
546 if let Some((_, token)) = authorization.split_once(' ') {
547 let access_key_id = if let Some((access_key_id, _)) = token.split_once(':') {
548 access_key_id
549 } else {
550 token
551 };
552 return Ok(access_key_id.into());
553 } else {
554 validation.record("authorization", "invalid service name");
555 }
556 } else {
557 validation.record("authorization", "invalid value to get the access key id");
558 }
559 Err(Rejection::bad_request(validation).context(self))
560 }
561 }
562
563 #[cfg(feature = "auth")]
566 fn parse_security_token(&self, key: &[u8]) -> Result<SecurityToken, Rejection> {
567 use ParseSecurityTokenError::*;
568 let query = self.parse_query::<Map>()?;
569 let mut validation = Validation::new();
570 if let Some(token) = self
571 .get_header("x-security-token")
572 .or_else(|| query.get_str("security_token"))
573 {
574 match SecurityToken::parse_with(token.to_owned(), key) {
575 Ok(security_token) => {
576 if let Some(access_key_id) = query.get_str("access_key_id") {
577 if security_token.access_key_id().as_str() != access_key_id {
578 validation.record("access_key_id", "untrusted access key ID");
579 }
580 }
581 if let Some(Ok(expires)) = query.parse_i64("expires") {
582 if security_token.expires_at().timestamp() != expires {
583 validation.record("expires", "untrusted timestamp");
584 }
585 }
586 if validation.is_success() {
587 return Ok(security_token);
588 }
589 }
590 Err(err) => {
591 let field = match err {
592 DecodeError(_) | InvalidFormat => "security_token",
593 ParseExpiresError(_) | ValidPeriodExpired(_) => "expires",
594 };
595 validation.record_fail(field, err);
596 }
597 }
598 } else {
599 validation.record("security_token", "should be nonempty");
600 }
601 Err(Rejection::bad_request(validation).context(self))
602 }
603
604 #[cfg(feature = "auth")]
607 fn parse_session_id(&self) -> Result<SessionId, Rejection> {
608 self.get_header("x-session-id")
609 .or_else(|| self.get_header("session-id"))
610 .ok_or_else(|| {
611 Rejection::from_validation_entry(
612 "session_id",
613 warn!("a `session-id` or `x-session-id` header is required"),
614 )
615 .context(self)
616 })
617 .and_then(|session_id| {
618 SessionId::parse(session_id).map_err(|err| {
619 Rejection::from_validation_entry("session_id", err).context(self)
620 })
621 })
622 }
623
624 #[cfg(feature = "jwt")]
628 fn parse_jwt_claims<T, K>(&self, key: &K) -> Result<JwtClaims<T>, Rejection>
629 where
630 T: Default + serde::Serialize + DeserializeOwned,
631 K: MACLike,
632 {
633 let (param, mut token) = match self.get_query("access_token") {
634 Some(access_token) => ("access_token", access_token),
635 None => ("authorization", ""),
636 };
637 if let Some(authorization) = self.get_header("authorization") {
638 token = authorization
639 .strip_prefix("Bearer ")
640 .unwrap_or(authorization);
641 }
642 if token.is_empty() {
643 let mut validation = Validation::new();
644 validation.record(param, "JWT token is absent");
645 return Err(Rejection::bad_request(validation).context(self));
646 }
647
648 let mut options = zino_auth::default_verification_options();
649 options.reject_before = self
650 .get_query("timestamp")
651 .and_then(|s| s.parse().ok())
652 .map(|i| Duration::from_secs(i).into());
653 options.required_nonce = self.get_query("nonce").map(|s| s.to_owned());
654
655 match key.verify_token(token, Some(options)) {
656 Ok(claims) => Ok(claims.into()),
657 Err(err) => {
658 let message = format!("401 Unauthorized: {err}");
659 Err(Rejection::with_message(message).context(self))
660 }
661 }
662 }
663
664 fn query_validation<S>(&self, query: &mut Query) -> Result<Response<S>, Rejection>
667 where
668 Self: Sized,
669 S: ResponseCode,
670 {
671 match self.parse_query() {
672 Ok(data) => {
673 let validation = query.read_map(&data);
674 if validation.is_success() {
675 Ok(Response::with_context(S::OK, self))
676 } else {
677 Err(Rejection::bad_request(validation).context(self))
678 }
679 }
680 Err(rejection) => Err(rejection),
681 }
682 }
683
684 async fn model_validation<M, S>(&mut self, model: &mut M) -> Result<Response<S>, Rejection>
687 where
688 Self: Sized,
689 M: ModelHooks,
690 S: ResponseCode,
691 {
692 let data_type = self.data_type().unwrap_or("form");
693 if data_type.contains('/') {
694 let err = warn!(
695 "deserialization of the data type `{}` is unsupported",
696 data_type
697 );
698 let rejection = Rejection::from_validation_entry("data_type", err).context(self);
699 return Err(rejection);
700 }
701 M::before_extract()
702 .await
703 .map_err(|err| Rejection::from_error(err).context(self))?;
704
705 let is_form = data_type == "form";
706 let bytes = self
707 .read_body_bytes()
708 .await
709 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))?;
710 let extension = self.get_data::<M::Extension>();
711 if is_form {
712 let mut data = serde_qs::from_bytes(&bytes)
713 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))?;
714 match M::before_validation(&mut data, extension.as_ref()).await {
715 Ok(()) => {
716 let validation = model.read_map(&data);
717 model
718 .after_validation(&mut data)
719 .await
720 .map_err(|err| Rejection::from_error(err).context(self))?;
721 if let Some(extension) = extension {
722 model
723 .after_extract(extension)
724 .await
725 .map_err(|err| Rejection::from_error(err).context(self))?;
726 }
727 if validation.is_success() {
728 Ok(Response::with_context(S::OK, self))
729 } else {
730 Err(Rejection::bad_request(validation).context(self))
731 }
732 }
733 Err(err) => Err(Rejection::from_error(err).context(self)),
734 }
735 } else {
736 let mut data = serde_json::from_slice(&bytes)
737 .map_err(|err| Rejection::from_validation_entry("body", err).context(self))?;
738 match M::before_validation(&mut data, extension.as_ref()).await {
739 Ok(()) => {
740 let validation = model.read_map(&data);
741 model
742 .after_validation(&mut data)
743 .await
744 .map_err(|err| Rejection::from_error(err).context(self))?;
745 if let Some(extension) = extension {
746 model
747 .after_extract(extension)
748 .await
749 .map_err(|err| Rejection::from_error(err).context(self))?;
750 }
751 if validation.is_success() {
752 Ok(Response::with_context(S::OK, self))
753 } else {
754 Err(Rejection::bad_request(validation).context(self))
755 }
756 }
757 Err(err) => Err(Rejection::from_error(err).context(self)),
758 }
759 }
760 }
761
762 async fn fetch(&self, url: &str, options: Option<&Map>) -> Result<reqwest::Response, Error> {
764 let trace_context = self.new_trace_context();
765 Agent::request_builder(url, options)?
766 .header("traceparent", trace_context.traceparent())
767 .header("tracestate", trace_context.tracestate())
768 .send()
769 .await
770 .map_err(Error::from)
771 }
772
773 async fn fetch_json<T: DeserializeOwned>(
776 &self,
777 url: &str,
778 options: Option<&Map>,
779 ) -> Result<T, Error> {
780 let response = self.fetch(url, options).await?.error_for_status()?;
781 let data = if response.headers().has_json_content_type() {
782 response.json().await?
783 } else {
784 let text = response.text().await?;
785 serde_json::from_str(&text)?
786 };
787 Ok(data)
788 }
789
790 #[cfg(feature = "i18n")]
792 fn translate(&self, message: &str, args: Option<FluentArgs>) -> Result<SharedString, Error> {
793 if let Some(locale) = self.locale() {
794 i18n::translate(&locale, message, args)
795 } else {
796 let default_locale = i18n::DEFAULT_LOCALE.parse()?;
797 i18n::translate(&default_locale, message, args)
798 }
799 }
800
801 fn subscription(&self) -> Subscription {
803 let mut subscription = self.parse_query::<Subscription>().unwrap_or_default();
804 if subscription.session_id().is_none() {
805 if let Some(session_id) = self.session_id() {
806 subscription.set_session_id(Some(session_id));
807 }
808 }
809 subscription
810 }
811
812 fn cloud_event(&self, event_type: SharedString, data: JsonValue) -> CloudEvent {
814 let id = self.request_id();
815 let source = self.instance();
816 let mut event = CloudEvent::new(id, source, event_type);
817 if let Some(session_id) = self.session_id() {
818 event.set_session_id(session_id);
819 }
820 event.set_data(data);
821 event
822 }
823}