systemprompt_api/services/middleware/negotiation/
mod.rs1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::Response;
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
6pub enum AcceptedMediaType {
7 #[default]
8 Json,
9 Markdown,
10 Html,
11}
12
13impl AcceptedMediaType {
14 pub const fn content_type(&self) -> &'static str {
15 match self {
16 Self::Json => "application/json",
17 Self::Markdown => "text/markdown; charset=utf-8",
18 Self::Html => "text/html; charset=utf-8",
19 }
20 }
21
22 pub const fn is_markdown(&self) -> bool {
23 matches!(self, Self::Markdown)
24 }
25}
26
27#[derive(Debug, Clone, Copy)]
28pub struct AcceptedFormat(pub AcceptedMediaType);
29
30impl Default for AcceptedFormat {
31 fn default() -> Self {
32 Self(AcceptedMediaType::Json)
33 }
34}
35
36impl AcceptedFormat {
37 pub const fn media_type(&self) -> AcceptedMediaType {
38 self.0
39 }
40
41 pub const fn is_markdown(&self) -> bool {
42 self.0.is_markdown()
43 }
44}
45
46struct MediaTypeEntry {
47 media_type: AcceptedMediaType,
48 quality: f32,
49}
50
51fn parse_accept_header(header_value: &str) -> AcceptedFormat {
52 let mut entries = Vec::new();
53
54 for part in header_value.split(',') {
55 let part = part.trim();
56 if part.is_empty() {
57 continue;
58 }
59
60 let (media_type_str, params) = part
61 .split_once(';')
62 .map_or((part, ""), |(m, p)| (m.trim(), p));
63
64 let quality = params
65 .split(';')
66 .find_map(|p| {
67 let p = p.trim();
68 if let Some(q_str) = p.strip_prefix("q=") {
69 q_str.parse::<f32>().ok().map(|q| q.clamp(0.0, 1.0))
70 } else {
71 None
72 }
73 })
74 .unwrap_or(1.0);
75
76 let media_type = match media_type_str.to_lowercase().as_str() {
77 "text/markdown" | "text/x-markdown" => Some(AcceptedMediaType::Markdown),
78 "application/json" | "*/*" => Some(AcceptedMediaType::Json),
79 "text/html" | "application/xhtml+xml" => Some(AcceptedMediaType::Html),
80 _ => None,
81 };
82
83 if let Some(mt) = media_type {
84 entries.push(MediaTypeEntry {
85 media_type: mt,
86 quality,
87 });
88 }
89 }
90
91 entries.sort_by(|a, b| {
92 b.quality
93 .partial_cmp(&a.quality)
94 .unwrap_or(std::cmp::Ordering::Equal)
95 });
96
97 let media_type = entries
98 .first()
99 .map_or(AcceptedMediaType::Json, |e| e.media_type);
100
101 AcceptedFormat(media_type)
102}
103
104pub async fn content_negotiation_middleware(mut request: Request, next: Next) -> Response {
105 let accepted_format = request
106 .headers()
107 .get(http::header::ACCEPT)
108 .and_then(|v| v.to_str().ok())
109 .map_or_else(AcceptedFormat::default, parse_accept_header);
110
111 request.extensions_mut().insert(accepted_format);
112
113 next.run(request).await
114}