tower_a2a/layer/
auth.rs

1//! Authentication layer for A2A protocol
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use base64::{engine::general_purpose, Engine as _};
10use tower_layer::Layer;
11use tower_service::Service;
12
13use crate::{
14    protocol::error::A2AError,
15    service::{A2ARequest, A2AResponse},
16};
17
18/// Authentication credentials
19#[derive(Debug, Clone)]
20pub enum AuthCredentials {
21    /// Bearer token authentication
22    Bearer(String),
23
24    /// API key authentication
25    ApiKey { key: String, header: String },
26
27    /// Basic HTTP authentication
28    Basic { username: String, password: String },
29}
30
31impl AuthCredentials {
32    /// Create bearer token credentials
33    pub fn bearer(token: impl Into<String>) -> Self {
34        Self::Bearer(token.into())
35    }
36
37    /// Create API key credentials
38    pub fn api_key(key: impl Into<String>, header: impl Into<String>) -> Self {
39        Self::ApiKey {
40            key: key.into(),
41            header: header.into(),
42        }
43    }
44
45    /// Create basic auth credentials
46    pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
47        Self::Basic {
48            username: username.into(),
49            password: password.into(),
50        }
51    }
52
53    /// Get the header name and value for this credential
54    pub fn to_header(&self) -> (String, String) {
55        match self {
56            AuthCredentials::Bearer(token) => {
57                ("Authorization".to_string(), format!("Bearer {}", token))
58            }
59            AuthCredentials::ApiKey { key, header } => (header.clone(), key.clone()),
60            AuthCredentials::Basic { username, password } => {
61                let credentials = format!("{}:{}", username, password);
62                let encoded = general_purpose::STANDARD.encode(credentials.as_bytes());
63                ("Authorization".to_string(), format!("Basic {}", encoded))
64            }
65        }
66    }
67}
68
69/// Authentication layer
70#[derive(Clone)]
71pub struct AuthLayer {
72    credentials: AuthCredentials,
73}
74
75impl AuthLayer {
76    /// Create a new authentication layer
77    pub fn new(credentials: AuthCredentials) -> Self {
78        Self { credentials }
79    }
80
81    /// Create a bearer authentication layer
82    pub fn bearer(token: impl Into<String>) -> Self {
83        Self::new(AuthCredentials::bearer(token))
84    }
85
86    /// Create an API key authentication layer
87    pub fn api_key(key: impl Into<String>, header: impl Into<String>) -> Self {
88        Self::new(AuthCredentials::api_key(key, header))
89    }
90}
91
92impl<S> Layer<S> for AuthLayer {
93    type Service = AuthService<S>;
94
95    fn layer(&self, inner: S) -> Self::Service {
96        AuthService {
97            inner,
98            credentials: self.credentials.clone(),
99        }
100    }
101}
102
103/// Authentication service
104#[derive(Clone)]
105pub struct AuthService<S> {
106    inner: S,
107    credentials: AuthCredentials,
108}
109
110impl<S> Service<A2ARequest> for AuthService<S>
111where
112    S: Service<A2ARequest, Response = A2AResponse, Error = A2AError> + Clone + Send + 'static,
113    S::Future: Send,
114{
115    type Response = A2AResponse;
116    type Error = A2AError;
117    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
118
119    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        self.inner.poll_ready(cx)
121    }
122
123    fn call(&mut self, mut req: A2ARequest) -> Self::Future {
124        // Inject credentials into request context
125        req.context.auth = Some(self.credentials.clone());
126
127        let mut inner = self.inner.clone();
128        Box::pin(async move { inner.call(req).await })
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_bearer_credentials() {
138        let creds = AuthCredentials::bearer("test-token");
139        let (header, value) = creds.to_header();
140
141        assert_eq!(header, "Authorization");
142        assert_eq!(value, "Bearer test-token");
143    }
144
145    #[test]
146    fn test_api_key_credentials() {
147        let creds = AuthCredentials::api_key("secret-key", "X-API-Key");
148        let (header, value) = creds.to_header();
149
150        assert_eq!(header, "X-API-Key");
151        assert_eq!(value, "secret-key");
152    }
153
154    #[test]
155    fn test_basic_credentials() {
156        let creds = AuthCredentials::basic("user", "pass");
157        let (header, value) = creds.to_header();
158
159        assert_eq!(header, "Authorization");
160        assert!(value.starts_with("Basic "));
161    }
162}