tower_a2a/layer/
validation.rs1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use tower_layer::Layer;
10use tower_service::Service;
11
12use crate::{
13 prelude::{MessagePart, TaskStatus},
14 protocol::{error::A2AError, operation::A2AOperation},
15 service::{A2ARequest, A2AResponse},
16};
17
18#[derive(Clone, Debug, Default)]
20pub struct A2AValidationLayer;
21
22impl A2AValidationLayer {
23 pub fn new() -> Self {
25 Self
26 }
27}
28
29impl<S> Layer<S> for A2AValidationLayer {
30 type Service = A2AValidationService<S>;
31
32 fn layer(&self, inner: S) -> Self::Service {
33 A2AValidationService { inner }
34 }
35}
36
37#[derive(Clone)]
39pub struct A2AValidationService<S> {
40 inner: S,
41}
42
43impl<S> A2AValidationService<S> {
44 fn validate_request(req: &A2ARequest) -> Result<(), A2AError> {
46 match &req.operation {
47 A2AOperation::SendMessage { message, .. } => {
48 if message.parts.is_empty() {
50 return Err(A2AError::Validation(
51 "Message must have at least one part".into(),
52 ));
53 }
54
55 for part in &message.parts {
57 match part {
58 MessagePart::Text { text } => {
59 if text.is_empty() {
60 return Err(A2AError::Validation(
61 "Text part cannot be empty".into(),
62 ));
63 }
64 }
65 MessagePart::File { file_uri, .. } => {
66 if file_uri.is_empty() {
67 return Err(A2AError::Validation(
68 "File URI cannot be empty".into(),
69 ));
70 }
71 }
72 MessagePart::Data { .. } => {
73 }
75 }
76 }
77 }
78 A2AOperation::GetTask { task_id } => {
79 if task_id.is_empty() {
80 return Err(A2AError::Validation("Task ID cannot be empty".into()));
81 }
82 }
83 A2AOperation::CancelTask { task_id } => {
84 if task_id.is_empty() {
85 return Err(A2AError::Validation("Task ID cannot be empty".into()));
86 }
87 }
88 A2AOperation::ListTasks { limit, offset, .. } => {
89 if let Some(limit_val) = limit {
90 if *limit_val == 0 {
91 return Err(A2AError::Validation("Limit must be greater than 0".into()));
92 }
93 if *limit_val > 1000 {
94 return Err(A2AError::Validation("Limit cannot exceed 1000".into()));
95 }
96 }
97
98 if let Some(offset_val) = offset {
99 if *offset_val > 1000000 {
100 return Err(A2AError::Validation("Offset is too large".into()));
101 }
102 }
103 }
104 A2AOperation::RegisterWebhook { url, events, .. } => {
105 if url.is_empty() {
106 return Err(A2AError::Validation("Webhook URL cannot be empty".into()));
107 }
108 if events.is_empty() {
109 return Err(A2AError::Validation(
110 "Webhook must subscribe to at least one event".into(),
111 ));
112 }
113 }
114 _ => {}
115 }
116
117 if req.context.agent_url.is_empty() {
119 return Err(A2AError::Validation("Agent URL cannot be empty".into()));
120 }
121
122 Ok(())
123 }
124
125 fn validate_response(resp: &A2AResponse) -> Result<(), A2AError> {
127 match resp {
128 A2AResponse::Task(task) => {
129 if task.id.is_empty() {
130 return Err(A2AError::Validation("Task ID cannot be empty".into()));
131 }
132
133 if task.input.parts.is_empty() {
135 return Err(A2AError::Validation(
136 "Task input must have at least one part".into(),
137 ));
138 }
139
140 if task.status == TaskStatus::Completed
142 && task.output.is_none()
143 && task.error.is_none()
144 {
145 return Err(A2AError::Validation(
146 "Completed task must have output or error".into(),
147 ));
148 }
149
150 if task.status == TaskStatus::Failed && task.error.is_none() {
152 return Err(A2AError::Validation(
153 "Failed task must have an error".into(),
154 ));
155 }
156 }
157 A2AResponse::AgentCard(card) => {
158 if card.name.is_empty() {
159 return Err(A2AError::Validation("Agent name cannot be empty".into()));
160 }
161 if card.endpoints.is_empty() {
162 return Err(A2AError::Validation(
163 "Agent card must have at least one endpoint".into(),
164 ));
165 }
166 }
167 _ => {}
168 }
169
170 Ok(())
171 }
172}
173
174impl<S> Service<A2ARequest> for A2AValidationService<S>
175where
176 S: Service<A2ARequest, Response = A2AResponse, Error = A2AError> + Clone + Send + 'static,
177 S::Future: Send,
178{
179 type Response = A2AResponse;
180 type Error = A2AError;
181 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
182
183 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184 self.inner.poll_ready(cx)
185 }
186
187 fn call(&mut self, req: A2ARequest) -> Self::Future {
188 if let Err(e) = Self::validate_request(&req) {
190 return Box::pin(async move { Err(e) });
191 }
192
193 let mut inner = self.inner.clone();
194 Box::pin(async move {
195 let response = inner.call(req).await?;
196
197 Self::validate_response(&response)?;
199
200 Ok(response)
201 })
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use crate::{
208 protocol::{message::Message, task::Task},
209 service::RequestContext,
210 };
211
212 use super::*;
213
214 #[test]
215 fn test_validate_send_message() {
216 let operation = A2AOperation::SendMessage {
217 message: Message::user("Hello"),
218 stream: false,
219 context_id: None,
220 task_id: None,
221 };
222
223 let context = RequestContext::new("https://example.com");
224 let request = A2ARequest::new(operation, context);
225
226 assert!(A2AValidationService::<()>::validate_request(&request).is_ok());
227 }
228
229 #[test]
230 fn test_validate_empty_message() {
231 let mut message = Message::user("Test");
232 message.parts.clear();
233
234 let operation = A2AOperation::SendMessage {
235 message,
236 stream: false,
237 context_id: None,
238 task_id: None,
239 };
240
241 let context = RequestContext::new("https://example.com");
242 let request = A2ARequest::new(operation, context);
243
244 assert!(A2AValidationService::<()>::validate_request(&request).is_err());
245 }
246
247 #[test]
248 fn test_validate_task_response() {
249 let task = Task::new("task-123", Message::user("Test"));
250 let response = A2AResponse::Task(Box::new(task));
251
252 assert!(A2AValidationService::<()>::validate_response(&response).is_ok());
253 }
254}