s2_api/
data.rs

1use std::str::FromStr;
2
3use base64ct::{Base64, Encoding as _};
4use bytes::Bytes;
5use s2_common::types::ValidationError;
6
7#[derive(Debug)]
8pub struct Json<T>(pub T);
9
10#[cfg(feature = "axum")]
11impl<T> axum::response::IntoResponse for Json<T>
12where
13    T: serde::Serialize,
14{
15    fn into_response(self) -> axum::response::Response {
16        let Self(value) = self;
17        axum::Json(value).into_response()
18    }
19}
20
21#[derive(Debug)]
22pub struct Proto<T>(pub T);
23
24#[cfg(feature = "axum")]
25impl<T> axum::response::IntoResponse for Proto<T>
26where
27    T: prost::Message,
28{
29    fn into_response(self) -> axum::response::Response {
30        let headers = [(
31            http::header::CONTENT_TYPE,
32            http::header::HeaderValue::from_static("application/protobuf"),
33        )];
34        let body = self.0.encode_to_vec();
35        (headers, body).into_response()
36    }
37}
38
39#[rustfmt::skip]
40#[derive(Debug, Default, Clone, Copy)]
41#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
42pub enum Format {
43    #[default]
44    #[cfg_attr(feature = "utoipa", schema(rename = "raw"))]
45    Raw,
46    #[cfg_attr(feature = "utoipa", schema(rename = "base64"))]
47    Base64,
48}
49
50impl s2_common::http::ParseableHeader for Format {
51    fn name() -> &'static http::HeaderName {
52        &FORMAT_HEADER
53    }
54}
55
56impl Format {
57    pub fn encode(self, bytes: &[u8]) -> String {
58        match self {
59            Format::Raw => String::from_utf8_lossy(bytes).into_owned(),
60            Format::Base64 => Base64::encode_string(bytes),
61        }
62    }
63
64    pub fn decode(self, s: String) -> Result<Bytes, ValidationError> {
65        Ok(match self {
66            Format::Raw => s.into_bytes().into(),
67            Format::Base64 => Base64::decode_vec(&s)
68                .map_err(|_| ValidationError("invalid Base64 encoding".to_owned()))?
69                .into(),
70        })
71    }
72}
73
74impl FromStr for Format {
75    type Err = ValidationError;
76
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        match s.trim() {
79            "raw" | "json" => Ok(Self::Raw),
80            "base64" | "json-binsafe" => Ok(Self::Base64),
81            _ => Err(ValidationError(s.to_string())),
82        }
83    }
84}
85
86pub static FORMAT_HEADER: http::HeaderName = http::HeaderName::from_static("s2-format");
87
88#[rustfmt::skip]
89#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
90#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
91pub struct S2FormatHeader {
92    /// Defines the interpretation of record data (header name, header value, and body) with the JSON content type.
93    /// Use `raw` (default) for efficient transmission and storage of Unicode data — storage will be in UTF-8.
94    /// Use `base64` for safe transmission with efficient storage of binary data.
95    #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-format"))]
96    pub s2_format: Format,
97}
98
99#[cfg(feature = "axum")]
100pub mod extract {
101    use axum::{
102        extract::{
103            FromRequest, OptionalFromRequest, Request,
104            rejection::{BytesRejection, JsonRejection},
105        },
106        response::{IntoResponse, Response},
107    };
108    use bytes::Bytes;
109    use serde::de::DeserializeOwned;
110
111    impl<S, T> FromRequest<S> for super::Json<T>
112    where
113        S: Send + Sync,
114        T: DeserializeOwned,
115    {
116        type Rejection = JsonRejection;
117
118        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
119            let axum::Json(value) =
120                <axum::Json<T> as FromRequest<S>>::from_request(req, state).await?;
121            Ok(Self(value))
122        }
123    }
124
125    impl<S, T> OptionalFromRequest<S> for super::Json<T>
126    where
127        S: Send + Sync,
128        T: DeserializeOwned,
129    {
130        type Rejection = JsonRejection;
131
132        async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
133            let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
134                return Ok(None);
135            };
136            if !crate::mime::parse(ctype)
137                .as_ref()
138                .is_some_and(crate::mime::is_json)
139            {
140                Err(JsonRejection::MissingJsonContentType(Default::default()))?;
141            }
142            let bytes = Bytes::from_request(req, state)
143                .await
144                .map_err(JsonRejection::BytesRejection)?;
145            if bytes.is_empty() {
146                return Ok(None);
147            }
148            let value = axum::Json::<T>::from_bytes(&bytes)?.0;
149            Ok(Some(Self(value)))
150        }
151    }
152
153    /// Workaround for https://github.com/tokio-rs/axum/issues/3623
154    #[derive(Debug)]
155    pub struct JsonOpt<T>(pub Option<T>);
156
157    impl<S, T> FromRequest<S> for JsonOpt<T>
158    where
159        S: Send + Sync,
160        T: DeserializeOwned,
161    {
162        type Rejection = JsonRejection;
163
164        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
165            match <super::Json<T> as OptionalFromRequest<S>>::from_request(req, state).await {
166                Ok(Some(super::Json(value))) => Ok(Self(Some(value))),
167                Ok(None) => Ok(Self(None)),
168                Err(e) => Err(e),
169            }
170        }
171    }
172
173    #[derive(Debug, thiserror::Error)]
174    pub enum ProtoRejection {
175        #[error(transparent)]
176        BytesRejection(#[from] BytesRejection),
177        #[error(transparent)]
178        Decode(#[from] prost::DecodeError),
179    }
180
181    impl IntoResponse for ProtoRejection {
182        fn into_response(self) -> Response {
183            match self {
184                ProtoRejection::BytesRejection(e) => e.into_response(),
185                ProtoRejection::Decode(e) => (
186                    http::StatusCode::BAD_REQUEST,
187                    format!("Invalid protobuf body: {e}"),
188                )
189                    .into_response(),
190            }
191        }
192    }
193
194    impl<S, T> FromRequest<S> for super::Proto<T>
195    where
196        S: Send + Sync,
197        T: prost::Message + Default,
198    {
199        type Rejection = ProtoRejection;
200
201        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
202            let bytes = Bytes::from_request(req, state).await?;
203            Ok(super::Proto(T::decode(bytes)?))
204        }
205    }
206}