tower_async_http/decompression/
service.rs

1use super::{body::BodyInner, DecompressionBody, DecompressionLayer};
2use crate::{
3    compression_utils::{AcceptEncoding, CompressionLevel, WrapBody},
4    content_encoding::SupportedEncodings,
5};
6use http::{
7    header::{self, ACCEPT_ENCODING},
8    Request, Response,
9};
10use http_body::Body;
11use tower_async_service::Service;
12
13/// Decompresses response bodies of the underlying service.
14///
15/// This adds the `Accept-Encoding` header to requests and transparently decompresses response
16/// bodies based on the `Content-Encoding` header.
17///
18/// See the [module docs](crate::decompression) for more details.
19#[derive(Debug, Clone)]
20pub struct Decompression<S> {
21    pub(crate) inner: S,
22    pub(crate) accept: AcceptEncoding,
23}
24
25impl<S> Decompression<S> {
26    /// Creates a new `Decompression` wrapping the `service`.
27    pub fn new(service: S) -> Self {
28        Self {
29            inner: service,
30            accept: AcceptEncoding::default(),
31        }
32    }
33
34    define_inner_service_accessors!();
35
36    /// Returns a new [`Layer`] that wraps services with a `Decompression` middleware.
37    ///
38    /// [`Layer`]: tower_async_layer::Layer
39    pub fn layer() -> DecompressionLayer {
40        DecompressionLayer::new()
41    }
42
43    /// Sets whether to request the gzip encoding.
44    #[cfg(feature = "decompression-gzip")]
45    pub fn gzip(mut self, enable: bool) -> Self {
46        self.accept.set_gzip(enable);
47        self
48    }
49
50    /// Sets whether to request the Deflate encoding.
51    #[cfg(feature = "decompression-deflate")]
52    pub fn deflate(mut self, enable: bool) -> Self {
53        self.accept.set_deflate(enable);
54        self
55    }
56
57    /// Sets whether to request the Brotli encoding.
58    #[cfg(feature = "decompression-br")]
59    pub fn br(mut self, enable: bool) -> Self {
60        self.accept.set_br(enable);
61        self
62    }
63
64    /// Sets whether to request the Zstd encoding.
65    #[cfg(feature = "decompression-zstd")]
66    pub fn zstd(mut self, enable: bool) -> Self {
67        self.accept.set_zstd(enable);
68        self
69    }
70
71    /// Disables the gzip encoding.
72    ///
73    /// This method is available even if the `gzip` crate feature is disabled.
74    pub fn no_gzip(mut self) -> Self {
75        self.accept.set_gzip(false);
76        self
77    }
78
79    /// Disables the Deflate encoding.
80    ///
81    /// This method is available even if the `deflate` crate feature is disabled.
82    pub fn no_deflate(mut self) -> Self {
83        self.accept.set_deflate(false);
84        self
85    }
86
87    /// Disables the Brotli encoding.
88    ///
89    /// This method is available even if the `br` crate feature is disabled.
90    pub fn no_br(mut self) -> Self {
91        self.accept.set_br(false);
92        self
93    }
94
95    /// Disables the Zstd encoding.
96    ///
97    /// This method is available even if the `zstd` crate feature is disabled.
98    pub fn no_zstd(mut self) -> Self {
99        self.accept.set_zstd(false);
100        self
101    }
102}
103
104impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Decompression<S>
105where
106    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
107    ResBody: Body,
108{
109    type Response = Response<DecompressionBody<ResBody>>;
110    type Error = S::Error;
111
112    async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
113        if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) {
114            if let Some(accept) = self.accept.to_header_value() {
115                entry.insert(accept);
116            }
117        }
118
119        let res = self.inner.call(req).await?;
120
121        let (mut parts, body) = res.into_parts();
122
123        let res =
124            if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
125                let body = match entry.get().as_bytes() {
126                    #[cfg(feature = "decompression-gzip")]
127                    b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip(
128                        WrapBody::new(body, CompressionLevel::default()),
129                    )),
130
131                    #[cfg(feature = "decompression-deflate")]
132                    b"deflate" if self.accept.deflate() => DecompressionBody::new(
133                        BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())),
134                    ),
135
136                    #[cfg(feature = "decompression-br")]
137                    b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli(
138                        WrapBody::new(body, CompressionLevel::default()),
139                    )),
140
141                    #[cfg(feature = "decompression-zstd")]
142                    b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd(
143                        WrapBody::new(body, CompressionLevel::default()),
144                    )),
145
146                    _ => {
147                        return Ok(Response::from_parts(
148                            parts,
149                            DecompressionBody::new(BodyInner::identity(body)),
150                        ))
151                    }
152                };
153
154                entry.remove();
155                parts.headers.remove(header::CONTENT_LENGTH);
156
157                Response::from_parts(parts, body)
158            } else {
159                Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body)))
160            };
161
162        Ok(res)
163    }
164}