turul_a2a_auth/
api_key.rs1use std::collections::HashMap;
4use std::fmt;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use turul_a2a::middleware::{
9 A2aMiddleware, AuthFailureKind, AuthIdentity, MiddlewareError, RequestContext,
10 SecurityContribution,
11};
12
13#[async_trait]
16pub trait ApiKeyLookup: Send + Sync {
17 async fn lookup(&self, key: &str) -> Option<String>;
18}
19
20pub struct StaticApiKeyLookup {
22 keys: HashMap<String, String>, }
24
25impl StaticApiKeyLookup {
26 pub fn new(keys: HashMap<String, String>) -> Self {
27 Self { keys }
28 }
29}
30
31#[async_trait]
32impl ApiKeyLookup for StaticApiKeyLookup {
33 async fn lookup(&self, key: &str) -> Option<String> {
34 self.keys.get(key).cloned()
35 }
36}
37
38pub struct RedactedApiKeyLookup {
61 keys: HashMap<String, String>,
62}
63
64impl RedactedApiKeyLookup {
65 pub fn new(keys: HashMap<String, String>) -> Self {
66 Self { keys }
67 }
68
69 pub fn len(&self) -> usize {
72 self.keys.len()
73 }
74
75 pub fn is_empty(&self) -> bool {
76 self.keys.is_empty()
77 }
78}
79
80impl fmt::Debug for RedactedApiKeyLookup {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.debug_struct("RedactedApiKeyLookup")
83 .field("len", &self.keys.len())
84 .finish()
85 }
86}
87
88#[async_trait]
89impl ApiKeyLookup for RedactedApiKeyLookup {
90 async fn lookup(&self, key: &str) -> Option<String> {
91 self.keys.get(key).cloned()
92 }
93}
94
95pub struct ApiKeyMiddleware {
100 lookup: Arc<dyn ApiKeyLookup>,
101 header_name: String,
102}
103
104impl ApiKeyMiddleware {
105 pub fn new(lookup: Arc<dyn ApiKeyLookup>, header_name: impl Into<String>) -> Self {
106 Self {
107 lookup,
108 header_name: header_name.into(),
109 }
110 }
111}
112
113#[async_trait]
114impl A2aMiddleware for ApiKeyMiddleware {
115 async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
116 let key = ctx
117 .headers
118 .get(&self.header_name)
119 .and_then(|v| v.to_str().ok())
120 .map(|s| s.to_string());
121
122 let key = match key {
123 Some(k) if !k.is_empty() => k,
124 _ => {
125 return Err(MiddlewareError::Unauthenticated(
126 AuthFailureKind::MissingCredential,
127 ));
128 }
129 };
130
131 let owner = self
132 .lookup
133 .lookup(&key)
134 .await
135 .ok_or(MiddlewareError::Unauthenticated(
136 AuthFailureKind::InvalidApiKey,
137 ))?;
138
139 if owner.trim().is_empty() {
141 return Err(MiddlewareError::Unauthenticated(
142 AuthFailureKind::EmptyPrincipal,
143 ));
144 }
145
146 ctx.identity = AuthIdentity::Authenticated {
147 owner,
148 claims: None, };
150 Ok(())
151 }
152
153 fn security_contribution(&self) -> SecurityContribution {
154 SecurityContribution::new().with_scheme(
155 "apiKey",
156 turul_a2a_proto::SecurityScheme {
157 scheme: Some(
158 turul_a2a_proto::security_scheme::Scheme::ApiKeySecurityScheme(
159 turul_a2a_proto::ApiKeySecurityScheme {
160 description: String::new(),
161 location: "header".into(),
162 name: self.header_name.clone(),
163 },
164 ),
165 ),
166 },
167 vec![],
168 )
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 fn test_lookup() -> Arc<dyn ApiKeyLookup> {
177 let mut keys = HashMap::new();
178 keys.insert("valid-key".to_string(), "user-from-key".to_string());
179 keys.insert("empty-owner-key".to_string(), "".to_string());
180 keys.insert("whitespace-key".to_string(), " ".to_string());
181 Arc::new(StaticApiKeyLookup::new(keys))
182 }
183
184 fn middleware() -> ApiKeyMiddleware {
185 ApiKeyMiddleware::new(test_lookup(), "X-API-Key")
186 }
187
188 #[tokio::test]
189 async fn valid_key_sets_authenticated_identity() {
190 let mw = middleware();
191 let mut ctx = RequestContext::new();
192 ctx.headers
193 .insert("x-api-key", "valid-key".parse().unwrap());
194 mw.before_request(&mut ctx).await.unwrap();
195 assert!(ctx.identity.is_authenticated());
196 assert_eq!(ctx.identity.owner(), "user-from-key");
197 assert!(ctx.identity.claims().is_none(), "API key has no claims");
198 }
199
200 #[tokio::test]
201 async fn missing_key_returns_unauthenticated() {
202 let mw = middleware();
203 let mut ctx = RequestContext::new();
204 let err = mw.before_request(&mut ctx).await.unwrap_err();
205 assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
206 }
207
208 #[tokio::test]
209 async fn invalid_key_returns_unauthenticated() {
210 let mw = middleware();
211 let mut ctx = RequestContext::new();
212 ctx.headers.insert("x-api-key", "bad-key".parse().unwrap());
213 let err = mw.before_request(&mut ctx).await.unwrap_err();
214 assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
215 }
216
217 #[tokio::test]
218 async fn empty_owner_from_lookup_rejected() {
219 let mw = middleware();
220 let mut ctx = RequestContext::new();
221 ctx.headers
222 .insert("x-api-key", "empty-owner-key".parse().unwrap());
223 let err = mw.before_request(&mut ctx).await.unwrap_err();
224 assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
225 }
226
227 #[tokio::test]
228 async fn whitespace_owner_from_lookup_rejected() {
229 let mw = middleware();
230 let mut ctx = RequestContext::new();
231 ctx.headers
232 .insert("x-api-key", "whitespace-key".parse().unwrap());
233 let err = mw.before_request(&mut ctx).await.unwrap_err();
234 assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
235 }
236
237 #[test]
238 fn security_contribution_returns_api_key_scheme() {
239 let mw = middleware();
240 let contrib = mw.security_contribution();
241 assert_eq!(contrib.schemes.len(), 1);
242 assert_eq!(contrib.schemes[0].0, "apiKey");
243 assert_eq!(contrib.requirements.len(), 1);
244 }
245}