Skip to main content

systemprompt_api/services/middleware/negotiation/
mod.rs

1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::Response;
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
6pub enum AcceptedMediaType {
7    #[default]
8    Json,
9    Markdown,
10    Html,
11}
12
13impl AcceptedMediaType {
14    pub const fn content_type(&self) -> &'static str {
15        match self {
16            Self::Json => "application/json",
17            Self::Markdown => "text/markdown; charset=utf-8",
18            Self::Html => "text/html; charset=utf-8",
19        }
20    }
21
22    pub const fn is_markdown(&self) -> bool {
23        matches!(self, Self::Markdown)
24    }
25}
26
27#[derive(Debug, Clone, Copy)]
28pub struct AcceptedFormat(pub AcceptedMediaType);
29
30impl Default for AcceptedFormat {
31    fn default() -> Self {
32        Self(AcceptedMediaType::Json)
33    }
34}
35
36impl AcceptedFormat {
37    pub const fn media_type(&self) -> AcceptedMediaType {
38        self.0
39    }
40
41    pub const fn is_markdown(&self) -> bool {
42        self.0.is_markdown()
43    }
44}
45
46struct MediaTypeEntry {
47    media_type: AcceptedMediaType,
48    quality: f32,
49}
50
51fn parse_accept_header(header_value: &str) -> AcceptedFormat {
52    let mut entries = Vec::new();
53
54    for part in header_value.split(',') {
55        let part = part.trim();
56        if part.is_empty() {
57            continue;
58        }
59
60        let (media_type_str, params) = part
61            .split_once(';')
62            .map_or((part, ""), |(m, p)| (m.trim(), p));
63
64        let quality = params
65            .split(';')
66            .find_map(|p| {
67                let p = p.trim();
68                if let Some(q_str) = p.strip_prefix("q=") {
69                    q_str.parse::<f32>().ok().map(|q| q.clamp(0.0, 1.0))
70                } else {
71                    None
72                }
73            })
74            .unwrap_or(1.0);
75
76        let media_type = match media_type_str.to_lowercase().as_str() {
77            "text/markdown" | "text/x-markdown" => Some(AcceptedMediaType::Markdown),
78            "application/json" | "*/*" => Some(AcceptedMediaType::Json),
79            "text/html" | "application/xhtml+xml" => Some(AcceptedMediaType::Html),
80            _ => None,
81        };
82
83        if let Some(mt) = media_type {
84            entries.push(MediaTypeEntry {
85                media_type: mt,
86                quality,
87            });
88        }
89    }
90
91    entries.sort_by(|a, b| {
92        b.quality
93            .partial_cmp(&a.quality)
94            .unwrap_or(std::cmp::Ordering::Equal)
95    });
96
97    let media_type = entries
98        .first()
99        .map_or(AcceptedMediaType::Json, |e| e.media_type);
100
101    AcceptedFormat(media_type)
102}
103
104pub async fn content_negotiation_middleware(mut request: Request, next: Next) -> Response {
105    let accepted_format = request
106        .headers()
107        .get(http::header::ACCEPT)
108        .and_then(|v| v.to_str().ok())
109        .map_or_else(AcceptedFormat::default, parse_accept_header);
110
111    request.extensions_mut().insert(accepted_format);
112
113    next.run(request).await
114}