systemprompt_api/services/middleware/negotiation/
mod.rs1use axum::extract::Request;
9use axum::middleware::Next;
10use axum::response::Response;
11
12#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
13pub enum AcceptedMediaType {
14 #[default]
15 Json,
16 Markdown,
17 Html,
18}
19
20impl AcceptedMediaType {
21 pub const fn content_type(&self) -> &'static str {
22 match self {
23 Self::Json => "application/json",
24 Self::Markdown => "text/markdown; charset=utf-8",
25 Self::Html => "text/html; charset=utf-8",
26 }
27 }
28
29 pub const fn is_markdown(&self) -> bool {
30 matches!(self, Self::Markdown)
31 }
32}
33
34#[derive(Debug, Clone, Copy)]
35pub struct AcceptedFormat(pub AcceptedMediaType);
36
37impl Default for AcceptedFormat {
38 fn default() -> Self {
39 Self(AcceptedMediaType::Json)
40 }
41}
42
43impl AcceptedFormat {
44 pub const fn media_type(&self) -> AcceptedMediaType {
45 self.0
46 }
47
48 pub const fn is_markdown(&self) -> bool {
49 self.0.is_markdown()
50 }
51}
52
53struct MediaTypeEntry {
54 media_type: AcceptedMediaType,
55 quality: f32,
56}
57
58pub fn parse_accept_header(header_value: &str) -> AcceptedFormat {
59 let mut entries = Vec::new();
60
61 for part in header_value.split(',') {
62 let part = part.trim();
63 if part.is_empty() {
64 continue;
65 }
66
67 let (media_type_str, params) = part
68 .split_once(';')
69 .map_or((part, ""), |(m, p)| (m.trim(), p));
70
71 let quality = params
72 .split(';')
73 .find_map(|p| {
74 let p = p.trim();
75 p.strip_prefix("q=")
76 .and_then(|q_str| q_str.parse::<f32>().ok().map(|q| q.clamp(0.0, 1.0)))
77 })
78 .unwrap_or(1.0);
79
80 let media_type = match media_type_str.to_lowercase().as_str() {
81 "text/markdown" | "text/x-markdown" => Some(AcceptedMediaType::Markdown),
82 "application/json" | "*/*" => Some(AcceptedMediaType::Json),
83 "text/html" | "application/xhtml+xml" => Some(AcceptedMediaType::Html),
84 _ => None,
85 };
86
87 if let Some(mt) = media_type {
88 entries.push(MediaTypeEntry {
89 media_type: mt,
90 quality,
91 });
92 }
93 }
94
95 entries.sort_by(|a, b| {
96 b.quality
97 .partial_cmp(&a.quality)
98 .unwrap_or(std::cmp::Ordering::Equal)
99 });
100
101 let media_type = entries
102 .first()
103 .map_or(AcceptedMediaType::Json, |e| e.media_type);
104
105 AcceptedFormat(media_type)
106}
107
108pub async fn content_negotiation_middleware(mut request: Request, next: Next) -> Response {
109 let accepted_format = request
110 .headers()
111 .get(http::header::ACCEPT)
112 .and_then(|v| v.to_str().ok())
113 .map_or_else(AcceptedFormat::default, parse_accept_header);
114
115 request.extensions_mut().insert(accepted_format);
116
117 next.run(request).await
118}