1use crate::{
4 client::config::ClientConfig,
5 prelude::A2AError,
6 protocol::{A2AOperation, AgentCard, Message, Task, TaskStatus},
7 service::{A2ARequest, A2AResponse, RequestContext},
8};
9use tower_service::Service;
10
11pub struct AgentClient<S> {
33 service: S,
34 config: ClientConfig,
35}
36
37impl<S> AgentClient<S>
38where
39 S: Service<A2ARequest, Response = A2AResponse, Error = A2AError>,
40{
41 pub fn new(service: S, config: ClientConfig) -> Self {
48 Self { service, config }
49 }
50
51 pub fn config(&self) -> &ClientConfig {
53 &self.config
54 }
55
56 fn build_context(&self) -> RequestContext {
58 RequestContext {
59 agent_url: self.config.agent_url.clone(),
60 auth: None, timeout: Some(self.config.timeout),
62 metadata: Default::default(),
63 }
64 }
65
66 pub async fn send_message(&mut self, message: Message) -> Result<Task, A2AError> {
80 let operation = A2AOperation::SendMessage {
81 message,
82 stream: false,
83 context_id: None,
84 task_id: None,
85 };
86
87 let request = A2ARequest::new(operation, self.build_context());
88 let response = self.service.call(request).await?;
89
90 match response {
91 A2AResponse::Task(task) => Ok(*task),
92 _ => Err(A2AError::Protocol(
93 "Expected task response from send_message".into(),
94 )),
95 }
96 }
97
98 pub async fn send_message_streaming(&mut self, message: Message) -> Result<Task, A2AError> {
102 let operation = A2AOperation::SendMessage {
103 message,
104 stream: true,
105 context_id: None,
106 task_id: None,
107 };
108
109 let request = A2ARequest::new(operation, self.build_context());
110 let response = self.service.call(request).await?;
111
112 match response {
113 A2AResponse::Task(task) => Ok(*task),
114 _ => Err(A2AError::Protocol(
115 "Expected task response from send_message_streaming".into(),
116 )),
117 }
118 }
119
120 pub async fn send_message_in_context(
127 &mut self,
128 message: Message,
129 context_id: String,
130 ) -> Result<Task, A2AError> {
131 let operation = A2AOperation::SendMessage {
132 message,
133 stream: false,
134 context_id: Some(context_id),
135 task_id: None,
136 };
137
138 let request = A2ARequest::new(operation, self.build_context());
139 let response = self.service.call(request).await?;
140
141 match response {
142 A2AResponse::Task(task) => Ok(*task),
143 _ => Err(A2AError::Protocol(
144 "Expected task response from send_message_in_context".into(),
145 )),
146 }
147 }
148
149 pub async fn get_task(&mut self, task_id: String) -> Result<Task, A2AError> {
163 let operation = A2AOperation::GetTask { task_id };
164
165 let request = A2ARequest::new(operation, self.build_context());
166 let response = self.service.call(request).await?;
167
168 match response {
169 A2AResponse::Task(task) => Ok(*task),
170 _ => Err(A2AError::Protocol(
171 "Expected task response from get_task".into(),
172 )),
173 }
174 }
175
176 pub async fn list_tasks(
187 &mut self,
188 status: Option<TaskStatus>,
189 limit: Option<u32>,
190 ) -> Result<Vec<Task>, A2AError> {
191 let operation = A2AOperation::ListTasks {
192 status,
193 limit,
194 offset: None,
195 next_token: None,
196 };
197
198 let request = A2ARequest::new(operation, self.build_context());
199 let response = self.service.call(request).await?;
200
201 match response {
202 A2AResponse::TaskList { tasks, .. } => Ok(tasks),
203 _ => Err(A2AError::Protocol(
204 "Expected task list response from list_tasks".into(),
205 )),
206 }
207 }
208
209 pub async fn list_all_tasks(&mut self) -> Result<Vec<Task>, A2AError> {
211 self.list_tasks(None, None).await
212 }
213
214 pub async fn list_tasks_by_status(
216 &mut self,
217 status: TaskStatus,
218 ) -> Result<Vec<Task>, A2AError> {
219 self.list_tasks(Some(status), None).await
220 }
221
222 pub async fn cancel_task(&mut self, task_id: String) -> Result<Task, A2AError> {
232 let operation = A2AOperation::CancelTask { task_id };
233
234 let request = A2ARequest::new(operation, self.build_context());
235 let response = self.service.call(request).await?;
236
237 match response {
238 A2AResponse::Task(task) => Ok(*task),
239 _ => Err(A2AError::Protocol(
240 "Expected task response from cancel_task".into(),
241 )),
242 }
243 }
244
245 pub async fn discover(&mut self) -> Result<AgentCard, A2AError> {
253 let operation = A2AOperation::DiscoverAgent;
254
255 let request = A2ARequest::new(operation, self.build_context());
256 let response = self.service.call(request).await?;
257
258 match response {
259 A2AResponse::AgentCard(card) => Ok(*card),
260 _ => Err(A2AError::Protocol(
261 "Expected agent card response from discover".into(),
262 )),
263 }
264 }
265
266 pub async fn poll_until_complete(
281 &mut self,
282 task_id: String,
283 poll_interval_ms: u64,
284 max_attempts: usize,
285 ) -> Result<Task, A2AError> {
286 let mut attempts = 0;
287
288 loop {
289 let task = self.get_task(task_id.clone()).await?;
290
291 if task.is_terminal() {
292 return Ok(task);
293 }
294
295 attempts += 1;
296 if max_attempts > 0 && attempts >= max_attempts {
297 return Err(A2AError::Timeout);
298 }
299
300 tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval_ms)).await;
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use std::sync::Arc;
308
309 use crate::{
310 codec::JsonCodec,
311 protocol::message::Message,
312 service::A2AProtocolService,
313 transport::{mock::MockTransport, TransportResponse},
314 };
315 use bytes::Bytes;
316
317 use super::*;
318
319 #[tokio::test]
320 async fn test_send_message() {
321 let transport = MockTransport::new(|_req| {
322 let task = Task::new("task-123", Message::user("Test"));
323 let json = serde_json::to_vec(&task).unwrap();
324 TransportResponse::new(200).body(Bytes::from(json))
325 });
326
327 let codec = Arc::new(JsonCodec::new());
328 let service = A2AProtocolService::new(transport, codec);
329 let config = ClientConfig::new("https://example.com");
330 let mut client = AgentClient::new(service, config);
331
332 let message = Message::user("Hello");
333 let task = client.send_message(message).await.unwrap();
334
335 assert_eq!(task.id, "task-123");
336 }
337
338 #[tokio::test]
339 async fn test_get_task() {
340 let transport = MockTransport::new(|_req| {
341 let task = Task::new("task-456", Message::user("Test"));
342 let json = serde_json::to_vec(&task).unwrap();
343 TransportResponse::new(200).body(Bytes::from(json))
344 });
345
346 let codec = Arc::new(JsonCodec::new());
347 let service = A2AProtocolService::new(transport, codec);
348 let config = ClientConfig::new("https://example.com");
349 let mut client = AgentClient::new(service, config);
350
351 let task = client.get_task("task-456".to_string()).await.unwrap();
352
353 assert_eq!(task.id, "task-456");
354 }
355
356 #[tokio::test]
357 async fn test_discover() {
358 use crate::protocol::agent::{AgentCapabilities, AgentCard};
359
360 let transport = MockTransport::new(|_req| {
361 let card = AgentCard::new("Test Agent", "A test agent", AgentCapabilities::default());
362 let json = serde_json::to_vec(&card).unwrap();
363 TransportResponse::new(200).body(Bytes::from(json))
364 });
365
366 let codec = Arc::new(JsonCodec::new());
367 let service = A2AProtocolService::new(transport, codec);
368 let config = ClientConfig::new("https://example.com");
369 let mut client = AgentClient::new(service, config);
370
371 let card = client.discover().await.unwrap();
372
373 assert_eq!(card.name, "Test Agent");
374 }
375}