tower_async_http/decompression/request/
service.rs

1use super::layer::RequestDecompressionLayer;
2use crate::compression_utils::CompressionLevel;
3use crate::{
4    compression_utils::AcceptEncoding, decompression::body::BodyInner,
5    decompression::DecompressionBody, BoxError,
6};
7use bytes::Buf;
8use http::{header, HeaderValue, Request, Response, StatusCode};
9use http_body::Body;
10use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Empty};
11use tower_async_service::Service;
12
13#[cfg(any(
14    feature = "decompression-gzip",
15    feature = "decompression-deflate",
16    feature = "decompression-br",
17    feature = "decompression-zstd",
18))]
19use crate::content_encoding::SupportedEncodings;
20
21/// Decompresses request bodies and calls its underlying service.
22///
23/// Transparently decompresses request bodies based on the `Content-Encoding` header.
24/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type`
25/// status code will be returned with the accepted encodings in the `Accept-Encoding` header.
26///
27/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type` but
28/// will call the underlying service with the unmodified request if the encoding is not supported.
29/// This is disabled by default.
30///
31/// See the [module docs](crate::decompression) for more details.
32#[derive(Debug, Clone)]
33pub struct RequestDecompression<S> {
34    pub(super) inner: S,
35    pub(super) accept: AcceptEncoding,
36    pub(super) pass_through_unaccepted: bool,
37}
38
39impl<S, ReqBody, ResBody, D> Service<Request<ReqBody>> for RequestDecompression<S>
40where
41    S: Service<Request<DecompressionBody<ReqBody>>, Response = Response<ResBody>>,
42    ReqBody: Body,
43    ResBody: Body<Data = D> + Send + 'static,
44    S::Error: Into<BoxError>,
45    <ResBody as Body>::Error: Into<BoxError>,
46    D: Buf + 'static,
47{
48    type Response = Response<UnsyncBoxBody<D, BoxError>>;
49    type Error = BoxError;
50
51    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
52        let (mut parts, body) = req.into_parts();
53
54        let body =
55            if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
56                match entry.get().as_bytes() {
57                    #[cfg(feature = "decompression-gzip")]
58                    b"gzip" if self.accept.gzip() => {
59                        entry.remove();
60                        parts.headers.remove(header::CONTENT_LENGTH);
61                        BodyInner::gzip(crate::compression_utils::WrapBody::new(
62                            body,
63                            CompressionLevel::default(),
64                        ))
65                    }
66                    #[cfg(feature = "decompression-deflate")]
67                    b"deflate" if self.accept.deflate() => {
68                        entry.remove();
69                        parts.headers.remove(header::CONTENT_LENGTH);
70                        BodyInner::deflate(crate::compression_utils::WrapBody::new(
71                            body,
72                            CompressionLevel::default(),
73                        ))
74                    }
75                    #[cfg(feature = "decompression-br")]
76                    b"br" if self.accept.br() => {
77                        entry.remove();
78                        parts.headers.remove(header::CONTENT_LENGTH);
79                        BodyInner::brotli(crate::compression_utils::WrapBody::new(
80                            body,
81                            CompressionLevel::default(),
82                        ))
83                    }
84                    #[cfg(feature = "decompression-zstd")]
85                    b"zstd" if self.accept.zstd() => {
86                        entry.remove();
87                        parts.headers.remove(header::CONTENT_LENGTH);
88                        BodyInner::zstd(crate::compression_utils::WrapBody::new(
89                            body,
90                            CompressionLevel::default(),
91                        ))
92                    }
93                    b"identity" => BodyInner::identity(body),
94                    _ if self.pass_through_unaccepted => BodyInner::identity(body),
95                    _ => return unsupported_encoding(self.accept).await,
96                }
97            } else {
98                BodyInner::identity(body)
99            };
100        let body = DecompressionBody::new(body);
101        let req = Request::from_parts(parts, body);
102        self.inner
103            .call(req)
104            .await
105            .map(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
106            .map_err(Into::into)
107    }
108}
109
110async fn unsupported_encoding<D>(
111    accept: AcceptEncoding,
112) -> Result<Response<UnsyncBoxBody<D, BoxError>>, BoxError>
113where
114    D: Buf + 'static,
115{
116    let res = Response::builder()
117        .header(
118            header::ACCEPT_ENCODING,
119            accept
120                .to_header_value()
121                .unwrap_or(HeaderValue::from_static("identity")),
122        )
123        .status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
124        .body(Empty::new().map_err(Into::into).boxed_unsync())
125        .unwrap();
126    Ok(res)
127}
128
129impl<S> RequestDecompression<S> {
130    /// Creates a new `RequestDecompression` wrapping the `service`.
131    pub fn new(service: S) -> Self {
132        Self {
133            inner: service,
134            accept: AcceptEncoding::default(),
135            pass_through_unaccepted: false,
136        }
137    }
138
139    define_inner_service_accessors!();
140
141    /// Returns a new [`Layer`] that wraps services with a `RequestDecompression` middleware.
142    ///
143    /// [`Layer`]: tower_async_layer::Layer
144    pub fn layer() -> RequestDecompressionLayer {
145        RequestDecompressionLayer::new()
146    }
147
148    /// Passes through the request even when the encoding is not supported.
149    ///
150    /// By default pass-through is disabled.
151    pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self {
152        self.pass_through_unaccepted = enabled;
153        self
154    }
155
156    /// Sets whether to support gzip encoding.
157    #[cfg(feature = "decompression-gzip")]
158    pub fn gzip(mut self, enable: bool) -> Self {
159        self.accept.set_gzip(enable);
160        self
161    }
162
163    /// Sets whether to support Deflate encoding.
164    #[cfg(feature = "decompression-deflate")]
165    pub fn deflate(mut self, enable: bool) -> Self {
166        self.accept.set_deflate(enable);
167        self
168    }
169
170    /// Sets whether to support Brotli encoding.
171    #[cfg(feature = "decompression-br")]
172    pub fn br(mut self, enable: bool) -> Self {
173        self.accept.set_br(enable);
174        self
175    }
176
177    /// Sets whether to support Zstd encoding.
178    #[cfg(feature = "decompression-zstd")]
179    pub fn zstd(mut self, enable: bool) -> Self {
180        self.accept.set_zstd(enable);
181        self
182    }
183
184    /// Disables support for gzip encoding.
185    ///
186    /// This method is available even if the `gzip` crate feature is disabled.
187    pub fn no_gzip(mut self) -> Self {
188        self.accept.set_gzip(false);
189        self
190    }
191
192    /// Disables support for Deflate encoding.
193    ///
194    /// This method is available even if the `deflate` crate feature is disabled.
195    pub fn no_deflate(mut self) -> Self {
196        self.accept.set_deflate(false);
197        self
198    }
199
200    /// Disables support for Brotli encoding.
201    ///
202    /// This method is available even if the `br` crate feature is disabled.
203    pub fn no_br(mut self) -> Self {
204        self.accept.set_br(false);
205        self
206    }
207
208    /// Disables support for Zstd encoding.
209    ///
210    /// This method is available even if the `zstd` crate feature is disabled.
211    pub fn no_zstd(mut self) -> Self {
212        self.accept.set_zstd(false);
213        self
214    }
215}