tower_a2a/layer/
validation.rs

1//! Validation layer for A2A protocol requests and responses
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use tower_layer::Layer;
10use tower_service::Service;
11
12use crate::{
13    prelude::{MessagePart, TaskStatus},
14    protocol::{error::A2AError, operation::A2AOperation},
15    service::{A2ARequest, A2AResponse},
16};
17
18/// Layer that validates A2A protocol requests and responses
19#[derive(Clone, Debug, Default)]
20pub struct A2AValidationLayer;
21
22impl A2AValidationLayer {
23    /// Create a new validation layer
24    pub fn new() -> Self {
25        Self
26    }
27}
28
29impl<S> Layer<S> for A2AValidationLayer {
30    type Service = A2AValidationService<S>;
31
32    fn layer(&self, inner: S) -> Self::Service {
33        A2AValidationService { inner }
34    }
35}
36
37/// Validation service that wraps an inner service
38#[derive(Clone)]
39pub struct A2AValidationService<S> {
40    inner: S,
41}
42
43impl<S> A2AValidationService<S> {
44    /// Validate an A2A request
45    fn validate_request(req: &A2ARequest) -> Result<(), A2AError> {
46        match &req.operation {
47            A2AOperation::SendMessage { message, .. } => {
48                // Message must have at least one part
49                if message.parts.is_empty() {
50                    return Err(A2AError::Validation(
51                        "Message must have at least one part".into(),
52                    ));
53                }
54
55                // Validate each part (basic checks)
56                for part in &message.parts {
57                    match part {
58                        MessagePart::Text { text } => {
59                            if text.is_empty() {
60                                return Err(A2AError::Validation(
61                                    "Text part cannot be empty".into(),
62                                ));
63                            }
64                        }
65                        MessagePart::File { file_uri, .. } => {
66                            if file_uri.is_empty() {
67                                return Err(A2AError::Validation(
68                                    "File URI cannot be empty".into(),
69                                ));
70                            }
71                        }
72                        MessagePart::Data { .. } => {
73                            // Data validation could be more specific
74                        }
75                    }
76                }
77            }
78            A2AOperation::GetTask { task_id } => {
79                if task_id.is_empty() {
80                    return Err(A2AError::Validation("Task ID cannot be empty".into()));
81                }
82            }
83            A2AOperation::CancelTask { task_id } => {
84                if task_id.is_empty() {
85                    return Err(A2AError::Validation("Task ID cannot be empty".into()));
86                }
87            }
88            A2AOperation::ListTasks { limit, offset, .. } => {
89                if let Some(limit_val) = limit {
90                    if *limit_val == 0 {
91                        return Err(A2AError::Validation("Limit must be greater than 0".into()));
92                    }
93                    if *limit_val > 1000 {
94                        return Err(A2AError::Validation("Limit cannot exceed 1000".into()));
95                    }
96                }
97
98                if let Some(offset_val) = offset {
99                    if *offset_val > 1000000 {
100                        return Err(A2AError::Validation("Offset is too large".into()));
101                    }
102                }
103            }
104            A2AOperation::RegisterWebhook { url, events, .. } => {
105                if url.is_empty() {
106                    return Err(A2AError::Validation("Webhook URL cannot be empty".into()));
107                }
108                if events.is_empty() {
109                    return Err(A2AError::Validation(
110                        "Webhook must subscribe to at least one event".into(),
111                    ));
112                }
113            }
114            _ => {}
115        }
116
117        // Validate agent URL
118        if req.context.agent_url.is_empty() {
119            return Err(A2AError::Validation("Agent URL cannot be empty".into()));
120        }
121
122        Ok(())
123    }
124
125    /// Validate an A2A response
126    fn validate_response(resp: &A2AResponse) -> Result<(), A2AError> {
127        match resp {
128            A2AResponse::Task(task) => {
129                if task.id.is_empty() {
130                    return Err(A2AError::Validation("Task ID cannot be empty".into()));
131                }
132
133                // Validate task has input
134                if task.input.parts.is_empty() {
135                    return Err(A2AError::Validation(
136                        "Task input must have at least one part".into(),
137                    ));
138                }
139
140                // If task is completed, it should have output or error
141                if task.status == TaskStatus::Completed
142                    && task.output.is_none()
143                    && task.error.is_none()
144                {
145                    return Err(A2AError::Validation(
146                        "Completed task must have output or error".into(),
147                    ));
148                }
149
150                // If task is failed, it should have an error
151                if task.status == TaskStatus::Failed && task.error.is_none() {
152                    return Err(A2AError::Validation(
153                        "Failed task must have an error".into(),
154                    ));
155                }
156            }
157            A2AResponse::AgentCard(card) => {
158                if card.name.is_empty() {
159                    return Err(A2AError::Validation("Agent name cannot be empty".into()));
160                }
161                if card.endpoints.is_empty() {
162                    return Err(A2AError::Validation(
163                        "Agent card must have at least one endpoint".into(),
164                    ));
165                }
166            }
167            _ => {}
168        }
169
170        Ok(())
171    }
172}
173
174impl<S> Service<A2ARequest> for A2AValidationService<S>
175where
176    S: Service<A2ARequest, Response = A2AResponse, Error = A2AError> + Clone + Send + 'static,
177    S::Future: Send,
178{
179    type Response = A2AResponse;
180    type Error = A2AError;
181    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
182
183    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184        self.inner.poll_ready(cx)
185    }
186
187    fn call(&mut self, req: A2ARequest) -> Self::Future {
188        // Validate request before passing to inner service
189        if let Err(e) = Self::validate_request(&req) {
190            return Box::pin(async move { Err(e) });
191        }
192
193        let mut inner = self.inner.clone();
194        Box::pin(async move {
195            let response = inner.call(req).await?;
196
197            // Validate response
198            Self::validate_response(&response)?;
199
200            Ok(response)
201        })
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use crate::{
208        protocol::{message::Message, task::Task},
209        service::RequestContext,
210    };
211
212    use super::*;
213
214    #[test]
215    fn test_validate_send_message() {
216        let operation = A2AOperation::SendMessage {
217            message: Message::user("Hello"),
218            stream: false,
219            context_id: None,
220            task_id: None,
221        };
222
223        let context = RequestContext::new("https://example.com");
224        let request = A2ARequest::new(operation, context);
225
226        assert!(A2AValidationService::<()>::validate_request(&request).is_ok());
227    }
228
229    #[test]
230    fn test_validate_empty_message() {
231        let mut message = Message::user("Test");
232        message.parts.clear();
233
234        let operation = A2AOperation::SendMessage {
235            message,
236            stream: false,
237            context_id: None,
238            task_id: None,
239        };
240
241        let context = RequestContext::new("https://example.com");
242        let request = A2ARequest::new(operation, context);
243
244        assert!(A2AValidationService::<()>::validate_request(&request).is_err());
245    }
246
247    #[test]
248    fn test_validate_task_response() {
249        let task = Task::new("task-123", Message::user("Test"));
250        let response = A2AResponse::Task(Box::new(task));
251
252        assert!(A2AValidationService::<()>::validate_response(&response).is_ok());
253    }
254}