Skip to main content

systemprompt_security/extraction/
token.rs

1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5use super::cookie::{CookieExtractionError, CookieExtractor};
6
7const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
8const BEARER_PREFIX: &str = "Bearer ";
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ExtractionMethod {
12    AuthorizationHeader,
13    McpProxyHeader,
14    Cookie,
15}
16
17impl fmt::Display for ExtractionMethod {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            Self::AuthorizationHeader => write!(f, "Authorization header"),
21            Self::McpProxyHeader => write!(f, "MCP proxy header"),
22            Self::Cookie => write!(f, "Cookie"),
23        }
24    }
25}
26
27#[derive(Debug, Clone)]
28pub struct TokenExtractor {
29    fallback_chain: Vec<ExtractionMethod>,
30    cookie_name: String,
31    mcp_header_name: String,
32}
33
34impl TokenExtractor {
35    #[must_use]
36    pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
37        Self {
38            fallback_chain,
39            cookie_name: CookieExtractor::DEFAULT_COOKIE_NAME.to_owned(),
40            mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_owned(),
41        }
42    }
43
44    #[must_use]
45    pub fn with_cookie_name(mut self, name: String) -> Self {
46        self.cookie_name = name;
47        self
48    }
49
50    #[must_use]
51    pub fn with_mcp_header_name(mut self, name: String) -> Self {
52        self.mcp_header_name = name;
53        self
54    }
55
56    #[must_use]
57    pub fn standard() -> Self {
58        Self::new(vec![
59            ExtractionMethod::AuthorizationHeader,
60            ExtractionMethod::McpProxyHeader,
61            ExtractionMethod::Cookie,
62        ])
63    }
64
65    #[must_use]
66    pub fn browser_only() -> Self {
67        Self::new(vec![
68            ExtractionMethod::AuthorizationHeader,
69            ExtractionMethod::Cookie,
70        ])
71    }
72
73    #[must_use]
74    pub fn api_only() -> Self {
75        Self::new(vec![ExtractionMethod::AuthorizationHeader])
76    }
77
78    #[must_use]
79    pub fn chain(&self) -> &[ExtractionMethod] {
80        &self.fallback_chain
81    }
82
83    pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
84        for method in &self.fallback_chain {
85            match method {
86                ExtractionMethod::AuthorizationHeader => {
87                    if let Ok(token) = Self::extract_from_authorization(headers) {
88                        return Ok(token);
89                    }
90                },
91                ExtractionMethod::McpProxyHeader => {
92                    if let Ok(token) = self.extract_from_mcp_proxy(headers) {
93                        return Ok(token);
94                    }
95                },
96                ExtractionMethod::Cookie => {
97                    if let Ok(token) = self.extract_from_cookie(headers) {
98                        return Ok(token);
99                    }
100                },
101            }
102        }
103
104        Err(TokenExtractionError::NoTokenFound)
105    }
106
107    pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
108        let auth_headers = headers.get_all("authorization");
109
110        if auth_headers.iter().count() == 0 {
111            return Err(TokenExtractionError::MissingAuthorizationHeader);
112        }
113
114        for auth_value in &auth_headers {
115            let Ok(auth_header) = auth_value.to_str().map_err(|e| {
116                tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
117                e
118            }) else {
119                continue;
120            };
121
122            if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
123                let token = token.trim();
124                if !token.is_empty() {
125                    return Ok(token.to_owned());
126                }
127            }
128        }
129
130        Err(TokenExtractionError::InvalidAuthorizationFormat)
131    }
132
133    pub fn extract_from_mcp_proxy(
134        &self,
135        headers: &HeaderMap,
136    ) -> Result<String, TokenExtractionError> {
137        let header_value = headers
138            .get(&self.mcp_header_name)
139            .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
140
141        let auth_header = header_value
142            .to_str()
143            .map_err(|_e| TokenExtractionError::InvalidMcpProxyFormat)?;
144
145        auth_header
146            .strip_prefix(BEARER_PREFIX)
147            .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
148            .map(str::to_owned)
149    }
150
151    pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
152        CookieExtractor::new(self.cookie_name.clone())
153            .extract(headers)
154            .map_err(|err| match err {
155                CookieExtractionError::MissingCookie => TokenExtractionError::MissingCookie,
156                CookieExtractionError::InvalidCookieFormat => {
157                    TokenExtractionError::InvalidCookieFormat
158                },
159                CookieExtractionError::TokenNotFoundInCookie => {
160                    TokenExtractionError::TokenNotFoundInCookie
161                },
162            })
163    }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum TokenExtractionError {
168    NoTokenFound,
169    MissingAuthorizationHeader,
170    InvalidAuthorizationFormat,
171    MissingMcpProxyHeader,
172    InvalidMcpProxyFormat,
173    MissingCookie,
174    InvalidCookieFormat,
175    TokenNotFoundInCookie,
176}
177
178impl fmt::Display for TokenExtractionError {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        match self {
181            Self::NoTokenFound => write!(f, "No token found in request"),
182            Self::MissingAuthorizationHeader => {
183                write!(f, "Missing Authorization header")
184            },
185            Self::InvalidAuthorizationFormat => {
186                write!(
187                    f,
188                    "Invalid Authorization header format (expected 'Bearer <token>')"
189                )
190            },
191            Self::MissingMcpProxyHeader => {
192                write!(f, "Missing MCP proxy authorization header")
193            },
194            Self::InvalidMcpProxyFormat => {
195                write!(
196                    f,
197                    "Invalid MCP proxy header format (expected 'Bearer <token>')"
198                )
199            },
200            Self::MissingCookie => write!(f, "Missing cookie header"),
201            Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
202            Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
203        }
204    }
205}
206
207impl Error for TokenExtractionError {}