tower_async_http/decompression/
service.rs1use super::{body::BodyInner, DecompressionBody, DecompressionLayer};
2use crate::{
3 compression_utils::{AcceptEncoding, CompressionLevel, WrapBody},
4 content_encoding::SupportedEncodings,
5};
6use http::{
7 header::{self, ACCEPT_ENCODING},
8 Request, Response,
9};
10use http_body::Body;
11use tower_async_service::Service;
12
13#[derive(Debug, Clone)]
20pub struct Decompression<S> {
21 pub(crate) inner: S,
22 pub(crate) accept: AcceptEncoding,
23}
24
25impl<S> Decompression<S> {
26 pub fn new(service: S) -> Self {
28 Self {
29 inner: service,
30 accept: AcceptEncoding::default(),
31 }
32 }
33
34 define_inner_service_accessors!();
35
36 pub fn layer() -> DecompressionLayer {
40 DecompressionLayer::new()
41 }
42
43 #[cfg(feature = "decompression-gzip")]
45 pub fn gzip(mut self, enable: bool) -> Self {
46 self.accept.set_gzip(enable);
47 self
48 }
49
50 #[cfg(feature = "decompression-deflate")]
52 pub fn deflate(mut self, enable: bool) -> Self {
53 self.accept.set_deflate(enable);
54 self
55 }
56
57 #[cfg(feature = "decompression-br")]
59 pub fn br(mut self, enable: bool) -> Self {
60 self.accept.set_br(enable);
61 self
62 }
63
64 #[cfg(feature = "decompression-zstd")]
66 pub fn zstd(mut self, enable: bool) -> Self {
67 self.accept.set_zstd(enable);
68 self
69 }
70
71 pub fn no_gzip(mut self) -> Self {
75 self.accept.set_gzip(false);
76 self
77 }
78
79 pub fn no_deflate(mut self) -> Self {
83 self.accept.set_deflate(false);
84 self
85 }
86
87 pub fn no_br(mut self) -> Self {
91 self.accept.set_br(false);
92 self
93 }
94
95 pub fn no_zstd(mut self) -> Self {
99 self.accept.set_zstd(false);
100 self
101 }
102}
103
104impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Decompression<S>
105where
106 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
107 ResBody: Body,
108{
109 type Response = Response<DecompressionBody<ResBody>>;
110 type Error = S::Error;
111
112 async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
113 if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) {
114 if let Some(accept) = self.accept.to_header_value() {
115 entry.insert(accept);
116 }
117 }
118
119 let res = self.inner.call(req).await?;
120
121 let (mut parts, body) = res.into_parts();
122
123 let res =
124 if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
125 let body = match entry.get().as_bytes() {
126 #[cfg(feature = "decompression-gzip")]
127 b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip(
128 WrapBody::new(body, CompressionLevel::default()),
129 )),
130
131 #[cfg(feature = "decompression-deflate")]
132 b"deflate" if self.accept.deflate() => DecompressionBody::new(
133 BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())),
134 ),
135
136 #[cfg(feature = "decompression-br")]
137 b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli(
138 WrapBody::new(body, CompressionLevel::default()),
139 )),
140
141 #[cfg(feature = "decompression-zstd")]
142 b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd(
143 WrapBody::new(body, CompressionLevel::default()),
144 )),
145
146 _ => {
147 return Ok(Response::from_parts(
148 parts,
149 DecompressionBody::new(BodyInner::identity(body)),
150 ))
151 }
152 };
153
154 entry.remove();
155 parts.headers.remove(header::CONTENT_LENGTH);
156
157 Response::from_parts(parts, body)
158 } else {
159 Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body)))
160 };
161
162 Ok(res)
163 }
164}