Skip to main content

systemprompt_security/extraction/
token.rs

1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5const DEFAULT_COOKIE_NAME: &str = "access_token";
6const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
7const BEARER_PREFIX: &str = "Bearer ";
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ExtractionMethod {
11    AuthorizationHeader,
12    McpProxyHeader,
13    Cookie,
14}
15
16impl fmt::Display for ExtractionMethod {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Self::AuthorizationHeader => write!(f, "Authorization header"),
20            Self::McpProxyHeader => write!(f, "MCP proxy header"),
21            Self::Cookie => write!(f, "Cookie"),
22        }
23    }
24}
25
26#[derive(Debug, Clone)]
27pub struct TokenExtractor {
28    fallback_chain: Vec<ExtractionMethod>,
29    cookie_name: String,
30    mcp_header_name: String,
31}
32
33impl TokenExtractor {
34    pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
35        Self {
36            fallback_chain,
37            cookie_name: DEFAULT_COOKIE_NAME.to_string(),
38            mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_string(),
39        }
40    }
41
42    pub fn with_cookie_name(mut self, name: String) -> Self {
43        self.cookie_name = name;
44        self
45    }
46
47    pub fn with_mcp_header_name(mut self, name: String) -> Self {
48        self.mcp_header_name = name;
49        self
50    }
51
52    pub fn standard() -> Self {
53        Self::new(vec![
54            ExtractionMethod::AuthorizationHeader,
55            ExtractionMethod::McpProxyHeader,
56            ExtractionMethod::Cookie,
57        ])
58    }
59
60    pub fn browser_only() -> Self {
61        Self::new(vec![
62            ExtractionMethod::AuthorizationHeader,
63            ExtractionMethod::Cookie,
64        ])
65    }
66
67    pub fn api_only() -> Self {
68        Self::new(vec![ExtractionMethod::AuthorizationHeader])
69    }
70
71    pub fn chain(&self) -> &[ExtractionMethod] {
72        &self.fallback_chain
73    }
74
75    pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
76        for method in &self.fallback_chain {
77            match method {
78                ExtractionMethod::AuthorizationHeader => {
79                    if let Ok(token) = Self::extract_from_authorization(headers) {
80                        return Ok(token);
81                    }
82                },
83                ExtractionMethod::McpProxyHeader => {
84                    if let Ok(token) = self.extract_from_mcp_proxy(headers) {
85                        return Ok(token);
86                    }
87                },
88                ExtractionMethod::Cookie => {
89                    if let Ok(token) = self.extract_from_cookie(headers) {
90                        return Ok(token);
91                    }
92                },
93            }
94        }
95
96        Err(TokenExtractionError::NoTokenFound)
97    }
98
99    pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
100        let auth_headers = headers.get_all("authorization");
101
102        if auth_headers.iter().count() == 0 {
103            return Err(TokenExtractionError::MissingAuthorizationHeader);
104        }
105
106        for auth_value in &auth_headers {
107            let Ok(auth_header) = auth_value.to_str().map_err(|e| {
108                tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
109                e
110            }) else {
111                continue;
112            };
113
114            if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
115                let token = token.trim();
116                if !token.is_empty() {
117                    return Ok(token.to_string());
118                }
119            }
120        }
121
122        Err(TokenExtractionError::InvalidAuthorizationFormat)
123    }
124
125    pub fn extract_from_mcp_proxy(
126        &self,
127        headers: &HeaderMap,
128    ) -> Result<String, TokenExtractionError> {
129        let header_value = headers
130            .get(&self.mcp_header_name)
131            .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
132
133        let auth_header = header_value
134            .to_str()
135            .map_err(|_| TokenExtractionError::InvalidMcpProxyFormat)?;
136
137        auth_header
138            .strip_prefix(BEARER_PREFIX)
139            .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
140            .map(ToString::to_string)
141    }
142
143    pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
144        let cookie_header = headers
145            .get("cookie")
146            .ok_or(TokenExtractionError::MissingCookie)?
147            .to_str()
148            .map_err(|_| TokenExtractionError::InvalidCookieFormat)?;
149
150        for cookie in cookie_header.split(';') {
151            let cookie = cookie.trim();
152            let cookie_prefix = format!("{}=", self.cookie_name);
153            if let Some(value) = cookie.strip_prefix(&cookie_prefix) {
154                if !value.is_empty() {
155                    return Ok(value.to_string());
156                }
157            }
158        }
159
160        Err(TokenExtractionError::TokenNotFoundInCookie)
161    }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum TokenExtractionError {
166    NoTokenFound,
167    MissingAuthorizationHeader,
168    InvalidAuthorizationFormat,
169    MissingMcpProxyHeader,
170    InvalidMcpProxyFormat,
171    MissingCookie,
172    InvalidCookieFormat,
173    TokenNotFoundInCookie,
174}
175
176impl fmt::Display for TokenExtractionError {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        match self {
179            Self::NoTokenFound => write!(f, "No token found in request"),
180            Self::MissingAuthorizationHeader => {
181                write!(f, "Missing Authorization header")
182            },
183            Self::InvalidAuthorizationFormat => {
184                write!(
185                    f,
186                    "Invalid Authorization header format (expected 'Bearer <token>')"
187                )
188            },
189            Self::MissingMcpProxyHeader => {
190                write!(f, "Missing MCP proxy authorization header")
191            },
192            Self::InvalidMcpProxyFormat => {
193                write!(
194                    f,
195                    "Invalid MCP proxy header format (expected 'Bearer <token>')"
196                )
197            },
198            Self::MissingCookie => write!(f, "Missing cookie header"),
199            Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
200            Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
201        }
202    }
203}
204
205impl Error for TokenExtractionError {}