1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use base64::{engine::general_purpose, Engine as _};
10use tower_layer::Layer;
11use tower_service::Service;
12
13use crate::{
14 protocol::error::A2AError,
15 service::{A2ARequest, A2AResponse},
16};
17
18#[derive(Debug, Clone)]
20pub enum AuthCredentials {
21 Bearer(String),
23
24 ApiKey { key: String, header: String },
26
27 Basic { username: String, password: String },
29}
30
31impl AuthCredentials {
32 pub fn bearer(token: impl Into<String>) -> Self {
34 Self::Bearer(token.into())
35 }
36
37 pub fn api_key(key: impl Into<String>, header: impl Into<String>) -> Self {
39 Self::ApiKey {
40 key: key.into(),
41 header: header.into(),
42 }
43 }
44
45 pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
47 Self::Basic {
48 username: username.into(),
49 password: password.into(),
50 }
51 }
52
53 pub fn to_header(&self) -> (String, String) {
55 match self {
56 AuthCredentials::Bearer(token) => {
57 ("Authorization".to_string(), format!("Bearer {}", token))
58 }
59 AuthCredentials::ApiKey { key, header } => (header.clone(), key.clone()),
60 AuthCredentials::Basic { username, password } => {
61 let credentials = format!("{}:{}", username, password);
62 let encoded = general_purpose::STANDARD.encode(credentials.as_bytes());
63 ("Authorization".to_string(), format!("Basic {}", encoded))
64 }
65 }
66 }
67}
68
69#[derive(Clone)]
71pub struct AuthLayer {
72 credentials: AuthCredentials,
73}
74
75impl AuthLayer {
76 pub fn new(credentials: AuthCredentials) -> Self {
78 Self { credentials }
79 }
80
81 pub fn bearer(token: impl Into<String>) -> Self {
83 Self::new(AuthCredentials::bearer(token))
84 }
85
86 pub fn api_key(key: impl Into<String>, header: impl Into<String>) -> Self {
88 Self::new(AuthCredentials::api_key(key, header))
89 }
90}
91
92impl<S> Layer<S> for AuthLayer {
93 type Service = AuthService<S>;
94
95 fn layer(&self, inner: S) -> Self::Service {
96 AuthService {
97 inner,
98 credentials: self.credentials.clone(),
99 }
100 }
101}
102
103#[derive(Clone)]
105pub struct AuthService<S> {
106 inner: S,
107 credentials: AuthCredentials,
108}
109
110impl<S> Service<A2ARequest> for AuthService<S>
111where
112 S: Service<A2ARequest, Response = A2AResponse, Error = A2AError> + Clone + Send + 'static,
113 S::Future: Send,
114{
115 type Response = A2AResponse;
116 type Error = A2AError;
117 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
118
119 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 self.inner.poll_ready(cx)
121 }
122
123 fn call(&mut self, mut req: A2ARequest) -> Self::Future {
124 req.context.auth = Some(self.credentials.clone());
126
127 let mut inner = self.inner.clone();
128 Box::pin(async move { inner.call(req).await })
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_bearer_credentials() {
138 let creds = AuthCredentials::bearer("test-token");
139 let (header, value) = creds.to_header();
140
141 assert_eq!(header, "Authorization");
142 assert_eq!(value, "Bearer test-token");
143 }
144
145 #[test]
146 fn test_api_key_credentials() {
147 let creds = AuthCredentials::api_key("secret-key", "X-API-Key");
148 let (header, value) = creds.to_header();
149
150 assert_eq!(header, "X-API-Key");
151 assert_eq!(value, "secret-key");
152 }
153
154 #[test]
155 fn test_basic_credentials() {
156 let creds = AuthCredentials::basic("user", "pass");
157 let (header, value) = creds.to_header();
158
159 assert_eq!(header, "Authorization");
160 assert!(value.starts_with("Basic "));
161 }
162}