tower_async_http/decompression/request/
service.rs1use super::layer::RequestDecompressionLayer;
2use crate::compression_utils::CompressionLevel;
3use crate::{
4 compression_utils::AcceptEncoding, decompression::body::BodyInner,
5 decompression::DecompressionBody, BoxError,
6};
7use bytes::Buf;
8use http::{header, HeaderValue, Request, Response, StatusCode};
9use http_body::Body;
10use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Empty};
11use tower_async_service::Service;
12
13#[cfg(any(
14 feature = "decompression-gzip",
15 feature = "decompression-deflate",
16 feature = "decompression-br",
17 feature = "decompression-zstd",
18))]
19use crate::content_encoding::SupportedEncodings;
20
21#[derive(Debug, Clone)]
33pub struct RequestDecompression<S> {
34 pub(super) inner: S,
35 pub(super) accept: AcceptEncoding,
36 pub(super) pass_through_unaccepted: bool,
37}
38
39impl<S, ReqBody, ResBody, D> Service<Request<ReqBody>> for RequestDecompression<S>
40where
41 S: Service<Request<DecompressionBody<ReqBody>>, Response = Response<ResBody>>,
42 ReqBody: Body,
43 ResBody: Body<Data = D> + Send + 'static,
44 S::Error: Into<BoxError>,
45 <ResBody as Body>::Error: Into<BoxError>,
46 D: Buf + 'static,
47{
48 type Response = Response<UnsyncBoxBody<D, BoxError>>;
49 type Error = BoxError;
50
51 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
52 let (mut parts, body) = req.into_parts();
53
54 let body =
55 if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
56 match entry.get().as_bytes() {
57 #[cfg(feature = "decompression-gzip")]
58 b"gzip" if self.accept.gzip() => {
59 entry.remove();
60 parts.headers.remove(header::CONTENT_LENGTH);
61 BodyInner::gzip(crate::compression_utils::WrapBody::new(
62 body,
63 CompressionLevel::default(),
64 ))
65 }
66 #[cfg(feature = "decompression-deflate")]
67 b"deflate" if self.accept.deflate() => {
68 entry.remove();
69 parts.headers.remove(header::CONTENT_LENGTH);
70 BodyInner::deflate(crate::compression_utils::WrapBody::new(
71 body,
72 CompressionLevel::default(),
73 ))
74 }
75 #[cfg(feature = "decompression-br")]
76 b"br" if self.accept.br() => {
77 entry.remove();
78 parts.headers.remove(header::CONTENT_LENGTH);
79 BodyInner::brotli(crate::compression_utils::WrapBody::new(
80 body,
81 CompressionLevel::default(),
82 ))
83 }
84 #[cfg(feature = "decompression-zstd")]
85 b"zstd" if self.accept.zstd() => {
86 entry.remove();
87 parts.headers.remove(header::CONTENT_LENGTH);
88 BodyInner::zstd(crate::compression_utils::WrapBody::new(
89 body,
90 CompressionLevel::default(),
91 ))
92 }
93 b"identity" => BodyInner::identity(body),
94 _ if self.pass_through_unaccepted => BodyInner::identity(body),
95 _ => return unsupported_encoding(self.accept).await,
96 }
97 } else {
98 BodyInner::identity(body)
99 };
100 let body = DecompressionBody::new(body);
101 let req = Request::from_parts(parts, body);
102 self.inner
103 .call(req)
104 .await
105 .map(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
106 .map_err(Into::into)
107 }
108}
109
110async fn unsupported_encoding<D>(
111 accept: AcceptEncoding,
112) -> Result<Response<UnsyncBoxBody<D, BoxError>>, BoxError>
113where
114 D: Buf + 'static,
115{
116 let res = Response::builder()
117 .header(
118 header::ACCEPT_ENCODING,
119 accept
120 .to_header_value()
121 .unwrap_or(HeaderValue::from_static("identity")),
122 )
123 .status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
124 .body(Empty::new().map_err(Into::into).boxed_unsync())
125 .unwrap();
126 Ok(res)
127}
128
129impl<S> RequestDecompression<S> {
130 pub fn new(service: S) -> Self {
132 Self {
133 inner: service,
134 accept: AcceptEncoding::default(),
135 pass_through_unaccepted: false,
136 }
137 }
138
139 define_inner_service_accessors!();
140
141 pub fn layer() -> RequestDecompressionLayer {
145 RequestDecompressionLayer::new()
146 }
147
148 pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self {
152 self.pass_through_unaccepted = enabled;
153 self
154 }
155
156 #[cfg(feature = "decompression-gzip")]
158 pub fn gzip(mut self, enable: bool) -> Self {
159 self.accept.set_gzip(enable);
160 self
161 }
162
163 #[cfg(feature = "decompression-deflate")]
165 pub fn deflate(mut self, enable: bool) -> Self {
166 self.accept.set_deflate(enable);
167 self
168 }
169
170 #[cfg(feature = "decompression-br")]
172 pub fn br(mut self, enable: bool) -> Self {
173 self.accept.set_br(enable);
174 self
175 }
176
177 #[cfg(feature = "decompression-zstd")]
179 pub fn zstd(mut self, enable: bool) -> Self {
180 self.accept.set_zstd(enable);
181 self
182 }
183
184 pub fn no_gzip(mut self) -> Self {
188 self.accept.set_gzip(false);
189 self
190 }
191
192 pub fn no_deflate(mut self) -> Self {
196 self.accept.set_deflate(false);
197 self
198 }
199
200 pub fn no_br(mut self) -> Self {
204 self.accept.set_br(false);
205 self
206 }
207
208 pub fn no_zstd(mut self) -> Self {
212 self.accept.set_zstd(false);
213 self
214 }
215}