systemprompt_api/services/middleware/negotiation/
mod.rs1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::Response;
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
6pub enum AcceptedMediaType {
7 #[default]
8 Json,
9 Markdown,
10 Html,
11}
12
13impl AcceptedMediaType {
14 pub const fn content_type(&self) -> &'static str {
15 match self {
16 Self::Json => "application/json",
17 Self::Markdown => "text/markdown; charset=utf-8",
18 Self::Html => "text/html; charset=utf-8",
19 }
20 }
21
22 pub const fn is_markdown(&self) -> bool {
23 matches!(self, Self::Markdown)
24 }
25}
26
27#[derive(Debug, Clone, Copy)]
28pub struct AcceptedFormat(pub AcceptedMediaType);
29
30impl Default for AcceptedFormat {
31 fn default() -> Self {
32 Self(AcceptedMediaType::Json)
33 }
34}
35
36impl AcceptedFormat {
37 pub const fn media_type(&self) -> AcceptedMediaType {
38 self.0
39 }
40
41 pub const fn is_markdown(&self) -> bool {
42 self.0.is_markdown()
43 }
44}
45
46struct MediaTypeEntry {
47 media_type: AcceptedMediaType,
48 quality: f32,
49}
50
51pub fn parse_accept_header(header_value: &str) -> AcceptedFormat {
52 let mut entries = Vec::new();
53
54 for part in header_value.split(',') {
55 let part = part.trim();
56 if part.is_empty() {
57 continue;
58 }
59
60 let (media_type_str, params) = part
61 .split_once(';')
62 .map_or((part, ""), |(m, p)| (m.trim(), p));
63
64 let quality = params
65 .split(';')
66 .find_map(|p| {
67 let p = p.trim();
68 p.strip_prefix("q=")
69 .and_then(|q_str| q_str.parse::<f32>().ok().map(|q| q.clamp(0.0, 1.0)))
70 })
71 .unwrap_or(1.0);
72
73 let media_type = match media_type_str.to_lowercase().as_str() {
74 "text/markdown" | "text/x-markdown" => Some(AcceptedMediaType::Markdown),
75 "application/json" | "*/*" => Some(AcceptedMediaType::Json),
76 "text/html" | "application/xhtml+xml" => Some(AcceptedMediaType::Html),
77 _ => None,
78 };
79
80 if let Some(mt) = media_type {
81 entries.push(MediaTypeEntry {
82 media_type: mt,
83 quality,
84 });
85 }
86 }
87
88 entries.sort_by(|a, b| {
89 b.quality
90 .partial_cmp(&a.quality)
91 .unwrap_or(std::cmp::Ordering::Equal)
92 });
93
94 let media_type = entries
95 .first()
96 .map_or(AcceptedMediaType::Json, |e| e.media_type);
97
98 AcceptedFormat(media_type)
99}
100
101pub async fn content_negotiation_middleware(mut request: Request, next: Next) -> Response {
102 let accepted_format = request
103 .headers()
104 .get(http::header::ACCEPT)
105 .and_then(|v| v.to_str().ok())
106 .map_or_else(AcceptedFormat::default, parse_accept_header);
107
108 request.extensions_mut().insert(accepted_format);
109
110 next.run(request).await
111}