Skip to main content

typeway_server/
negotiate.rs

1//! Content negotiation engine and format wrappers.
2//!
3//! Provides format marker types ([`JsonFormat`], [`TextFormat`], [`HtmlFormat`],
4//! [`CsvFormat`]), the [`RenderAs`] trait for converting domain types into
5//! specific formats, and [`NegotiatedResponse`] which inspects the `Accept`
6//! header to pick the best representation.
7//!
8//! # Example
9//!
10//! ```ignore
11//! use typeway_server::negotiate::*;
12//!
13//! #[derive(serde::Serialize)]
14//! struct User { id: u32, name: String }
15//!
16//! impl std::fmt::Display for User {
17//!     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18//!         write!(f, "User({}, {})", self.id, self.name)
19//!     }
20//! }
21//!
22//! async fn get_user(accept: AcceptHeader) -> NegotiatedResponse<User, (JsonFormat, TextFormat)> {
23//!     NegotiatedResponse::new(User { id: 1, name: "Alice".into() }, accept.0)
24//! }
25//! ```
26
27use http::header::CONTENT_TYPE;
28use http::StatusCode;
29use typeway_core::negotiate::ContentFormat;
30
31use crate::body::{body_from_bytes, body_from_string, BoxBody};
32use crate::response::IntoResponse;
33
34// ---------------------------------------------------------------------------
35// Format marker types
36// ---------------------------------------------------------------------------
37
38/// JSON representation. Serializes via `serde_json`.
39pub struct JsonFormat;
40
41impl ContentFormat for JsonFormat {
42    const CONTENT_TYPE: &'static str = "application/json";
43}
44
45/// Plain text representation. Uses `Display`.
46pub struct TextFormat;
47
48impl ContentFormat for TextFormat {
49    const CONTENT_TYPE: &'static str = "text/plain; charset=utf-8";
50}
51
52/// HTML representation. Uses `Display` (intended for types that produce HTML).
53pub struct HtmlFormat;
54
55impl ContentFormat for HtmlFormat {
56    const CONTENT_TYPE: &'static str = "text/html; charset=utf-8";
57}
58
59/// CSV representation.
60pub struct CsvFormat;
61
62impl ContentFormat for CsvFormat {
63    const CONTENT_TYPE: &'static str = "text/csv";
64}
65
66/// XML representation. Requires explicit [`RenderAsXml`] impls per type.
67pub struct XmlFormat;
68
69impl ContentFormat for XmlFormat {
70    const CONTENT_TYPE: &'static str = "application/xml";
71}
72
73// ---------------------------------------------------------------------------
74// RenderAsXml trait
75// ---------------------------------------------------------------------------
76
77/// Trait for types that can render as XML.
78/// Unlike JsonFormat/TextFormat which have blanket impls, XML rendering
79/// requires an explicit impl per type since there's no standard XML serialization trait.
80pub trait RenderAsXml {
81    fn to_xml(&self) -> String;
82}
83
84impl<T: RenderAsXml> RenderAs<XmlFormat> for T {
85    fn render(&self) -> Result<(Vec<u8>, &'static str), String> {
86        Ok((self.to_xml().into_bytes(), XmlFormat::CONTENT_TYPE))
87    }
88}
89
90// ---------------------------------------------------------------------------
91// RenderAs trait
92// ---------------------------------------------------------------------------
93
94/// Convert a domain type into bytes for a specific content format.
95///
96/// Implement this trait to teach the negotiation engine how to serialize
97/// your type as a particular format.
98///
99/// Blanket implementations are provided for:
100/// - `RenderAs<JsonFormat>` for any `T: serde::Serialize`
101/// - `RenderAs<TextFormat>` for any `T: Display`
102pub trait RenderAs<Format: ContentFormat> {
103    /// Render this value into bytes and its content-type string.
104    fn render(&self) -> Result<(Vec<u8>, &'static str), String>;
105}
106
107impl<T: serde::Serialize> RenderAs<JsonFormat> for T {
108    fn render(&self) -> Result<(Vec<u8>, &'static str), String> {
109        let bytes = serde_json::to_vec(self).map_err(|e| e.to_string())?;
110        Ok((bytes, JsonFormat::CONTENT_TYPE))
111    }
112}
113
114impl<T: std::fmt::Display> RenderAs<TextFormat> for T {
115    fn render(&self) -> Result<(Vec<u8>, &'static str), String> {
116        Ok((self.to_string().into_bytes(), TextFormat::CONTENT_TYPE))
117    }
118}
119
120// ---------------------------------------------------------------------------
121// NegotiateFormats trait
122// ---------------------------------------------------------------------------
123
124/// Select the best format from a tuple of formats based on the `Accept` header
125/// and render the domain value.
126///
127/// Implemented for format tuples of arities 1 through 6 via macro.
128pub trait NegotiateFormats<T> {
129    /// All supported content types, in preference order.
130    fn supported_types() -> Vec<&'static str>;
131
132    /// Pick the best format for the given `Accept` header and render `value`.
133    fn negotiate_and_render(
134        value: &T,
135        accept: Option<&str>,
136    ) -> Result<(Vec<u8>, &'static str), String>;
137}
138
139/// Parse an `Accept` header value into a list of (media_type, quality) pairs,
140/// sorted by quality descending.
141fn parse_accept(accept: &str) -> Vec<(&str, f32)> {
142    let mut entries: Vec<(&str, f32)> = accept
143        .split(',')
144        .filter_map(|entry| {
145            let entry = entry.trim();
146            if entry.is_empty() {
147                return None;
148            }
149            let mut parts = entry.splitn(2, ';');
150            let media_type = parts.next()?.trim();
151            let quality = parts
152                .next()
153                .and_then(|params| {
154                    params.split(';').find_map(|p| {
155                        let p = p.trim();
156                        p.strip_prefix("q=")
157                            .and_then(|q| q.trim().parse::<f32>().ok())
158                    })
159                })
160                .unwrap_or(1.0);
161            Some((media_type, quality))
162        })
163        .collect();
164    entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
165    entries
166}
167
168/// Check whether a media type from the Accept header matches a supported
169/// content type. Supports wildcard matching (`*/*`, `application/*`).
170fn media_type_matches(accept_type: &str, supported: &str) -> bool {
171    if accept_type == "*/*" {
172        return true;
173    }
174    // Extract just the media type part (before any parameters like charset)
175    let supported_base = supported.split(';').next().unwrap_or(supported).trim();
176    if accept_type == supported_base {
177        return true;
178    }
179    // Check type/* wildcard (e.g. "text/*" matches "text/plain")
180    if let Some(prefix) = accept_type.strip_suffix("/*") {
181        if let Some(sup_prefix) = supported_base.split('/').next() {
182            return prefix == sup_prefix;
183        }
184    }
185    false
186}
187
188/// Implementation helper: given an Accept header and a list of supported types
189/// (in preference order), return the index of the best matching type.
190fn best_match(accept: Option<&str>, supported: &[&str]) -> usize {
191    let accept = match accept {
192        Some(a) if !a.is_empty() => a,
193        _ => return 0, // No Accept header -> use first (default) format
194    };
195
196    let entries = parse_accept(accept);
197
198    // For each Accept entry (sorted by quality), find the first supported type
199    // that matches.
200    for (media_type, _quality) in &entries {
201        for (idx, supported_type) in supported.iter().enumerate() {
202            if media_type_matches(media_type, supported_type) {
203                return idx;
204            }
205        }
206    }
207
208    // No match found -> default to first format
209    0
210}
211
212// Generate NegotiateFormats impls for tuples of arities 1-6.
213macro_rules! impl_negotiate_formats {
214    // Single format
215    ([$F1:ident], [$idx1:tt]) => {
216        impl<T, $F1> NegotiateFormats<T> for ($F1,)
217        where
218            $F1: ContentFormat,
219            T: RenderAs<$F1>,
220        {
221            fn supported_types() -> Vec<&'static str> {
222                vec![$F1::CONTENT_TYPE]
223            }
224
225            fn negotiate_and_render(
226                value: &T,
227                _accept: Option<&str>,
228            ) -> Result<(Vec<u8>, &'static str), String> {
229                <T as RenderAs<$F1>>::render(value)
230            }
231        }
232    };
233    // Multiple formats
234    ([$F1:ident $(, $FN:ident)*], [$idx1:tt $(, $idxN:tt)*]) => {
235        impl<T, $F1 $(, $FN)*> NegotiateFormats<T> for ($F1, $($FN,)*)
236        where
237            $F1: ContentFormat,
238            $($FN: ContentFormat,)*
239            T: RenderAs<$F1> $(+ RenderAs<$FN>)*,
240        {
241            fn supported_types() -> Vec<&'static str> {
242                vec![$F1::CONTENT_TYPE $(, $FN::CONTENT_TYPE)*]
243            }
244
245            fn negotiate_and_render(
246                value: &T,
247                accept: Option<&str>,
248            ) -> Result<(Vec<u8>, &'static str), String> {
249                let supported = [$F1::CONTENT_TYPE $(, $FN::CONTENT_TYPE)*];
250                let idx = best_match(accept, &supported);
251                // Dispatch to the correct RenderAs impl based on index.
252                let renderers: Vec<Box<dyn Fn(&T) -> Result<(Vec<u8>, &'static str), String>>> = vec![
253                    Box::new(|v| <T as RenderAs<$F1>>::render(v)),
254                    $(Box::new(|v| <T as RenderAs<$FN>>::render(v)),)*
255                ];
256                (renderers[idx])(value)
257            }
258        }
259    };
260}
261
262impl_negotiate_formats!([F1], [0]);
263impl_negotiate_formats!([F1, F2], [0, 1]);
264impl_negotiate_formats!([F1, F2, F3], [0, 1, 2]);
265impl_negotiate_formats!([F1, F2, F3, F4], [0, 1, 2, 3]);
266impl_negotiate_formats!([F1, F2, F3, F4, F5], [0, 1, 2, 3, 4]);
267impl_negotiate_formats!([F1, F2, F3, F4, F5, F6], [0, 1, 2, 3, 4, 5]);
268
269// ---------------------------------------------------------------------------
270// NegotiatedResponse
271// ---------------------------------------------------------------------------
272
273/// A response that holds a domain value and negotiates the best content type
274/// based on the `Accept` header.
275///
276/// `T` is the domain type, `Formats` is a tuple of format markers
277/// (e.g., `(JsonFormat, TextFormat)`).
278///
279/// Implements [`IntoResponse`] when `Formats: NegotiateFormats<T>`.
280pub struct NegotiatedResponse<T, Formats> {
281    value: T,
282    accept: Option<String>,
283    _formats: std::marker::PhantomData<Formats>,
284}
285
286impl<T, Formats> NegotiatedResponse<T, Formats> {
287    /// Create a new negotiated response.
288    ///
289    /// `accept` should be the value of the `Accept` header from the request,
290    /// or `None` if absent. Use the [`AcceptHeader`] extractor to obtain this.
291    pub fn new(value: T, accept: Option<String>) -> Self {
292        NegotiatedResponse {
293            value,
294            accept,
295            _formats: std::marker::PhantomData,
296        }
297    }
298}
299
300impl<T, Formats> IntoResponse for NegotiatedResponse<T, Formats>
301where
302    Formats: NegotiateFormats<T>,
303{
304    fn into_response(self) -> http::Response<BoxBody> {
305        match Formats::negotiate_and_render(&self.value, self.accept.as_deref()) {
306            Ok((body_bytes, content_type)) => {
307                let body = body_from_bytes(bytes::Bytes::from(body_bytes));
308                let mut res = http::Response::new(body);
309                if let Ok(val) = http::HeaderValue::from_str(content_type) {
310                    res.headers_mut().insert(CONTENT_TYPE, val);
311                }
312                res
313            }
314            Err(e) => {
315                let mut res =
316                    http::Response::new(body_from_string(format!("negotiation error: {e}")));
317                *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
318                res
319            }
320        }
321    }
322}
323
324// ---------------------------------------------------------------------------
325// AcceptHeader extractor
326// ---------------------------------------------------------------------------
327
328/// Extracts the `Accept` header value from the request.
329///
330/// Use this in handler arguments to pass into [`NegotiatedResponse::new`].
331///
332/// # Example
333///
334/// ```ignore
335/// use typeway_server::negotiate::*;
336///
337/// async fn handler(accept: AcceptHeader) -> NegotiatedResponse<MyType, (JsonFormat, TextFormat)> {
338///     NegotiatedResponse::new(my_value, accept.0)
339/// }
340/// ```
341pub struct AcceptHeader(pub Option<String>);
342
343impl crate::extract::FromRequestParts for AcceptHeader {
344    type Error = std::convert::Infallible;
345
346    fn from_request_parts(parts: &http::request::Parts) -> Result<Self, Self::Error> {
347        let accept = parts
348            .headers
349            .get(http::header::ACCEPT)
350            .and_then(|v| v.to_str().ok())
351            .map(|s| s.to_string());
352        Ok(AcceptHeader(accept))
353    }
354}
355
356// Infallible always succeeds, but we need IntoResponse for the trait bound.
357impl IntoResponse for std::convert::Infallible {
358    fn into_response(self) -> http::Response<BoxBody> {
359        match self {}
360    }
361}
362
363// ---------------------------------------------------------------------------
364// Convenience function
365// ---------------------------------------------------------------------------
366
367/// Wrap a domain value for content negotiation with the given Accept header.
368///
369/// # Example
370///
371/// ```ignore
372/// async fn get_user(accept: AcceptHeader) -> NegotiatedResponse<User, (JsonFormat, TextFormat)> {
373///     negotiated(user, accept)
374/// }
375/// ```
376pub fn negotiated<T, Formats>(value: T, accept: AcceptHeader) -> NegotiatedResponse<T, Formats> {
377    NegotiatedResponse::new(value, accept.0)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[derive(serde::Serialize)]
385    struct TestUser {
386        id: u32,
387        name: String,
388    }
389
390    impl std::fmt::Display for TestUser {
391        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392            write!(f, "User({}, {})", self.id, self.name)
393        }
394    }
395
396    fn test_user() -> TestUser {
397        TestUser {
398            id: 1,
399            name: "Alice".to_string(),
400        }
401    }
402
403    #[test]
404    fn parse_accept_simple() {
405        let entries = parse_accept("application/json");
406        assert_eq!(entries.len(), 1);
407        assert_eq!(entries[0].0, "application/json");
408        assert!((entries[0].1 - 1.0).abs() < f32::EPSILON);
409    }
410
411    #[test]
412    fn parse_accept_with_quality() {
413        let entries = parse_accept("text/plain;q=0.5, application/json;q=0.9");
414        assert_eq!(entries.len(), 2);
415        // Sorted by quality descending
416        assert_eq!(entries[0].0, "application/json");
417        assert_eq!(entries[1].0, "text/plain");
418    }
419
420    #[test]
421    fn parse_accept_wildcard() {
422        let entries = parse_accept("*/*");
423        assert_eq!(entries.len(), 1);
424        assert_eq!(entries[0].0, "*/*");
425    }
426
427    #[test]
428    fn media_type_matches_exact() {
429        assert!(media_type_matches("application/json", "application/json"));
430        assert!(!media_type_matches("application/json", "text/plain"));
431    }
432
433    #[test]
434    fn media_type_matches_with_params() {
435        assert!(media_type_matches(
436            "text/plain",
437            "text/plain; charset=utf-8"
438        ));
439    }
440
441    #[test]
442    fn media_type_matches_wildcard() {
443        assert!(media_type_matches("*/*", "application/json"));
444        assert!(media_type_matches("text/*", "text/plain"));
445        assert!(!media_type_matches("text/*", "application/json"));
446    }
447
448    #[test]
449    fn best_match_no_accept() {
450        let supported = &["application/json", "text/plain"];
451        assert_eq!(best_match(None, supported), 0);
452    }
453
454    #[test]
455    fn best_match_wildcard() {
456        let supported = &["application/json", "text/plain"];
457        assert_eq!(best_match(Some("*/*"), supported), 0);
458    }
459
460    #[test]
461    fn best_match_specific() {
462        let supported = &["application/json", "text/plain; charset=utf-8"];
463        assert_eq!(best_match(Some("text/plain"), supported), 1);
464    }
465
466    #[test]
467    fn best_match_quality_order() {
468        let supported = &["application/json", "text/plain; charset=utf-8"];
469        assert_eq!(
470            best_match(Some("text/plain;q=0.9, application/json;q=0.5"), supported),
471            1
472        );
473    }
474
475    #[test]
476    fn render_as_json() {
477        let user = test_user();
478        let (bytes, ct) = <TestUser as RenderAs<JsonFormat>>::render(&user).unwrap();
479        assert_eq!(ct, "application/json");
480        let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
481        assert_eq!(parsed["name"], "Alice");
482    }
483
484    #[test]
485    fn render_as_text() {
486        let user = test_user();
487        let (bytes, ct) = <TestUser as RenderAs<TextFormat>>::render(&user).unwrap();
488        assert_eq!(ct, "text/plain; charset=utf-8");
489        assert_eq!(String::from_utf8(bytes).unwrap(), "User(1, Alice)");
490    }
491
492    #[test]
493    fn negotiate_json_when_accepted() {
494        let user = test_user();
495        let (bytes, ct) =
496            <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
497                &user,
498                Some("application/json"),
499            )
500            .unwrap();
501        assert_eq!(ct, "application/json");
502        let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
503        assert_eq!(parsed["id"], 1);
504    }
505
506    #[test]
507    fn negotiate_text_when_accepted() {
508        let user = test_user();
509        let (bytes, ct) =
510            <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
511                &user,
512                Some("text/plain"),
513            )
514            .unwrap();
515        assert_eq!(ct, "text/plain; charset=utf-8");
516        assert_eq!(String::from_utf8(bytes).unwrap(), "User(1, Alice)");
517    }
518
519    #[test]
520    fn negotiate_default_on_wildcard() {
521        let user = test_user();
522        let (_bytes, ct) =
523            <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
524                &user,
525                Some("*/*"),
526            )
527            .unwrap();
528        // Default to first format (JSON)
529        assert_eq!(ct, "application/json");
530    }
531
532    #[test]
533    fn negotiate_default_on_no_accept() {
534        let user = test_user();
535        let (_bytes, ct) =
536            <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::negotiate_and_render(
537                &user, None,
538            )
539            .unwrap();
540        assert_eq!(ct, "application/json");
541    }
542
543    #[test]
544    fn negotiated_response_into_response_json() {
545        let user = test_user();
546        let resp: NegotiatedResponse<TestUser, (JsonFormat, TextFormat)> =
547            NegotiatedResponse::new(user, Some("application/json".to_string()));
548        let http_resp = resp.into_response();
549        assert_eq!(http_resp.status(), StatusCode::OK);
550        assert_eq!(
551            http_resp.headers().get("content-type").unwrap(),
552            "application/json"
553        );
554    }
555
556    #[test]
557    fn negotiated_response_into_response_text() {
558        let user = test_user();
559        let resp: NegotiatedResponse<TestUser, (JsonFormat, TextFormat)> =
560            NegotiatedResponse::new(user, Some("text/plain".to_string()));
561        let http_resp = resp.into_response();
562        assert_eq!(http_resp.status(), StatusCode::OK);
563        assert_eq!(
564            http_resp.headers().get("content-type").unwrap(),
565            "text/plain; charset=utf-8"
566        );
567    }
568
569    #[test]
570    fn single_format_tuple() {
571        let user = test_user();
572        let (_bytes, ct) =
573            <(JsonFormat,) as NegotiateFormats<TestUser>>::negotiate_and_render(&user, None)
574                .unwrap();
575        assert_eq!(ct, "application/json");
576    }
577
578    #[test]
579    fn three_format_tuple() {
580        let user = test_user();
581        let (_, ct) = <(JsonFormat, TextFormat, JsonFormat) as NegotiateFormats<
582            TestUser,
583        >>::negotiate_and_render(&user, Some("text/plain"))
584        .unwrap();
585        assert_eq!(ct, "text/plain; charset=utf-8");
586    }
587
588    #[test]
589    fn supported_types_lists_all() {
590        let types = <(JsonFormat, TextFormat) as NegotiateFormats<TestUser>>::supported_types();
591        assert_eq!(types, vec!["application/json", "text/plain; charset=utf-8"]);
592    }
593}