Skip to main content

systemprompt_security/extraction/
token.rs

1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5const DEFAULT_COOKIE_NAME: &str = "access_token";
6const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
7const BEARER_PREFIX: &str = "Bearer ";
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ExtractionMethod {
11    AuthorizationHeader,
12    McpProxyHeader,
13    Cookie,
14}
15
16impl fmt::Display for ExtractionMethod {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Self::AuthorizationHeader => write!(f, "Authorization header"),
20            Self::McpProxyHeader => write!(f, "MCP proxy header"),
21            Self::Cookie => write!(f, "Cookie"),
22        }
23    }
24}
25
26#[derive(Debug, Clone)]
27pub struct TokenExtractor {
28    fallback_chain: Vec<ExtractionMethod>,
29    cookie_name: String,
30    mcp_header_name: String,
31}
32
33impl TokenExtractor {
34    #[must_use]
35    pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
36        Self {
37            fallback_chain,
38            cookie_name: DEFAULT_COOKIE_NAME.to_string(),
39            mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_string(),
40        }
41    }
42
43    #[must_use]
44    pub fn with_cookie_name(mut self, name: String) -> Self {
45        self.cookie_name = name;
46        self
47    }
48
49    #[must_use]
50    pub fn with_mcp_header_name(mut self, name: String) -> Self {
51        self.mcp_header_name = name;
52        self
53    }
54
55    #[must_use]
56    pub fn standard() -> Self {
57        Self::new(vec![
58            ExtractionMethod::AuthorizationHeader,
59            ExtractionMethod::McpProxyHeader,
60            ExtractionMethod::Cookie,
61        ])
62    }
63
64    #[must_use]
65    pub fn browser_only() -> Self {
66        Self::new(vec![
67            ExtractionMethod::AuthorizationHeader,
68            ExtractionMethod::Cookie,
69        ])
70    }
71
72    #[must_use]
73    pub fn api_only() -> Self {
74        Self::new(vec![ExtractionMethod::AuthorizationHeader])
75    }
76
77    #[must_use]
78    pub fn chain(&self) -> &[ExtractionMethod] {
79        &self.fallback_chain
80    }
81
82    pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
83        for method in &self.fallback_chain {
84            match method {
85                ExtractionMethod::AuthorizationHeader => {
86                    if let Ok(token) = Self::extract_from_authorization(headers) {
87                        return Ok(token);
88                    }
89                },
90                ExtractionMethod::McpProxyHeader => {
91                    if let Ok(token) = self.extract_from_mcp_proxy(headers) {
92                        return Ok(token);
93                    }
94                },
95                ExtractionMethod::Cookie => {
96                    if let Ok(token) = self.extract_from_cookie(headers) {
97                        return Ok(token);
98                    }
99                },
100            }
101        }
102
103        Err(TokenExtractionError::NoTokenFound)
104    }
105
106    pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
107        let auth_headers = headers.get_all("authorization");
108
109        if auth_headers.iter().count() == 0 {
110            return Err(TokenExtractionError::MissingAuthorizationHeader);
111        }
112
113        for auth_value in &auth_headers {
114            let Ok(auth_header) = auth_value.to_str().map_err(|e| {
115                tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
116                e
117            }) else {
118                continue;
119            };
120
121            if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
122                let token = token.trim();
123                if !token.is_empty() {
124                    return Ok(token.to_string());
125                }
126            }
127        }
128
129        Err(TokenExtractionError::InvalidAuthorizationFormat)
130    }
131
132    pub fn extract_from_mcp_proxy(
133        &self,
134        headers: &HeaderMap,
135    ) -> Result<String, TokenExtractionError> {
136        let header_value = headers
137            .get(&self.mcp_header_name)
138            .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
139
140        let auth_header = header_value
141            .to_str()
142            .map_err(|_| TokenExtractionError::InvalidMcpProxyFormat)?;
143
144        auth_header
145            .strip_prefix(BEARER_PREFIX)
146            .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
147            .map(ToString::to_string)
148    }
149
150    pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
151        let cookie_header = headers
152            .get("cookie")
153            .ok_or(TokenExtractionError::MissingCookie)?
154            .to_str()
155            .map_err(|_| TokenExtractionError::InvalidCookieFormat)?;
156
157        for cookie in cookie_header.split(';') {
158            let cookie = cookie.trim();
159            let cookie_prefix = format!("{}=", self.cookie_name);
160            if let Some(value) = cookie.strip_prefix(&cookie_prefix) {
161                if !value.is_empty() {
162                    return Ok(value.to_string());
163                }
164            }
165        }
166
167        Err(TokenExtractionError::TokenNotFoundInCookie)
168    }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub enum TokenExtractionError {
173    NoTokenFound,
174    MissingAuthorizationHeader,
175    InvalidAuthorizationFormat,
176    MissingMcpProxyHeader,
177    InvalidMcpProxyFormat,
178    MissingCookie,
179    InvalidCookieFormat,
180    TokenNotFoundInCookie,
181}
182
183impl fmt::Display for TokenExtractionError {
184    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185        match self {
186            Self::NoTokenFound => write!(f, "No token found in request"),
187            Self::MissingAuthorizationHeader => {
188                write!(f, "Missing Authorization header")
189            },
190            Self::InvalidAuthorizationFormat => {
191                write!(
192                    f,
193                    "Invalid Authorization header format (expected 'Bearer <token>')"
194                )
195            },
196            Self::MissingMcpProxyHeader => {
197                write!(f, "Missing MCP proxy authorization header")
198            },
199            Self::InvalidMcpProxyFormat => {
200                write!(
201                    f,
202                    "Invalid MCP proxy header format (expected 'Bearer <token>')"
203                )
204            },
205            Self::MissingCookie => write!(f, "Missing cookie header"),
206            Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
207            Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
208        }
209    }
210}
211
212impl Error for TokenExtractionError {}