tower_a2a/client/
agent.rs

1//! High-level A2A agent client
2
3use crate::{
4    client::config::ClientConfig,
5    prelude::A2AError,
6    protocol::{A2AOperation, AgentCard, Message, Task, TaskStatus},
7    service::{A2ARequest, A2AResponse, RequestContext},
8};
9use tower_service::Service;
10
11/// High-level A2A client for interacting with agents
12///
13/// This client wraps a Tower service and provides convenient methods for common A2A operations.
14/// The service is generic over any implementation that satisfies the Service trait bounds.
15///
16/// # Example
17///
18/// ```rust,no_run
19/// use tower_a2a::prelude::*;
20///
21/// # async fn example() -> Result<(), A2AError> {
22/// let url = "https://agent.example.com".parse().unwrap();
23/// let mut client = A2AClientBuilder::new_http(url)
24///     .build()?;
25///
26/// let message = Message::user("Hello, agent!");
27/// let task = client.send_message(message).await?;
28/// println!("Task created: {}", task.id);
29/// # Ok(())
30/// # }
31/// ```
32pub struct AgentClient<S> {
33    service: S,
34    config: ClientConfig,
35}
36
37impl<S> AgentClient<S>
38where
39    S: Service<A2ARequest, Response = A2AResponse, Error = A2AError>,
40{
41    /// Create a new agent client
42    ///
43    /// # Arguments
44    ///
45    /// * `service` - The Tower service that handles requests
46    /// * `config` - Client configuration
47    pub fn new(service: S, config: ClientConfig) -> Self {
48        Self { service, config }
49    }
50
51    /// Get the client configuration
52    pub fn config(&self) -> &ClientConfig {
53        &self.config
54    }
55
56    /// Build a request context from the client configuration
57    fn build_context(&self) -> RequestContext {
58        RequestContext {
59            agent_url: self.config.agent_url.clone(),
60            auth: None, // Set by AuthLayer
61            timeout: Some(self.config.timeout),
62            metadata: Default::default(),
63        }
64    }
65
66    /// Send a message to the agent and get a task
67    ///
68    /// # Arguments
69    ///
70    /// * `message` - The message to send to the agent
71    ///
72    /// # Returns
73    ///
74    /// A task representing the agent's processing of the message
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the message fails to send or the response is invalid
79    pub async fn send_message(&mut self, message: Message) -> Result<Task, A2AError> {
80        let operation = A2AOperation::SendMessage {
81            message,
82            stream: false,
83            context_id: None,
84            task_id: None,
85        };
86
87        let request = A2ARequest::new(operation, self.build_context());
88        let response = self.service.call(request).await?;
89
90        match response {
91            A2AResponse::Task(task) => Ok(*task),
92            _ => Err(A2AError::Protocol(
93                "Expected task response from send_message".into(),
94            )),
95        }
96    }
97
98    /// Send a message with streaming enabled
99    ///
100    /// Note: Streaming is not yet fully implemented
101    pub async fn send_message_streaming(&mut self, message: Message) -> Result<Task, A2AError> {
102        let operation = A2AOperation::SendMessage {
103            message,
104            stream: true,
105            context_id: None,
106            task_id: None,
107        };
108
109        let request = A2ARequest::new(operation, self.build_context());
110        let response = self.service.call(request).await?;
111
112        match response {
113            A2AResponse::Task(task) => Ok(*task),
114            _ => Err(A2AError::Protocol(
115                "Expected task response from send_message_streaming".into(),
116            )),
117        }
118    }
119
120    /// Send a message in a specific context for multi-turn conversations
121    ///
122    /// # Arguments
123    ///
124    /// * `message` - The message to send
125    /// * `context_id` - The context ID for grouping related messages
126    pub async fn send_message_in_context(
127        &mut self,
128        message: Message,
129        context_id: String,
130    ) -> Result<Task, A2AError> {
131        let operation = A2AOperation::SendMessage {
132            message,
133            stream: false,
134            context_id: Some(context_id),
135            task_id: None,
136        };
137
138        let request = A2ARequest::new(operation, self.build_context());
139        let response = self.service.call(request).await?;
140
141        match response {
142            A2AResponse::Task(task) => Ok(*task),
143            _ => Err(A2AError::Protocol(
144                "Expected task response from send_message_in_context".into(),
145            )),
146        }
147    }
148
149    /// Get a task by ID
150    ///
151    /// # Arguments
152    ///
153    /// * `task_id` - The unique identifier of the task to retrieve
154    ///
155    /// # Returns
156    ///
157    /// The task with the specified ID
158    ///
159    /// # Errors
160    ///
161    /// Returns `A2AError::TaskNotFound` if the task doesn't exist
162    pub async fn get_task(&mut self, task_id: String) -> Result<Task, A2AError> {
163        let operation = A2AOperation::GetTask { task_id };
164
165        let request = A2ARequest::new(operation, self.build_context());
166        let response = self.service.call(request).await?;
167
168        match response {
169            A2AResponse::Task(task) => Ok(*task),
170            _ => Err(A2AError::Protocol(
171                "Expected task response from get_task".into(),
172            )),
173        }
174    }
175
176    /// List tasks with optional filtering
177    ///
178    /// # Arguments
179    ///
180    /// * `status` - Optional filter by task status
181    /// * `limit` - Maximum number of tasks to return (default: 100)
182    ///
183    /// # Returns
184    ///
185    /// A vector of tasks matching the query
186    pub async fn list_tasks(
187        &mut self,
188        status: Option<TaskStatus>,
189        limit: Option<u32>,
190    ) -> Result<Vec<Task>, A2AError> {
191        let operation = A2AOperation::ListTasks {
192            status,
193            limit,
194            offset: None,
195            next_token: None,
196        };
197
198        let request = A2ARequest::new(operation, self.build_context());
199        let response = self.service.call(request).await?;
200
201        match response {
202            A2AResponse::TaskList { tasks, .. } => Ok(tasks),
203            _ => Err(A2AError::Protocol(
204                "Expected task list response from list_tasks".into(),
205            )),
206        }
207    }
208
209    /// List all tasks without filtering
210    pub async fn list_all_tasks(&mut self) -> Result<Vec<Task>, A2AError> {
211        self.list_tasks(None, None).await
212    }
213
214    /// List tasks with a specific status
215    pub async fn list_tasks_by_status(
216        &mut self,
217        status: TaskStatus,
218    ) -> Result<Vec<Task>, A2AError> {
219        self.list_tasks(Some(status), None).await
220    }
221
222    /// Cancel a task by ID
223    ///
224    /// # Arguments
225    ///
226    /// * `task_id` - The unique identifier of the task to cancel
227    ///
228    /// # Returns
229    ///
230    /// The updated task with cancelled status
231    pub async fn cancel_task(&mut self, task_id: String) -> Result<Task, A2AError> {
232        let operation = A2AOperation::CancelTask { task_id };
233
234        let request = A2ARequest::new(operation, self.build_context());
235        let response = self.service.call(request).await?;
236
237        match response {
238            A2AResponse::Task(task) => Ok(*task),
239            _ => Err(A2AError::Protocol(
240                "Expected task response from cancel_task".into(),
241            )),
242        }
243    }
244
245    /// Discover agent capabilities by fetching the Agent Card
246    ///
247    /// This retrieves the agent's metadata from `/.well-known/agent-card.json`
248    ///
249    /// # Returns
250    ///
251    /// The agent's capability card
252    pub async fn discover(&mut self) -> Result<AgentCard, A2AError> {
253        let operation = A2AOperation::DiscoverAgent;
254
255        let request = A2ARequest::new(operation, self.build_context());
256        let response = self.service.call(request).await?;
257
258        match response {
259            A2AResponse::AgentCard(card) => Ok(*card),
260            _ => Err(A2AError::Protocol(
261                "Expected agent card response from discover".into(),
262            )),
263        }
264    }
265
266    /// Poll a task until it reaches a terminal state
267    ///
268    /// This is a convenience method that repeatedly calls get_task until
269    /// the task is completed, failed, cancelled, or rejected.
270    ///
271    /// # Arguments
272    ///
273    /// * `task_id` - The task ID to poll
274    /// * `poll_interval` - How often to poll (in milliseconds)
275    /// * `max_attempts` - Maximum number of polling attempts (0 = unlimited)
276    ///
277    /// # Returns
278    ///
279    /// The final task state
280    pub async fn poll_until_complete(
281        &mut self,
282        task_id: String,
283        poll_interval_ms: u64,
284        max_attempts: usize,
285    ) -> Result<Task, A2AError> {
286        let mut attempts = 0;
287
288        loop {
289            let task = self.get_task(task_id.clone()).await?;
290
291            if task.is_terminal() {
292                return Ok(task);
293            }
294
295            attempts += 1;
296            if max_attempts > 0 && attempts >= max_attempts {
297                return Err(A2AError::Timeout);
298            }
299
300            tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval_ms)).await;
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use std::sync::Arc;
308
309    use crate::{
310        codec::JsonCodec,
311        protocol::message::Message,
312        service::A2AProtocolService,
313        transport::{mock::MockTransport, TransportResponse},
314    };
315    use bytes::Bytes;
316
317    use super::*;
318
319    #[tokio::test]
320    async fn test_send_message() {
321        let transport = MockTransport::new(|_req| {
322            let task = Task::new("task-123", Message::user("Test"));
323            let json = serde_json::to_vec(&task).unwrap();
324            TransportResponse::new(200).body(Bytes::from(json))
325        });
326
327        let codec = Arc::new(JsonCodec::new());
328        let service = A2AProtocolService::new(transport, codec);
329        let config = ClientConfig::new("https://example.com");
330        let mut client = AgentClient::new(service, config);
331
332        let message = Message::user("Hello");
333        let task = client.send_message(message).await.unwrap();
334
335        assert_eq!(task.id, "task-123");
336    }
337
338    #[tokio::test]
339    async fn test_get_task() {
340        let transport = MockTransport::new(|_req| {
341            let task = Task::new("task-456", Message::user("Test"));
342            let json = serde_json::to_vec(&task).unwrap();
343            TransportResponse::new(200).body(Bytes::from(json))
344        });
345
346        let codec = Arc::new(JsonCodec::new());
347        let service = A2AProtocolService::new(transport, codec);
348        let config = ClientConfig::new("https://example.com");
349        let mut client = AgentClient::new(service, config);
350
351        let task = client.get_task("task-456".to_string()).await.unwrap();
352
353        assert_eq!(task.id, "task-456");
354    }
355
356    #[tokio::test]
357    async fn test_discover() {
358        use crate::protocol::agent::{AgentCapabilities, AgentCard};
359
360        let transport = MockTransport::new(|_req| {
361            let card = AgentCard::new("Test Agent", "A test agent", AgentCapabilities::default());
362            let json = serde_json::to_vec(&card).unwrap();
363            TransportResponse::new(200).body(Bytes::from(json))
364        });
365
366        let codec = Arc::new(JsonCodec::new());
367        let service = A2AProtocolService::new(transport, codec);
368        let config = ClientConfig::new("https://example.com");
369        let mut client = AgentClient::new(service, config);
370
371        let card = client.discover().await.unwrap();
372
373        assert_eq!(card.name, "Test Agent");
374    }
375}