systemprompt_security/extraction/
token.rs1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5use super::cookie::{CookieExtractionError, CookieExtractor};
6
7const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
8const BEARER_PREFIX: &str = "Bearer ";
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ExtractionMethod {
12 AuthorizationHeader,
13 McpProxyHeader,
14 Cookie,
15}
16
17impl fmt::Display for ExtractionMethod {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 match self {
20 Self::AuthorizationHeader => write!(f, "Authorization header"),
21 Self::McpProxyHeader => write!(f, "MCP proxy header"),
22 Self::Cookie => write!(f, "Cookie"),
23 }
24 }
25}
26
27#[derive(Debug, Clone)]
28pub struct TokenExtractor {
29 fallback_chain: Vec<ExtractionMethod>,
30 cookie_name: String,
31 mcp_header_name: String,
32}
33
34impl TokenExtractor {
35 #[must_use]
36 pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
37 Self {
38 fallback_chain,
39 cookie_name: CookieExtractor::DEFAULT_COOKIE_NAME.to_owned(),
40 mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_owned(),
41 }
42 }
43
44 #[must_use]
45 pub fn with_cookie_name(mut self, name: String) -> Self {
46 self.cookie_name = name;
47 self
48 }
49
50 #[must_use]
51 pub fn with_mcp_header_name(mut self, name: String) -> Self {
52 self.mcp_header_name = name;
53 self
54 }
55
56 #[must_use]
57 pub fn standard() -> Self {
58 Self::new(vec![
59 ExtractionMethod::AuthorizationHeader,
60 ExtractionMethod::McpProxyHeader,
61 ExtractionMethod::Cookie,
62 ])
63 }
64
65 #[must_use]
66 pub fn browser_only() -> Self {
67 Self::new(vec![
68 ExtractionMethod::AuthorizationHeader,
69 ExtractionMethod::Cookie,
70 ])
71 }
72
73 #[must_use]
74 pub fn api_only() -> Self {
75 Self::new(vec![ExtractionMethod::AuthorizationHeader])
76 }
77
78 #[must_use]
79 pub fn chain(&self) -> &[ExtractionMethod] {
80 &self.fallback_chain
81 }
82
83 pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
84 for method in &self.fallback_chain {
85 match method {
86 ExtractionMethod::AuthorizationHeader => {
87 if let Ok(token) = Self::extract_from_authorization(headers) {
88 return Ok(token);
89 }
90 },
91 ExtractionMethod::McpProxyHeader => {
92 if let Ok(token) = self.extract_from_mcp_proxy(headers) {
93 return Ok(token);
94 }
95 },
96 ExtractionMethod::Cookie => {
97 if let Ok(token) = self.extract_from_cookie(headers) {
98 return Ok(token);
99 }
100 },
101 }
102 }
103
104 Err(TokenExtractionError::NoTokenFound)
105 }
106
107 pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
108 let auth_headers = headers.get_all("authorization");
109
110 if auth_headers.iter().count() == 0 {
111 return Err(TokenExtractionError::MissingAuthorizationHeader);
112 }
113
114 for auth_value in &auth_headers {
115 let Ok(auth_header) = auth_value.to_str().map_err(|e| {
116 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
117 e
118 }) else {
119 continue;
120 };
121
122 if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
123 let token = token.trim();
124 if !token.is_empty() {
125 return Ok(token.to_owned());
126 }
127 }
128 }
129
130 Err(TokenExtractionError::InvalidAuthorizationFormat)
131 }
132
133 pub fn extract_from_mcp_proxy(
134 &self,
135 headers: &HeaderMap,
136 ) -> Result<String, TokenExtractionError> {
137 let header_value = headers
138 .get(&self.mcp_header_name)
139 .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
140
141 let auth_header = header_value
142 .to_str()
143 .map_err(|_e| TokenExtractionError::InvalidMcpProxyFormat)?;
144
145 auth_header
146 .strip_prefix(BEARER_PREFIX)
147 .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
148 .map(str::to_owned)
149 }
150
151 pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
152 CookieExtractor::new(self.cookie_name.clone())
153 .extract(headers)
154 .map_err(|err| match err {
155 CookieExtractionError::MissingCookie => TokenExtractionError::MissingCookie,
156 CookieExtractionError::InvalidCookieFormat => {
157 TokenExtractionError::InvalidCookieFormat
158 },
159 CookieExtractionError::TokenNotFoundInCookie => {
160 TokenExtractionError::TokenNotFoundInCookie
161 },
162 })
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum TokenExtractionError {
168 NoTokenFound,
169 MissingAuthorizationHeader,
170 InvalidAuthorizationFormat,
171 MissingMcpProxyHeader,
172 InvalidMcpProxyFormat,
173 MissingCookie,
174 InvalidCookieFormat,
175 TokenNotFoundInCookie,
176}
177
178impl fmt::Display for TokenExtractionError {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 match self {
181 Self::NoTokenFound => write!(f, "No token found in request"),
182 Self::MissingAuthorizationHeader => {
183 write!(f, "Missing Authorization header")
184 },
185 Self::InvalidAuthorizationFormat => {
186 write!(
187 f,
188 "Invalid Authorization header format (expected 'Bearer <token>')"
189 )
190 },
191 Self::MissingMcpProxyHeader => {
192 write!(f, "Missing MCP proxy authorization header")
193 },
194 Self::InvalidMcpProxyFormat => {
195 write!(
196 f,
197 "Invalid MCP proxy header format (expected 'Bearer <token>')"
198 )
199 },
200 Self::MissingCookie => write!(f, "Missing cookie header"),
201 Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
202 Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
203 }
204 }
205}
206
207impl Error for TokenExtractionError {}