1use http::header::CONTENT_TYPE;
28use http::StatusCode;
29use typeway_core::negotiate::ContentFormat;
30
31use crate::body::{body_from_bytes, body_from_string, BoxBody};
32use crate::response::IntoResponse;
33
34pub struct JsonFormat;
40
41impl ContentFormat for JsonFormat {
42 const CONTENT_TYPE: &'static str = "application/json";
43}
44
45pub struct TextFormat;
47
48impl ContentFormat for TextFormat {
49 const CONTENT_TYPE: &'static str = "text/plain; charset=utf-8";
50}
51
52pub struct HtmlFormat;
54
55impl ContentFormat for HtmlFormat {
56 const CONTENT_TYPE: &'static str = "text/html; charset=utf-8";
57}
58
59pub struct CsvFormat;
61
62impl ContentFormat for CsvFormat {
63 const CONTENT_TYPE: &'static str = "text/csv";
64}
65
66pub struct XmlFormat;
68
69impl ContentFormat for XmlFormat {
70 const CONTENT_TYPE: &'static str = "application/xml";
71}
72
73pub trait RenderAsXml {
81 fn to_xml(&self) -> String;
82}
83
84impl<T: RenderAsXml> RenderAs<XmlFormat> for T {
85 fn render(&self) -> Result<(Vec<u8>, &'static str), String> {
86 Ok((self.to_xml().into_bytes(), XmlFormat::CONTENT_TYPE))
87 }
88}
89
90pub trait RenderAs<Format: ContentFormat> {
103 fn render(&self) -> Result<(Vec<u8>, &'static str), String>;
105}
106
107impl<T: serde::Serialize> RenderAs<JsonFormat> for T {
108 fn render(&self) -> Result<(Vec<u8>, &'static str), String> {
109 let bytes = serde_json::to_vec(self).map_err(|e| e.to_string())?;
110 Ok((bytes, JsonFormat::CONTENT_TYPE))
111 }
112}
113
114impl<T: std::fmt::Display> RenderAs<TextFormat> for T {
115 fn render(&self) -> Result<(Vec<u8>, &'static str), String> {
116 Ok((self.to_string().into_bytes(), TextFormat::CONTENT_TYPE))
117 }
118}
119
120pub trait NegotiateFormats<T> {
129 fn supported_types() -> Vec<&'static str>;
131
132 fn negotiate_and_render(
134 value: &T,
135 accept: Option<&str>,
136 ) -> Result<(Vec<u8>, &'static str), String>;
137}
138
139fn parse_accept(accept: &str) -> Vec<(&str, f32)> {
142 let mut entries: Vec<(&str, f32)> = accept
143 .split(',')
144 .filter_map(|entry| {
145 let entry = entry.trim();
146 if entry.is_empty() {
147 return None;
148 }
149 let mut parts = entry.splitn(2, ';');
150 let media_type = parts.next()?.trim();
151 let quality = parts
152 .next()
153 .and_then(|params| {
154 params.split(';').find_map(|p| {
155 let p = p.trim();
156 p.strip_prefix("q=")
157 .and_then(|q| q.trim().parse::<f32>().ok())
158 })
159 })
160 .unwrap_or(1.0);
161 Some((media_type, quality))
162 })
163 .collect();
164 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
165 entries
166}
167
168fn media_type_matches(accept_type: &str, supported: &str) -> bool {
171 if accept_type == "*/*" {
172 return true;
173 }
174 let supported_base = supported.split(';').next().unwrap_or(supported).trim();
176 if accept_type == supported_base {
177 return true;
178 }
179 if let Some(prefix) = accept_type.strip_suffix("/*") {
181 if let Some(sup_prefix) = supported_base.split('/').next() {
182 return prefix == sup_prefix;
183 }
184 }
185 false
186}
187
188fn best_match(accept: Option<&str>, supported: &[&str]) -> usize {
191 let accept = match accept {
192 Some(a) if !a.is_empty() => a,
193 _ => return 0, };
195
196 let entries = parse_accept(accept);
197
198 for (media_type, _quality) in &entries {
201 for (idx, supported_type) in supported.iter().enumerate() {
202 if media_type_matches(media_type, supported_type) {
203 return idx;
204 }
205 }
206 }
207
208 0
210}
211
212macro_rules! impl_negotiate_formats {
214 ([$F1:ident], [$idx1:tt]) => {
216 impl<T, $F1> NegotiateFormats<T> for ($F1,)
217 where
218 $F1: ContentFormat,
219 T: RenderAs<$F1>,
220 {
221 fn supported_types() -> Vec<&'static str> {
222 vec![$F1::CONTENT_TYPE]
223 }
224
225 fn negotiate_and_render(
226 value: &T,
227 _accept: Option<&str>,
228 ) -> Result<(Vec<u8>, &'static str), String> {
229 <T as RenderAs<$F1>>::render(value)
230 }
231 }
232 };
233 ([$F1:ident $(, $FN:ident)*], [$idx1:tt $(, $idxN:tt)*]) => {
235 impl<T, $F1 $(, $FN)*> NegotiateFormats<T> for ($F1, $($FN,)*)
236 where
237 $F1: ContentFormat,
238 $($FN: ContentFormat,)*
239 T: RenderAs<$F1> $(+ RenderAs<$FN>)*,
240 {
241 fn supported_types() -> Vec<&'static str> {
242 vec![$F1::CONTENT_TYPE $(, $FN::CONTENT_TYPE)*]
243 }
244
245 fn negotiate_and_render(
246 value: &T,
247 accept: Option<&str>,
248 ) -> Result<(Vec<u8>, &'static str), String> {
249 let supported = [$F1::CONTENT_TYPE $(, $FN::CONTENT_TYPE)*];
250 let idx = best_match(accept, &supported);
251 let renderers: Vec<Box<dyn Fn(&T) -> Result<(Vec<u8>, &'static str), String>>> = vec![
253 Box::new(|v| <T as RenderAs<$F1>>::render(v)),
254 $(Box::new(|v| <T as RenderAs<$FN>>::render(v)),)*
255 ];
256 (renderers[idx])(value)
257 }
258 }
259 };
260}
261
262impl_negotiate_formats!([F1], [0]);
263impl_negotiate_formats!([F1, F2], [0, 1]);
264impl_negotiate_formats!([F1, F2, F3], [0, 1, 2]);
265impl_negotiate_formats!([F1, F2, F3, F4], [0, 1, 2, 3]);
266impl_negotiate_formats!([F1, F2, F3, F4, F5], [0, 1, 2, 3, 4]);
267impl_negotiate_formats!([F1, F2, F3, F4, F5, F6], [0, 1, 2, 3, 4, 5]);
268
269pub struct NegotiatedResponse<T, Formats> {
281 value: T,
282 accept: Option<String>,
283 _formats: std::marker::PhantomData<Formats>,
284}
285
286impl<T, Formats> NegotiatedResponse<T, Formats> {
287 pub fn new(value: T, accept: Option<String>) -> Self {
292 NegotiatedResponse {
293 value,
294 accept,
295 _formats: std::marker::PhantomData,
296 }
297 }
298}
299
300impl<T, Formats> IntoResponse for NegotiatedResponse<T, Formats>
301where
302 Formats: NegotiateFormats<T>,
303{
304 fn into_response(self) -> http::Response<BoxBody> {
305 match Formats::negotiate_and_render(&self.value, self.accept.as_deref()) {
306 Ok((body_bytes, content_type)) => {
307 let body = body_from_bytes(bytes::Bytes::from(body_bytes));
308 let mut res = http::Response::new(body);
309 if let Ok(val) = http::HeaderValue::from_str(content_type) {
310 res.headers_mut().insert(CONTENT_TYPE, val);
311 }
312 res
313 }
314 Err(e) => {
315 let mut res =
316 http::Response::new(body_from_string(format!("negotiation error: {e}")));
317 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
318 res
319 }
320 }
321 }
322}
323
324pub struct AcceptHeader(pub Option<String>);
342
343impl crate::extract::FromRequestParts for AcceptHeader {
344 type Error = std::convert::Infallible;
345
346 fn from_request_parts(parts: &http::request::Parts) -> Result<Self, Self::Error> {
347 let accept = parts
348 .headers
349 .get(http::header::ACCEPT)
350 .and_then(|v| v.to_str().ok())
351 .map(|s| s.to_string());
352 Ok(AcceptHeader(accept))
353 }
354}
355
356impl IntoResponse for std::convert::Infallible {
358 fn into_response(self) -> http::Response<BoxBody> {
359 match self {}
360 }
361}
362
363pub fn negotiated<T, Formats>(value: T, accept: AcceptHeader) -> NegotiatedResponse<T, Formats> {
377 NegotiatedResponse::new(value, accept.0)
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[derive(serde::Serialize)]
385 struct TestUser {
386 id: u32,
387 name: String,
388 }
389
390 impl std::fmt::Display for TestUser {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 write!(f, "User({}, {})", self.id, self.name)
393 }
394 }
395
396 fn test_user() -> TestUser {
397 TestUser {
398 id: 1,
399 name: "Alice".to_string(),
400 }
401 }
402
403 #[test]
404 fn parse_accept_simple() {
405 let entries = parse_accept("application/json");
406 assert_eq!(entries.len(), 1);
407 assert_eq!(entries[0].0, "application/json");
408 assert!((entries[0].1 - 1.0).abs() < f32::EPSILON);
409 }
410
411 #[test]
412 fn parse_accept_with_quality() {
413 let entries = parse_accept("text/plain;q=0.5, application/json;q=0.9");
414 assert_eq!(entries.len(), 2);
415 assert_eq!(entries[0].0, "application/json");
417 assert_eq!(entries[1].0, "text/plain");
418 }
419
420 #[test]
421 fn parse_accept_wildcard() {
422 let entries = parse_accept("*/*");
423 assert_eq!(entries.len(), 1);
424 assert_eq!(entries[0].0, "*/*");
425 }
426
427 #[test]
428 fn media_type_matches_exact() {
429 assert!(media_type_matches("application/json", "application/json"));
430 assert!(!media_type_matches("application/json", "text/plain"));
431 }
432
433 #[test]
434 fn media_type_matches_with_params() {
435 assert!(media_type_matches(
436 "text/plain",
437 "text/plain; charset=utf-8"
438 ));
439 }
440
441 #[test]
442 fn media_type_matches_wildcard() {
443 assert!(media_type_matches("*/*", "application/json"));
444 assert!(media_type_matches("text/*", "text/plain"));
445 assert!(!media_type_matches("text/*", "application/json"));
446 }
447
448 #[test]
449 fn best_match_no_accept() {
450 let supported = &["application/json", "text/plain"];
451 assert_eq!(best_match(None, supported), 0);
452 }
453
454 #[test]
455 fn best_match_wildcard() {
456 let supported = &["application/json", "text/plain"];
457 assert_eq!(best_match(Some("*/*"), supported), 0);
458 }
459
460 #[test]
461 fn best_match_specific() {
462 let supported = &["application/json", "text/plain; charset=utf-8"];
463 assert_eq!(best_match(Some("text/plain"), supported), 1);
464 }
465
466 #[test]
467 fn best_match_quality_order() {
468 let supported = &["application/json", "text/plain; charset=utf-8"];
469 assert_eq!(
470 best_match(Some("text/plain;q=0.9, application/json;q=0.5"), supported),
471 1
472 );
473 }
474
475 #[test]
476 fn render_as_json() {
477 let user = test_user();
478 let (bytes, ct) = <TestUser as RenderAs<JsonFormat>>::render(&user).unwrap();
479 assert_eq!(ct, "application/json");
480 let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
481 assert_eq!(parsed["name"], "Alice");
482 }
483
484 #[test]
485 fn render_as_text() {
486 let user = test_user();
487 let (bytes, ct) = <TestUser as RenderAs<TextFormat>>::render(&user).unwrap();
488 assert_eq!(ct, "text/plain; charset=utf-8");
489 assert_eq!(String::from_utf8(bytes).unwrap(), "User(1, Alice)");
490 }
491
492 #[test]
493 fn negotiate_json_when_accepted() {
494 let user = test_user();
495 let (bytes, ct) =
496 <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
497 &user,
498 Some("application/json"),
499 )
500 .unwrap();
501 assert_eq!(ct, "application/json");
502 let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
503 assert_eq!(parsed["id"], 1);
504 }
505
506 #[test]
507 fn negotiate_text_when_accepted() {
508 let user = test_user();
509 let (bytes, ct) =
510 <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
511 &user,
512 Some("text/plain"),
513 )
514 .unwrap();
515 assert_eq!(ct, "text/plain; charset=utf-8");
516 assert_eq!(String::from_utf8(bytes).unwrap(), "User(1, Alice)");
517 }
518
519 #[test]
520 fn negotiate_default_on_wildcard() {
521 let user = test_user();
522 let (_bytes, ct) =
523 <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
524 &user,
525 Some("*/*"),
526 )
527 .unwrap();
528 assert_eq!(ct, "application/json");
530 }
531
532 #[test]
533 fn negotiate_default_on_no_accept() {
534 let user = test_user();
535 let (_bytes, ct) =
536 <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
537 &user, None,
538 )
539 .unwrap();
540 assert_eq!(ct, "application/json");
541 }
542
543 #[test]
544 fn negotiated_response_into_response_json() {
545 let user = test_user();
546 let resp: NegotiatedResponse<TestUser, (JsonFormat, TextFormat)> =
547 NegotiatedResponse::new(user, Some("application/json".to_string()));
548 let http_resp = resp.into_response();
549 assert_eq!(http_resp.status(), StatusCode::OK);
550 assert_eq!(
551 http_resp.headers().get("content-type").unwrap(),
552 "application/json"
553 );
554 }
555
556 #[test]
557 fn negotiated_response_into_response_text() {
558 let user = test_user();
559 let resp: NegotiatedResponse<TestUser, (JsonFormat, TextFormat)> =
560 NegotiatedResponse::new(user, Some("text/plain".to_string()));
561 let http_resp = resp.into_response();
562 assert_eq!(http_resp.status(), StatusCode::OK);
563 assert_eq!(
564 http_resp.headers().get("content-type").unwrap(),
565 "text/plain; charset=utf-8"
566 );
567 }
568
569 #[test]
570 fn single_format_tuple() {
571 let user = test_user();
572 let (_bytes, ct) =
573 <(JsonFormat,) as NegotiateFormats<TestUser>>::negotiate_and_render(&user, None)
574 .unwrap();
575 assert_eq!(ct, "application/json");
576 }
577
578 #[test]
579 fn three_format_tuple() {
580 let user = test_user();
581 let (_, ct) = <(JsonFormat, TextFormat, JsonFormat) as NegotiateFormats<
582 TestUser,
583 >>::negotiate_and_render(&user, Some("text/plain"))
584 .unwrap();
585 assert_eq!(ct, "text/plain; charset=utf-8");
586 }
587
588 #[test]
589 fn supported_types_lists_all() {
590 let types = <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::supported_types();
591 assert_eq!(types, vec!["application/json", "text/plain; charset=utf-8"]);
592 }
593}