Skip to main content

turul_a2a_auth/
api_key.rs

1//! API Key auth middleware.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use turul_a2a::middleware::{
9    A2aMiddleware, AuthFailureKind, AuthIdentity, MiddlewareError, RequestContext,
10    SecurityContribution,
11};
12
13/// Trait for resolving API key to owner identity.
14/// Returns `Some(owner)` if the key is valid, `None` if invalid.
15#[async_trait]
16pub trait ApiKeyLookup: Send + Sync {
17    async fn lookup(&self, key: &str) -> Option<String>;
18}
19
20/// Static in-memory API key lookup.
21pub struct StaticApiKeyLookup {
22    keys: HashMap<String, String>, // key → owner
23}
24
25impl StaticApiKeyLookup {
26    pub fn new(keys: HashMap<String, String>) -> Self {
27        Self { keys }
28    }
29}
30
31#[async_trait]
32impl ApiKeyLookup for StaticApiKeyLookup {
33    async fn lookup(&self, key: &str) -> Option<String> {
34        self.keys.get(key).cloned()
35    }
36}
37
38/// First-party `ApiKeyLookup` reference implementation with a redacted
39/// `Debug` that never exposes key material.
40///
41/// Internal storage uses a plain `HashMap<String, String>` — keys remain
42/// in-process strings but are unreachable via `Debug`, `Display`, or
43/// `Serialize`. This is the simplest shape that satisfies the ADR
44/// invariant; adopter implementations may reach for more elaborate
45/// containers (`secrecy::SecretString` with newtype, etc.) if their
46/// deployment requires stronger guarantees.
47///
48/// ```
49/// use std::collections::HashMap;
50/// use turul_a2a_auth::RedactedApiKeyLookup;
51///
52/// let mut keys = HashMap::new();
53/// keys.insert("api-key-abc".to_string(), "owner-alice".to_string());
54/// let lookup = RedactedApiKeyLookup::new(keys);
55///
56/// let dbg = format!("{lookup:?}");
57/// assert!(!dbg.contains("api-key-abc"), "keys must never print: {dbg}");
58/// assert!(dbg.contains("RedactedApiKeyLookup"));
59/// ```
60pub struct RedactedApiKeyLookup {
61    keys: HashMap<String, String>,
62}
63
64impl RedactedApiKeyLookup {
65    pub fn new(keys: HashMap<String, String>) -> Self {
66        Self { keys }
67    }
68
69    /// Number of keys currently registered. Safe to expose — reveals no
70    /// key material.
71    pub fn len(&self) -> usize {
72        self.keys.len()
73    }
74
75    pub fn is_empty(&self) -> bool {
76        self.keys.is_empty()
77    }
78}
79
80impl fmt::Debug for RedactedApiKeyLookup {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("RedactedApiKeyLookup")
83            .field("len", &self.keys.len())
84            .finish()
85    }
86}
87
88#[async_trait]
89impl ApiKeyLookup for RedactedApiKeyLookup {
90    async fn lookup(&self, key: &str) -> Option<String> {
91        self.keys.get(key).cloned()
92    }
93}
94
95/// API Key auth middleware.
96///
97/// Extracts API key from a configurable header and validates via `ApiKeyLookup`.
98/// Rejects empty/whitespace owner values from lookup (symmetrical with Bearer).
99pub struct ApiKeyMiddleware {
100    lookup: Arc<dyn ApiKeyLookup>,
101    header_name: String,
102}
103
104impl ApiKeyMiddleware {
105    pub fn new(lookup: Arc<dyn ApiKeyLookup>, header_name: impl Into<String>) -> Self {
106        Self {
107            lookup,
108            header_name: header_name.into(),
109        }
110    }
111}
112
113#[async_trait]
114impl A2aMiddleware for ApiKeyMiddleware {
115    async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
116        let key = ctx
117            .headers
118            .get(&self.header_name)
119            .and_then(|v| v.to_str().ok())
120            .map(|s| s.to_string());
121
122        let key = match key {
123            Some(k) if !k.is_empty() => k,
124            _ => {
125                return Err(MiddlewareError::Unauthenticated(
126                    AuthFailureKind::MissingCredential,
127                ));
128            }
129        };
130
131        let owner = self
132            .lookup
133            .lookup(&key)
134            .await
135            .ok_or(MiddlewareError::Unauthenticated(
136                AuthFailureKind::InvalidApiKey,
137            ))?;
138
139        // Reject empty/whitespace owners — symmetrical with Bearer principal extraction
140        if owner.trim().is_empty() {
141            return Err(MiddlewareError::Unauthenticated(
142                AuthFailureKind::EmptyPrincipal,
143            ));
144        }
145
146        ctx.identity = AuthIdentity::Authenticated {
147            owner,
148            claims: None, // API key auth has no JWT claims
149        };
150        Ok(())
151    }
152
153    fn security_contribution(&self) -> SecurityContribution {
154        SecurityContribution::new().with_scheme(
155            "apiKey",
156            turul_a2a_proto::SecurityScheme {
157                scheme: Some(
158                    turul_a2a_proto::security_scheme::Scheme::ApiKeySecurityScheme(
159                        turul_a2a_proto::ApiKeySecurityScheme {
160                            description: String::new(),
161                            location: "header".into(),
162                            name: self.header_name.clone(),
163                        },
164                    ),
165                ),
166            },
167            vec![],
168        )
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    fn test_lookup() -> Arc<dyn ApiKeyLookup> {
177        let mut keys = HashMap::new();
178        keys.insert("valid-key".to_string(), "user-from-key".to_string());
179        keys.insert("empty-owner-key".to_string(), "".to_string());
180        keys.insert("whitespace-key".to_string(), "  ".to_string());
181        Arc::new(StaticApiKeyLookup::new(keys))
182    }
183
184    fn middleware() -> ApiKeyMiddleware {
185        ApiKeyMiddleware::new(test_lookup(), "X-API-Key")
186    }
187
188    #[tokio::test]
189    async fn valid_key_sets_authenticated_identity() {
190        let mw = middleware();
191        let mut ctx = RequestContext::new();
192        ctx.headers
193            .insert("x-api-key", "valid-key".parse().unwrap());
194        mw.before_request(&mut ctx).await.unwrap();
195        assert!(ctx.identity.is_authenticated());
196        assert_eq!(ctx.identity.owner(), "user-from-key");
197        assert!(ctx.identity.claims().is_none(), "API key has no claims");
198    }
199
200    #[tokio::test]
201    async fn missing_key_returns_unauthenticated() {
202        let mw = middleware();
203        let mut ctx = RequestContext::new();
204        let err = mw.before_request(&mut ctx).await.unwrap_err();
205        assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
206    }
207
208    #[tokio::test]
209    async fn invalid_key_returns_unauthenticated() {
210        let mw = middleware();
211        let mut ctx = RequestContext::new();
212        ctx.headers.insert("x-api-key", "bad-key".parse().unwrap());
213        let err = mw.before_request(&mut ctx).await.unwrap_err();
214        assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
215    }
216
217    #[tokio::test]
218    async fn empty_owner_from_lookup_rejected() {
219        let mw = middleware();
220        let mut ctx = RequestContext::new();
221        ctx.headers
222            .insert("x-api-key", "empty-owner-key".parse().unwrap());
223        let err = mw.before_request(&mut ctx).await.unwrap_err();
224        assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
225    }
226
227    #[tokio::test]
228    async fn whitespace_owner_from_lookup_rejected() {
229        let mw = middleware();
230        let mut ctx = RequestContext::new();
231        ctx.headers
232            .insert("x-api-key", "whitespace-key".parse().unwrap());
233        let err = mw.before_request(&mut ctx).await.unwrap_err();
234        assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
235    }
236
237    #[test]
238    fn security_contribution_returns_api_key_scheme() {
239        let mw = middleware();
240        let contrib = mw.security_contribution();
241        assert_eq!(contrib.schemes.len(), 1);
242        assert_eq!(contrib.schemes[0].0, "apiKey");
243        assert_eq!(contrib.requirements.len(), 1);
244    }
245}