systemprompt_security/extraction/
token.rs1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5const DEFAULT_COOKIE_NAME: &str = "access_token";
6const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
7const BEARER_PREFIX: &str = "Bearer ";
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ExtractionMethod {
11 AuthorizationHeader,
12 McpProxyHeader,
13 Cookie,
14}
15
16impl fmt::Display for ExtractionMethod {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 match self {
19 Self::AuthorizationHeader => write!(f, "Authorization header"),
20 Self::McpProxyHeader => write!(f, "MCP proxy header"),
21 Self::Cookie => write!(f, "Cookie"),
22 }
23 }
24}
25
26#[derive(Debug, Clone)]
27pub struct TokenExtractor {
28 fallback_chain: Vec<ExtractionMethod>,
29 cookie_name: String,
30 mcp_header_name: String,
31}
32
33impl TokenExtractor {
34 #[must_use]
35 pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
36 Self {
37 fallback_chain,
38 cookie_name: DEFAULT_COOKIE_NAME.to_string(),
39 mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_string(),
40 }
41 }
42
43 #[must_use]
44 pub fn with_cookie_name(mut self, name: String) -> Self {
45 self.cookie_name = name;
46 self
47 }
48
49 #[must_use]
50 pub fn with_mcp_header_name(mut self, name: String) -> Self {
51 self.mcp_header_name = name;
52 self
53 }
54
55 #[must_use]
56 pub fn standard() -> Self {
57 Self::new(vec![
58 ExtractionMethod::AuthorizationHeader,
59 ExtractionMethod::McpProxyHeader,
60 ExtractionMethod::Cookie,
61 ])
62 }
63
64 #[must_use]
65 pub fn browser_only() -> Self {
66 Self::new(vec![
67 ExtractionMethod::AuthorizationHeader,
68 ExtractionMethod::Cookie,
69 ])
70 }
71
72 #[must_use]
73 pub fn api_only() -> Self {
74 Self::new(vec![ExtractionMethod::AuthorizationHeader])
75 }
76
77 #[must_use]
78 pub fn chain(&self) -> &[ExtractionMethod] {
79 &self.fallback_chain
80 }
81
82 pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
83 for method in &self.fallback_chain {
84 match method {
85 ExtractionMethod::AuthorizationHeader => {
86 if let Ok(token) = Self::extract_from_authorization(headers) {
87 return Ok(token);
88 }
89 },
90 ExtractionMethod::McpProxyHeader => {
91 if let Ok(token) = self.extract_from_mcp_proxy(headers) {
92 return Ok(token);
93 }
94 },
95 ExtractionMethod::Cookie => {
96 if let Ok(token) = self.extract_from_cookie(headers) {
97 return Ok(token);
98 }
99 },
100 }
101 }
102
103 Err(TokenExtractionError::NoTokenFound)
104 }
105
106 pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
107 let auth_headers = headers.get_all("authorization");
108
109 if auth_headers.iter().count() == 0 {
110 return Err(TokenExtractionError::MissingAuthorizationHeader);
111 }
112
113 for auth_value in &auth_headers {
114 let Ok(auth_header) = auth_value.to_str().map_err(|e| {
115 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
116 e
117 }) else {
118 continue;
119 };
120
121 if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
122 let token = token.trim();
123 if !token.is_empty() {
124 return Ok(token.to_string());
125 }
126 }
127 }
128
129 Err(TokenExtractionError::InvalidAuthorizationFormat)
130 }
131
132 pub fn extract_from_mcp_proxy(
133 &self,
134 headers: &HeaderMap,
135 ) -> Result<String, TokenExtractionError> {
136 let header_value = headers
137 .get(&self.mcp_header_name)
138 .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
139
140 let auth_header = header_value
141 .to_str()
142 .map_err(|_| TokenExtractionError::InvalidMcpProxyFormat)?;
143
144 auth_header
145 .strip_prefix(BEARER_PREFIX)
146 .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
147 .map(ToString::to_string)
148 }
149
150 pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
151 let cookie_header = headers
152 .get("cookie")
153 .ok_or(TokenExtractionError::MissingCookie)?
154 .to_str()
155 .map_err(|_| TokenExtractionError::InvalidCookieFormat)?;
156
157 for cookie in cookie_header.split(';') {
158 let cookie = cookie.trim();
159 let cookie_prefix = format!("{}=", self.cookie_name);
160 if let Some(value) = cookie.strip_prefix(&cookie_prefix) {
161 if !value.is_empty() {
162 return Ok(value.to_string());
163 }
164 }
165 }
166
167 Err(TokenExtractionError::TokenNotFoundInCookie)
168 }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub enum TokenExtractionError {
173 NoTokenFound,
174 MissingAuthorizationHeader,
175 InvalidAuthorizationFormat,
176 MissingMcpProxyHeader,
177 InvalidMcpProxyFormat,
178 MissingCookie,
179 InvalidCookieFormat,
180 TokenNotFoundInCookie,
181}
182
183impl fmt::Display for TokenExtractionError {
184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185 match self {
186 Self::NoTokenFound => write!(f, "No token found in request"),
187 Self::MissingAuthorizationHeader => {
188 write!(f, "Missing Authorization header")
189 },
190 Self::InvalidAuthorizationFormat => {
191 write!(
192 f,
193 "Invalid Authorization header format (expected 'Bearer <token>')"
194 )
195 },
196 Self::MissingMcpProxyHeader => {
197 write!(f, "Missing MCP proxy authorization header")
198 },
199 Self::InvalidMcpProxyFormat => {
200 write!(
201 f,
202 "Invalid MCP proxy header format (expected 'Bearer <token>')"
203 )
204 },
205 Self::MissingCookie => write!(f, "Missing cookie header"),
206 Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
207 Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
208 }
209 }
210}
211
212impl Error for TokenExtractionError {}