Skip to main content

systemprompt_api/services/middleware/negotiation/
mod.rs

1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::Response;
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
6pub enum AcceptedMediaType {
7    #[default]
8    Json,
9    Markdown,
10    Html,
11}
12
13impl AcceptedMediaType {
14    pub const fn content_type(&self) -> &'static str {
15        match self {
16            Self::Json => "application/json",
17            Self::Markdown => "text/markdown; charset=utf-8",
18            Self::Html => "text/html; charset=utf-8",
19        }
20    }
21
22    pub const fn is_markdown(&self) -> bool {
23        matches!(self, Self::Markdown)
24    }
25}
26
27#[derive(Debug, Clone, Copy)]
28pub struct AcceptedFormat(pub AcceptedMediaType);
29
30impl Default for AcceptedFormat {
31    fn default() -> Self {
32        Self(AcceptedMediaType::Json)
33    }
34}
35
36impl AcceptedFormat {
37    pub const fn media_type(&self) -> AcceptedMediaType {
38        self.0
39    }
40
41    pub const fn is_markdown(&self) -> bool {
42        self.0.is_markdown()
43    }
44}
45
46struct MediaTypeEntry {
47    media_type: AcceptedMediaType,
48    quality: f32,
49}
50
51pub fn parse_accept_header(header_value: &str) -> AcceptedFormat {
52    let mut entries = Vec::new();
53
54    for part in header_value.split(',') {
55        let part = part.trim();
56        if part.is_empty() {
57            continue;
58        }
59
60        let (media_type_str, params) = part
61            .split_once(';')
62            .map_or((part, ""), |(m, p)| (m.trim(), p));
63
64        let quality = params
65            .split(';')
66            .find_map(|p| {
67                let p = p.trim();
68                p.strip_prefix("q=")
69                    .and_then(|q_str| q_str.parse::<f32>().ok().map(|q| q.clamp(0.0, 1.0)))
70            })
71            .unwrap_or(1.0);
72
73        let media_type = match media_type_str.to_lowercase().as_str() {
74            "text/markdown" | "text/x-markdown" => Some(AcceptedMediaType::Markdown),
75            "application/json" | "*/*" => Some(AcceptedMediaType::Json),
76            "text/html" | "application/xhtml+xml" => Some(AcceptedMediaType::Html),
77            _ => None,
78        };
79
80        if let Some(mt) = media_type {
81            entries.push(MediaTypeEntry {
82                media_type: mt,
83                quality,
84            });
85        }
86    }
87
88    entries.sort_by(|a, b| {
89        b.quality
90            .partial_cmp(&a.quality)
91            .unwrap_or(std::cmp::Ordering::Equal)
92    });
93
94    let media_type = entries
95        .first()
96        .map_or(AcceptedMediaType::Json, |e| e.media_type);
97
98    AcceptedFormat(media_type)
99}
100
101pub async fn content_negotiation_middleware(mut request: Request, next: Next) -> Response {
102    let accepted_format = request
103        .headers()
104        .get(http::header::ACCEPT)
105        .and_then(|v| v.to_str().ok())
106        .map_or_else(AcceptedFormat::default, parse_accept_header);
107
108    request.extensions_mut().insert(accepted_format);
109
110    next.run(request).await
111}