1use std::str::FromStr;
2
3use base64ct::{Base64, Encoding as _};
4use bytes::Bytes;
5use s2_common::types::ValidationError;
6
7#[derive(Debug)]
8pub struct Json<T>(pub T);
9
10#[cfg(feature = "axum")]
11impl<T> axum::response::IntoResponse for Json<T>
12where
13 T: serde::Serialize,
14{
15 fn into_response(self) -> axum::response::Response {
16 let Self(value) = self;
17 axum::Json(value).into_response()
18 }
19}
20
21#[derive(Debug)]
22pub struct Proto<T>(pub T);
23
24#[cfg(feature = "axum")]
25impl<T> axum::response::IntoResponse for Proto<T>
26where
27 T: prost::Message,
28{
29 fn into_response(self) -> axum::response::Response {
30 let headers = [(
31 http::header::CONTENT_TYPE,
32 http::header::HeaderValue::from_static("application/protobuf"),
33 )];
34 let body = self.0.encode_to_vec();
35 (headers, body).into_response()
36 }
37}
38
39#[rustfmt::skip]
40#[derive(Debug, Default, Clone, Copy)]
41#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
42pub enum Format {
43 #[default]
44 #[cfg_attr(feature = "utoipa", schema(rename = "raw"))]
45 Raw,
46 #[cfg_attr(feature = "utoipa", schema(rename = "base64"))]
47 Base64,
48}
49
50impl s2_common::http::ParseableHeader for Format {
51 fn name() -> &'static http::HeaderName {
52 &FORMAT_HEADER
53 }
54}
55
56impl Format {
57 pub fn encode(self, bytes: &[u8]) -> String {
58 match self {
59 Format::Raw => String::from_utf8_lossy(bytes).into_owned(),
60 Format::Base64 => Base64::encode_string(bytes),
61 }
62 }
63
64 pub fn decode(self, s: String) -> Result<Bytes, ValidationError> {
65 Ok(match self {
66 Format::Raw => s.into_bytes().into(),
67 Format::Base64 => Base64::decode_vec(&s)
68 .map_err(|_| ValidationError("invalid Base64 encoding".to_owned()))?
69 .into(),
70 })
71 }
72}
73
74impl FromStr for Format {
75 type Err = ValidationError;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 match s.trim() {
79 "raw" | "json" => Ok(Self::Raw),
80 "base64" | "json-binsafe" => Ok(Self::Base64),
81 _ => Err(ValidationError(s.to_string())),
82 }
83 }
84}
85
86pub static FORMAT_HEADER: http::HeaderName = http::HeaderName::from_static("s2-format");
87
88#[rustfmt::skip]
89#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
90#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
91pub struct S2FormatHeader {
92 #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-format"))]
96 pub s2_format: Format,
97}
98
99#[rustfmt::skip]
100#[derive(Debug)]
101#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))]
102#[cfg_attr(feature = "utoipa", into_params(parameter_in = Header))]
103pub struct S2EncryptionHeader {
104 #[cfg_attr(feature = "utoipa", param(required = false, rename = "s2-encryption", value_type = String))]
107 pub s2_encryption: String,
108}
109
110#[cfg(feature = "axum")]
111pub mod extract {
112 use std::borrow::Cow;
113
114 use axum::{
115 extract::{FromRequest, OptionalFromRequest, Request, rejection::BytesRejection},
116 response::{IntoResponse, Response},
117 };
118 use bytes::Bytes;
119 use serde::de::DeserializeOwned;
120
121 #[derive(Debug)]
123 #[non_exhaustive]
124 pub enum JsonExtractionRejection {
125 SyntaxError {
126 status: http::StatusCode,
127 message: Cow<'static, str>,
128 },
129 DataError {
130 status: http::StatusCode,
131 message: Cow<'static, str>,
132 },
133 MissingContentType,
134 Other {
135 status: http::StatusCode,
136 message: Cow<'static, str>,
137 },
138 }
139
140 const MISSING_CONTENT_TYPE_MSG: &str = "Expected request with `Content-Type: application/json`";
141
142 impl JsonExtractionRejection {
143 pub fn body_text(&self) -> &str {
144 match self {
145 Self::SyntaxError { message, .. }
146 | Self::DataError { message, .. }
147 | Self::Other { message, .. } => message,
148 Self::MissingContentType => MISSING_CONTENT_TYPE_MSG,
149 }
150 }
151
152 pub fn status(&self) -> http::StatusCode {
153 match self {
154 Self::SyntaxError { status, .. }
155 | Self::DataError { status, .. }
156 | Self::Other { status, .. } => *status,
157 Self::MissingContentType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE,
158 }
159 }
160 }
161
162 impl std::fmt::Display for JsonExtractionRejection {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.write_str(self.body_text())
165 }
166 }
167
168 impl std::error::Error for JsonExtractionRejection {}
169
170 impl IntoResponse for JsonExtractionRejection {
171 fn into_response(self) -> Response {
172 let status = self.status();
173 match self {
174 Self::SyntaxError { message, .. }
175 | Self::DataError { message, .. }
176 | Self::Other { message, .. } => match message {
177 Cow::Borrowed(s) => (status, s).into_response(),
178 Cow::Owned(s) => (status, s).into_response(),
179 },
180 Self::MissingContentType => (status, MISSING_CONTENT_TYPE_MSG).into_response(),
181 }
182 }
183 }
184
185 impl From<axum::extract::rejection::JsonRejection> for JsonExtractionRejection {
187 fn from(rej: axum::extract::rejection::JsonRejection) -> Self {
188 use axum::extract::rejection::JsonRejection::*;
189 match rej {
190 JsonDataError(e) => Self::DataError {
191 status: e.status(),
192 message: e.body_text().into(),
193 },
194 JsonSyntaxError(e) => Self::SyntaxError {
195 status: e.status(),
196 message: e.body_text().into(),
197 },
198 MissingJsonContentType(_) => Self::MissingContentType,
199 other => Self::Other {
200 status: other.status(),
201 message: other.body_text().into(),
202 },
203 }
204 }
205 }
206
207 impl<S, T> FromRequest<S> for super::Json<T>
208 where
209 S: Send + Sync,
210 T: DeserializeOwned,
211 {
212 type Rejection = JsonExtractionRejection;
213
214 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
215 let axum::Json(value) = <axum::Json<T> as FromRequest<S>>::from_request(req, state)
216 .await
217 .map_err(JsonExtractionRejection::from)?;
218 Ok(Self(value))
219 }
220 }
221
222 impl<S, T> OptionalFromRequest<S> for super::Json<T>
223 where
224 S: Send + Sync,
225 T: DeserializeOwned,
226 {
227 type Rejection = JsonExtractionRejection;
228
229 async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
230 let Some(ctype) = req.headers().get(http::header::CONTENT_TYPE) else {
231 return Ok(None);
232 };
233 if !crate::mime::parse(ctype)
234 .as_ref()
235 .is_some_and(crate::mime::is_json)
236 {
237 return Err(JsonExtractionRejection::MissingContentType);
238 }
239 let bytes = Bytes::from_request(req, state).await.map_err(|e| {
240 JsonExtractionRejection::Other {
241 status: e.status(),
242 message: e.body_text().into(),
243 }
244 })?;
245 if bytes.is_empty() {
246 return Ok(None);
247 }
248 let value = axum::Json::<T>::from_bytes(&bytes)
249 .map_err(JsonExtractionRejection::from)?
250 .0;
251 Ok(Some(Self(value)))
252 }
253 }
254
255 #[derive(Debug)]
257 pub struct JsonOpt<T>(pub Option<T>);
258
259 impl<S, T> FromRequest<S> for JsonOpt<T>
260 where
261 S: Send + Sync,
262 T: DeserializeOwned,
263 {
264 type Rejection = JsonExtractionRejection;
265
266 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
267 match <super::Json<T> as OptionalFromRequest<S>>::from_request(req, state).await {
268 Ok(Some(super::Json(value))) => Ok(Self(Some(value))),
269 Ok(None) => Ok(Self(None)),
270 Err(e) => Err(e),
271 }
272 }
273 }
274
275 #[derive(Debug, thiserror::Error)]
276 pub enum ProtoRejection {
277 #[error(transparent)]
278 BytesRejection(#[from] BytesRejection),
279 #[error(transparent)]
280 Decode(#[from] prost::DecodeError),
281 }
282
283 impl IntoResponse for ProtoRejection {
284 fn into_response(self) -> Response {
285 match self {
286 ProtoRejection::BytesRejection(e) => e.into_response(),
287 ProtoRejection::Decode(e) => (
288 http::StatusCode::BAD_REQUEST,
289 format!("Invalid protobuf body: {e}"),
290 )
291 .into_response(),
292 }
293 }
294 }
295
296 impl<S, T> FromRequest<S> for super::Proto<T>
297 where
298 S: Send + Sync,
299 T: prost::Message + Default,
300 {
301 type Rejection = ProtoRejection;
302
303 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
304 let bytes = Bytes::from_request(req, state).await?;
305 Ok(super::Proto(T::decode(bytes)?))
306 }
307 }
308
309 #[cfg(test)]
310 mod tests {
311 use super::*;
312 use crate::v1::stream::AppendInput;
313
314 fn classify_json_error<T: DeserializeOwned>(
315 json: &[u8],
316 ) -> Result<T, JsonExtractionRejection> {
317 axum::Json::<T>::from_bytes(json)
318 .map(|axum::Json(v)| v)
319 .map_err(JsonExtractionRejection::from)
320 }
321
322 #[test]
326 fn json_error_classification() {
327 let cases: &[(&[u8], http::StatusCode)] = &[
328 (b"not json", http::StatusCode::BAD_REQUEST),
330 (b"{} trailing", http::StatusCode::UNPROCESSABLE_ENTITY),
333 (b"", http::StatusCode::BAD_REQUEST),
334 (b"{truncated", http::StatusCode::BAD_REQUEST),
335 (b"{}", http::StatusCode::UNPROCESSABLE_ENTITY),
337 (
338 br#"{"records": "nope"}"#,
339 http::StatusCode::UNPROCESSABLE_ENTITY,
340 ),
341 (
342 br#"{"records": [{"body": 123}]}"#,
343 http::StatusCode::UNPROCESSABLE_ENTITY,
344 ),
345 ];
346
347 for (input, expected_status) in cases {
348 let err = classify_json_error::<AppendInput>(input).expect_err(&format!(
349 "expected error for {:?}",
350 String::from_utf8_lossy(input)
351 ));
352 assert_eq!(
353 err.status(),
354 *expected_status,
355 "wrong status for {:?}: got {}, body: {}",
356 String::from_utf8_lossy(input),
357 err.status(),
358 err.body_text(),
359 );
360 }
361 }
362
363 #[test]
364 fn valid_json_parses_successfully() {
365 let input = br#"{"records": [], "match_seq_num": null}"#;
366 let result = classify_json_error::<AppendInput>(input);
367 assert!(result.is_ok());
368 }
369 }
370}