Skip to main content

s2_api/
data.rs

1use std::str::FromStr;
2
3use base64ct::{Base64, Encoding as _};
4use bytes::Bytes;
5use s2_common::types::ValidationError;
6
7#[derive(Debug)]
8pub struct Json<T>(pub T);
9
10#[cfg(feature = "axum")]
11impl<T> axum::response::IntoResponse for Json<T>
12where
13    T: serde::Serialize,
14{
15    fn into_response(self) -> axum::response::Response {
16        let Self(value) = self;
17        axum::Json(value).into_response()
18    }
19}
20
21#[derive(Debug)]
22pub struct Proto<T>(pub T);
23
24#[cfg(feature = "axum")]
25impl<T> axum::response::IntoResponse for Proto<T>
26where
27    T: prost::Message,
28{
29    fn into_response(self) -> axum::response::Response {
30        let headers = [(
31            http::header::CONTENT_TYPE,
32            http::header::HeaderValue::from_static("application/protobuf"),
33        )];
34        let body = self.0.encode_to_vec();
35        (headers, body).into_response()
36    }
37}
38
39#[rustfmt::skip]
40#[derive(Debug, Default, Clone, Copy)]
41#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
42pub enum Format {
43    #[default]
44    #[cfg_attr(feature = "utoipa", schema(rename = "raw"))]
45    Raw,
46    #[cfg_attr(feature = "utoipa", schema(rename = "base64"))]
47    Base64,
48}
49
50impl s2_common::http::ParseableHeader for Format {
51    fn name() -> &'static http::HeaderName {
52        &FORMAT_HEADER
53    }
54}
55
56impl Format {
57    pub fn encode(self, bytes: &[u8]) -> String {
58        match self {
59            Format::Raw => String::from_utf8_lossy(bytes).into_owned(),
60            Format::Base64 => Base64::encode_string(bytes),
61        }
62    }
63
64    pub fn decode(self, s: String) -> Result<Bytes, ValidationError> {
65        Ok(match self {
66            Format::Raw => s.into_bytes().into(),
67            Format::Base64 => Base64::decode_vec(&s)
68                .map_err(|_| ValidationError("invalid Base64 encoding".to_owned()))?
69                .into(),
70        })
71    }
72}
73
74impl FromStr for Format {
75    type Err = ValidationError;
76
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        match s.trim() {
79            "raw" | "json" => Ok(Self::Raw),
80            "base64" | "json-binsafe" => Ok(Self::Base64),
81            _ => Err(ValidationError(s.to_string())),
82        }
83    }
84}
85
86pub static FORMAT_HEADER: http::HeaderName = http::HeaderName::from_static("s2-format");
87
88#[rustfmt::skip]
89#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
90#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
91pub struct S2FormatHeader {
92    /// Defines the interpretation of record data (header name, header value, and body) with the JSON content type.
93    /// Use `raw` (default) for efficient transmission and storage of Unicode data — storage will be in UTF-8.
94    /// Use `base64` for safe transmission with efficient storage of binary data.
95    #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-format"))]
96    pub s2_format: Format,
97}
98
99#[rustfmt::skip]
100#[derive(Debug)]
101#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
102#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
103pub struct S2EncryptionKeyHeader {
104    /// Encryption key material for append and read operations.
105    /// Provide base64-encoded key when stream encryption is enabled.
106    #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-encryption-key", value_type = String))]
107    pub s2_encryption_key: String,
108}
109
110#[cfg(feature = "axum")]
111pub mod extract {
112    use std::borrow::Cow;
113
114    use axum::{
115        extract::{FromRequest, OptionalFromRequest, Request, rejection::BytesRejection},
116        response::{IntoResponse, Response},
117    };
118    use bytes::Bytes;
119    use serde::de::DeserializeOwned;
120
121    /// Rejection type for JSON extraction, owned by s2-api.
122    #[derive(Debug)]
123    #[non_exhaustive]
124    pub enum JsonExtractionRejection {
125        SyntaxError {
126            status: http::StatusCode,
127            message: Cow<'static, str>,
128        },
129        DataError {
130            status: http::StatusCode,
131            message: Cow<'static, str>,
132        },
133        MissingContentType,
134        Other {
135            status: http::StatusCode,
136            message: Cow<'static, str>,
137        },
138    }
139
140    const MISSING_CONTENT_TYPE_MSG: &str = "Expected request with `Content-Type: application/json`";
141
142    impl JsonExtractionRejection {
143        pub fn body_text(&self) -> &str {
144            match self {
145                Self::SyntaxError { message, .. }
146                | Self::DataError { message, .. }
147                | Self::Other { message, .. } => message,
148                Self::MissingContentType => MISSING_CONTENT_TYPE_MSG,
149            }
150        }
151
152        pub fn status(&self) -> http::StatusCode {
153            match self {
154                Self::SyntaxError { status, .. }
155                | Self::DataError { status, .. }
156                | Self::Other { status, .. } => *status,
157                Self::MissingContentType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE,
158            }
159        }
160    }
161
162    impl std::fmt::Display for JsonExtractionRejection {
163        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164            f.write_str(self.body_text())
165        }
166    }
167
168    impl std::error::Error for JsonExtractionRejection {}
169
170    impl IntoResponse for JsonExtractionRejection {
171        fn into_response(self) -> Response {
172            let status = self.status();
173            match self {
174                Self::SyntaxError { message, .. }
175                | Self::DataError { message, .. }
176                | Self::Other { message, .. } => match message {
177                    Cow::Borrowed(s) => (status, s).into_response(),
178                    Cow::Owned(s) => (status, s).into_response(),
179                },
180                Self::MissingContentType => (status, MISSING_CONTENT_TYPE_MSG).into_response(),
181            }
182        }
183    }
184
185    fn classify_sonic_error(err: sonic_rs::Error) -> JsonExtractionRejection {
186        use sonic_rs::error::Category;
187        match err.classify() {
188            Category::TypeUnmatched | Category::NotFound => JsonExtractionRejection::DataError {
189                status: http::StatusCode::UNPROCESSABLE_ENTITY,
190                message: err.to_string().into(),
191            },
192            Category::Io => JsonExtractionRejection::Other {
193                status: http::StatusCode::INTERNAL_SERVER_ERROR,
194                message: err.to_string().into(),
195            },
196            _ => JsonExtractionRejection::SyntaxError {
197                status: http::StatusCode::BAD_REQUEST,
198                message: err.to_string().into(),
199            },
200        }
201    }
202
203    impl<S, T> FromRequest<S> for super::Json<T>
204    where
205        S: Send + Sync,
206        T: DeserializeOwned,
207    {
208        type Rejection = JsonExtractionRejection;
209
210        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
211            let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
212                return Err(JsonExtractionRejection::MissingContentType);
213            };
214            if !crate::mime::parse(ctype)
215                .as_ref()
216                .is_some_and(crate::mime::is_json)
217            {
218                return Err(JsonExtractionRejection::MissingContentType);
219            }
220            let bytes = Bytes::from_request(req, state).await.map_err(|e| {
221                JsonExtractionRejection::Other {
222                    status: e.status(),
223                    message: e.body_text().into(),
224                }
225            })?;
226            sonic_rs::from_slice(&bytes)
227                .map(Self)
228                .map_err(classify_sonic_error)
229        }
230    }
231
232    impl<S, T> OptionalFromRequest<S> for super::Json<T>
233    where
234        S: Send + Sync,
235        T: DeserializeOwned,
236    {
237        type Rejection = JsonExtractionRejection;
238
239        async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
240            let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
241                return Ok(None);
242            };
243            if !crate::mime::parse(ctype)
244                .as_ref()
245                .is_some_and(crate::mime::is_json)
246            {
247                return Err(JsonExtractionRejection::MissingContentType);
248            }
249            let bytes = Bytes::from_request(req, state).await.map_err(|e| {
250                JsonExtractionRejection::Other {
251                    status: e.status(),
252                    message: e.body_text().into(),
253                }
254            })?;
255            if bytes.is_empty() {
256                return Ok(None);
257            }
258            sonic_rs::from_slice(&bytes)
259                .map(|v| Some(Self(v)))
260                .map_err(classify_sonic_error)
261        }
262    }
263
264    /// Workaround for https://github.com/tokio-rs/axum/issues/3623
265    #[derive(Debug)]
266    pub struct JsonOpt<T>(pub Option<T>);
267
268    impl<S, T> FromRequest<S> for JsonOpt<T>
269    where
270        S: Send + Sync,
271        T: DeserializeOwned,
272    {
273        type Rejection = JsonExtractionRejection;
274
275        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
276            match <super::Json<T> as OptionalFromRequest<S>>::from_request(req, state).await {
277                Ok(Some(super::Json(value))) => Ok(Self(Some(value))),
278                Ok(None) => Ok(Self(None)),
279                Err(e) => Err(e),
280            }
281        }
282    }
283
284    #[derive(Debug, thiserror::Error)]
285    pub enum ProtoRejection {
286        #[error(transparent)]
287        BytesRejection(#[from] BytesRejection),
288        #[error(transparent)]
289        Decode(#[from] prost::DecodeError),
290    }
291
292    impl IntoResponse for ProtoRejection {
293        fn into_response(self) -> Response {
294            match self {
295                ProtoRejection::BytesRejection(e) => e.into_response(),
296                ProtoRejection::Decode(e) => (
297                    http::StatusCode::BAD_REQUEST,
298                    format!("Invalid protobuf body: {e}"),
299                )
300                    .into_response(),
301            }
302        }
303    }
304
305    impl<S, T> FromRequest<S> for super::Proto<T>
306    where
307        S: Send + Sync,
308        T: prost::Message + Default,
309    {
310        type Rejection = ProtoRejection;
311
312        async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
313            let bytes = Bytes::from_request(req, state).await?;
314            Ok(super::Proto(T::decode(bytes)?))
315        }
316    }
317
318    #[cfg(test)]
319    mod tests {
320        use super::*;
321        use crate::v1::{
322            config::{BasinReconfiguration, StreamReconfiguration},
323            stream::{AppendInput, AppendRecord, Header},
324        };
325
326        fn classify_json_error<T: DeserializeOwned>(
327            json: &[u8],
328        ) -> Result<T, JsonExtractionRejection> {
329            sonic_rs::from_slice(json).map_err(classify_sonic_error)
330        }
331
332        /// Verify that our rejection wrapper preserves axum's status code
333        /// classification for a variety of invalid JSON payloads, now using
334        /// sonic-rs as the deserializer.
335        #[test]
336        fn json_error_classification() {
337            let cases: &[(&[u8], http::StatusCode)] = &[
338                // Syntax errors → 400
339                (b"not json", http::StatusCode::BAD_REQUEST),
340                // `{}` is valid JSON but missing `records` — the data error is
341                // reported before checking trailing chars.
342                (b"{} trailing", http::StatusCode::UNPROCESSABLE_ENTITY),
343                (b"", http::StatusCode::BAD_REQUEST),
344                (b"{truncated", http::StatusCode::BAD_REQUEST),
345                // Data errors → 422
346                (b"{}", http::StatusCode::UNPROCESSABLE_ENTITY),
347                (
348                    br#"{"records": "nope"}"#,
349                    http::StatusCode::UNPROCESSABLE_ENTITY,
350                ),
351                (
352                    br#"{"records": [{"body": 123}]}"#,
353                    http::StatusCode::UNPROCESSABLE_ENTITY,
354                ),
355            ];
356
357            for (input, expected_status) in cases {
358                let err = classify_json_error::<AppendInput>(input).expect_err(&format!(
359                    "expected error for {:?}",
360                    String::from_utf8_lossy(input)
361                ));
362                assert_eq!(
363                    err.status(),
364                    *expected_status,
365                    "wrong status for {:?}: got {}, body: {}",
366                    String::from_utf8_lossy(input),
367                    err.status(),
368                    err.body_text(),
369                );
370            }
371        }
372
373        #[test]
374        fn valid_json_parses_successfully() {
375            let input = br#"{"records": [], "match_seq_num": null}"#;
376            let result = classify_json_error::<AppendInput>(input);
377            assert!(result.is_ok());
378        }
379
380        /// Differential test: serialize with serde_json, deserialize with
381        /// both serde_json and sonic_rs, assert semantic equality.
382        #[test]
383        fn serde_json_sonic_rs_roundtrip() {
384            fn assert_roundtrip<T>(input: &T)
385            where
386                T: serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug,
387            {
388                let json = serde_json::to_vec(input).unwrap();
389                let from_serde: T = serde_json::from_slice(&json).unwrap();
390                let from_sonic: T = sonic_rs::from_slice(&json).unwrap();
391                assert_eq!(
392                    format!("{from_serde:?}"),
393                    format!("{from_sonic:?}"),
394                    "roundtrip mismatch for {}",
395                    String::from_utf8_lossy(&json),
396                );
397            }
398
399            // AppendInput variants
400            assert_roundtrip(&AppendInput {
401                records: vec![],
402                match_seq_num: None,
403                fencing_token: None,
404            });
405            assert_roundtrip(&AppendInput {
406                records: vec![AppendRecord {
407                    timestamp: None,
408                    headers: vec![Header("key".into(), "val".into())],
409                    body: "hello world".into(),
410                }],
411                match_seq_num: Some(42),
412                fencing_token: Some("token".parse().unwrap()),
413            });
414
415            // StreamReconfiguration: exercises Maybe<T> in all three states
416            use s2_common::maybe::Maybe;
417
418            use crate::v1::config::{StorageClass, TimestampingMode, TimestampingReconfiguration};
419
420            // All fields unspecified (empty JSON object)
421            assert_roundtrip(&StreamReconfiguration {
422                storage_class: Maybe::Unspecified,
423                retention_policy: Maybe::Unspecified,
424                timestamping: Maybe::Unspecified,
425                delete_on_empty: Maybe::Unspecified,
426            });
427            // Mix of specified-null and specified-value
428            assert_roundtrip(&StreamReconfiguration {
429                storage_class: Maybe::Specified(Some(StorageClass::Express)),
430                retention_policy: Maybe::Specified(None),
431                timestamping: Maybe::Specified(Some(TimestampingReconfiguration {
432                    mode: Maybe::Specified(Some(TimestampingMode::ClientRequire)),
433                    uncapped: Maybe::Specified(Some(true)),
434                })),
435                delete_on_empty: Maybe::Unspecified,
436            });
437
438            // BasinReconfiguration: nested Maybe<Option<StreamReconfiguration>>
439            assert_roundtrip(&BasinReconfiguration {
440                default_stream_config: Maybe::Specified(None),
441                stream_cipher: Maybe::Unspecified,
442                create_stream_on_append: Maybe::Specified(true),
443                create_stream_on_read: Maybe::Unspecified,
444            });
445        }
446    }
447}