Skip to main content

systemprompt_security/extraction/
token.rs

1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5use super::cookie::{CookieExtractionError, CookieExtractor};
6
7const DEFAULT_COOKIE_NAME: &str = "access_token";
8const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
9const BEARER_PREFIX: &str = "Bearer ";
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExtractionMethod {
13    AuthorizationHeader,
14    McpProxyHeader,
15    Cookie,
16}
17
18impl fmt::Display for ExtractionMethod {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            Self::AuthorizationHeader => write!(f, "Authorization header"),
22            Self::McpProxyHeader => write!(f, "MCP proxy header"),
23            Self::Cookie => write!(f, "Cookie"),
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct TokenExtractor {
30    fallback_chain: Vec<ExtractionMethod>,
31    cookie_name: String,
32    mcp_header_name: String,
33}
34
35impl TokenExtractor {
36    #[must_use]
37    pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
38        Self {
39            fallback_chain,
40            cookie_name: DEFAULT_COOKIE_NAME.to_string(),
41            mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_string(),
42        }
43    }
44
45    #[must_use]
46    pub fn with_cookie_name(mut self, name: String) -> Self {
47        self.cookie_name = name;
48        self
49    }
50
51    #[must_use]
52    pub fn with_mcp_header_name(mut self, name: String) -> Self {
53        self.mcp_header_name = name;
54        self
55    }
56
57    #[must_use]
58    pub fn standard() -> Self {
59        Self::new(vec![
60            ExtractionMethod::AuthorizationHeader,
61            ExtractionMethod::McpProxyHeader,
62            ExtractionMethod::Cookie,
63        ])
64    }
65
66    #[must_use]
67    pub fn browser_only() -> Self {
68        Self::new(vec![
69            ExtractionMethod::AuthorizationHeader,
70            ExtractionMethod::Cookie,
71        ])
72    }
73
74    #[must_use]
75    pub fn api_only() -> Self {
76        Self::new(vec![ExtractionMethod::AuthorizationHeader])
77    }
78
79    #[must_use]
80    pub fn chain(&self) -> &[ExtractionMethod] {
81        &self.fallback_chain
82    }
83
84    pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
85        for method in &self.fallback_chain {
86            match method {
87                ExtractionMethod::AuthorizationHeader => {
88                    if let Ok(token) = Self::extract_from_authorization(headers) {
89                        return Ok(token);
90                    }
91                },
92                ExtractionMethod::McpProxyHeader => {
93                    if let Ok(token) = self.extract_from_mcp_proxy(headers) {
94                        return Ok(token);
95                    }
96                },
97                ExtractionMethod::Cookie => {
98                    if let Ok(token) = self.extract_from_cookie(headers) {
99                        return Ok(token);
100                    }
101                },
102            }
103        }
104
105        Err(TokenExtractionError::NoTokenFound)
106    }
107
108    pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
109        let auth_headers = headers.get_all("authorization");
110
111        if auth_headers.iter().count() == 0 {
112            return Err(TokenExtractionError::MissingAuthorizationHeader);
113        }
114
115        for auth_value in &auth_headers {
116            let Ok(auth_header) = auth_value.to_str().map_err(|e| {
117                tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
118                e
119            }) else {
120                continue;
121            };
122
123            if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
124                let token = token.trim();
125                if !token.is_empty() {
126                    return Ok(token.to_string());
127                }
128            }
129        }
130
131        Err(TokenExtractionError::InvalidAuthorizationFormat)
132    }
133
134    pub fn extract_from_mcp_proxy(
135        &self,
136        headers: &HeaderMap,
137    ) -> Result<String, TokenExtractionError> {
138        let header_value = headers
139            .get(&self.mcp_header_name)
140            .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
141
142        let auth_header = header_value
143            .to_str()
144            .map_err(|_| TokenExtractionError::InvalidMcpProxyFormat)?;
145
146        auth_header
147            .strip_prefix(BEARER_PREFIX)
148            .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
149            .map(ToString::to_string)
150    }
151
152    pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
153        CookieExtractor::new(self.cookie_name.clone())
154            .extract(headers)
155            .map_err(|err| match err {
156                CookieExtractionError::MissingCookie => TokenExtractionError::MissingCookie,
157                CookieExtractionError::InvalidCookieFormat => {
158                    TokenExtractionError::InvalidCookieFormat
159                },
160                CookieExtractionError::TokenNotFoundInCookie => {
161                    TokenExtractionError::TokenNotFoundInCookie
162                },
163            })
164    }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub enum TokenExtractionError {
169    NoTokenFound,
170    MissingAuthorizationHeader,
171    InvalidAuthorizationFormat,
172    MissingMcpProxyHeader,
173    InvalidMcpProxyFormat,
174    MissingCookie,
175    InvalidCookieFormat,
176    TokenNotFoundInCookie,
177}
178
179impl fmt::Display for TokenExtractionError {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        match self {
182            Self::NoTokenFound => write!(f, "No token found in request"),
183            Self::MissingAuthorizationHeader => {
184                write!(f, "Missing Authorization header")
185            },
186            Self::InvalidAuthorizationFormat => {
187                write!(
188                    f,
189                    "Invalid Authorization header format (expected 'Bearer <token>')"
190                )
191            },
192            Self::MissingMcpProxyHeader => {
193                write!(f, "Missing MCP proxy authorization header")
194            },
195            Self::InvalidMcpProxyFormat => {
196                write!(
197                    f,
198                    "Invalid MCP proxy header format (expected 'Bearer <token>')"
199                )
200            },
201            Self::MissingCookie => write!(f, "Missing cookie header"),
202            Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
203            Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
204        }
205    }
206}
207
208impl Error for TokenExtractionError {}