Skip to main content

systemprompt_api/services/middleware/negotiation/
mod.rs

1//! Content-negotiation middleware.
2//!
3//! Parses the `Accept` header into an [`AcceptedFormat`] (one of
4//! [`AcceptedMediaType`]) honouring `q=` quality weights, and stores it in the
5//! request extensions so handlers can serve JSON, Markdown, or HTML from a
6//! single route.
7
8use axum::extract::Request;
9use axum::middleware::Next;
10use axum::response::Response;
11
12#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
13pub enum AcceptedMediaType {
14    #[default]
15    Json,
16    Markdown,
17    Html,
18}
19
20impl AcceptedMediaType {
21    pub const fn content_type(&self) -> &'static str {
22        match self {
23            Self::Json => "application/json",
24            Self::Markdown => "text/markdown; charset=utf-8",
25            Self::Html => "text/html; charset=utf-8",
26        }
27    }
28
29    pub const fn is_markdown(&self) -> bool {
30        matches!(self, Self::Markdown)
31    }
32}
33
34#[derive(Debug, Clone, Copy)]
35pub struct AcceptedFormat(pub AcceptedMediaType);
36
37impl Default for AcceptedFormat {
38    fn default() -> Self {
39        Self(AcceptedMediaType::Json)
40    }
41}
42
43impl AcceptedFormat {
44    pub const fn media_type(&self) -> AcceptedMediaType {
45        self.0
46    }
47
48    pub const fn is_markdown(&self) -> bool {
49        self.0.is_markdown()
50    }
51}
52
53struct MediaTypeEntry {
54    media_type: AcceptedMediaType,
55    quality: f32,
56}
57
58pub fn parse_accept_header(header_value: &str) -> AcceptedFormat {
59    let mut entries = Vec::new();
60
61    for part in header_value.split(',') {
62        let part = part.trim();
63        if part.is_empty() {
64            continue;
65        }
66
67        let (media_type_str, params) = part
68            .split_once(';')
69            .map_or((part, ""), |(m, p)| (m.trim(), p));
70
71        let quality = params
72            .split(';')
73            .find_map(|p| {
74                let p = p.trim();
75                p.strip_prefix("q=")
76                    .and_then(|q_str| q_str.parse::<f32>().ok().map(|q| q.clamp(0.0, 1.0)))
77            })
78            .unwrap_or(1.0);
79
80        let media_type = match media_type_str.to_lowercase().as_str() {
81            "text/markdown" | "text/x-markdown" => Some(AcceptedMediaType::Markdown),
82            "application/json" | "*/*" => Some(AcceptedMediaType::Json),
83            "text/html" | "application/xhtml+xml" => Some(AcceptedMediaType::Html),
84            _ => None,
85        };
86
87        if let Some(mt) = media_type {
88            entries.push(MediaTypeEntry {
89                media_type: mt,
90                quality,
91            });
92        }
93    }
94
95    entries.sort_by(|a, b| {
96        b.quality
97            .partial_cmp(&a.quality)
98            .unwrap_or(std::cmp::Ordering::Equal)
99    });
100
101    let media_type = entries
102        .first()
103        .map_or(AcceptedMediaType::Json, |e| e.media_type);
104
105    AcceptedFormat(media_type)
106}
107
108pub async fn content_negotiation_middleware(mut request: Request, next: Next) -> Response {
109    let accepted_format = request
110        .headers()
111        .get(http::header::ACCEPT)
112        .and_then(|v| v.to_str().ok())
113        .map_or_else(AcceptedFormat::default, parse_accept_header);
114
115    request.extensions_mut().insert(accepted_format);
116
117    next.run(request).await
118}