Skip to main content

vtcode_core/a2a/
client.rs

1//! A2A client for interacting with remote A2A agents.
2//! Provides helper methods for discovery, task operations, and streaming.
3
4use std::sync::{
5    Arc,
6    atomic::{AtomicU64, Ordering},
7};
8
9use anyhow::Context;
10use futures::{Stream, StreamExt};
11use reqwest::Client;
12use serde_json::Value;
13
14use crate::a2a::agent_card::AgentCard;
15use crate::a2a::errors::{A2aError, A2aErrorCode, A2aResult};
16use crate::a2a::rpc::{
17    JsonRpcRequest, ListTasksParams, METHOD_MESSAGE_SEND, METHOD_MESSAGE_STREAM,
18    METHOD_TASKS_CANCEL, METHOD_TASKS_GET, METHOD_TASKS_LIST, METHOD_TASKS_PUSH_CONFIG_GET,
19    METHOD_TASKS_PUSH_CONFIG_SET, MessageSendParams, SendStreamingMessageResponse, StreamingEvent,
20    TaskIdParams, TaskPushNotificationConfig, TaskQueryParams,
21};
22use crate::a2a::types::Task;
23
24/// HTTP client for interacting with A2A agents
25#[derive(Clone, Debug)]
26pub struct A2aClient {
27    base_url: String,
28    http: Client,
29    request_id: Arc<AtomicU64>,
30}
31
32impl A2aClient {
33    /// Create a new client with default reqwest settings
34    pub fn new(base_url: impl Into<String>) -> A2aResult<Self> {
35        let http = Client::builder()
36            .build()
37            .context("Failed to build HTTP client")
38            .map_err(|e| A2aError::Internal(e.to_string()))?;
39
40        Ok(Self {
41            base_url: base_url.into().trim_end_matches('/').to_string(),
42            http,
43            request_id: Arc::new(AtomicU64::new(1)),
44        })
45    }
46
47    fn next_id(&self) -> String {
48        let id = self.request_id.fetch_add(1, Ordering::Relaxed);
49        format!("a2a-{}", id)
50    }
51
52    fn rpc_url(&self) -> String {
53        format!("{}/a2a", self.base_url)
54    }
55
56    fn stream_url(&self) -> String {
57        format!("{}/a2a/stream", self.base_url)
58    }
59
60    fn agent_card_url(&self) -> String {
61        format!("{}/.well-known/agent-card.json", self.base_url)
62    }
63
64    /// Fetch the remote agent card
65    pub async fn agent_card(&self) -> A2aResult<AgentCard> {
66        let resp = self
67            .http
68            .get(self.agent_card_url())
69            .send()
70            .await
71            .context("Failed to fetch agent card")
72            .map_err(|e| A2aError::Internal(e.to_string()))?;
73
74        let status = resp.status();
75        if !status.is_success() {
76            return Err(A2aError::rpc(
77                A2aErrorCode::InvalidAgentResponse,
78                format!("Agent card request failed with status {status}"),
79            ));
80        }
81
82        let card = resp
83            .json::<AgentCard>()
84            .await
85            .context("Invalid agent card response")
86            .map_err(|e| A2aError::Internal(e.to_string()))?;
87        Ok(card)
88    }
89
90    /// Send a message/send RPC
91    pub async fn send_message(&self, params: MessageSendParams) -> A2aResult<Task> {
92        let result_value = self
93            .call_rpc(METHOD_MESSAGE_SEND, Some(serde_json::to_value(&params)?))
94            .await?;
95        let task: Task = serde_json::from_value(result_value)
96            .context("Failed to deserialize task")
97            .map_err(|e| A2aError::Internal(e.to_string()))?;
98        Ok(task)
99    }
100
101    /// Send a message/stream RPC and consume streaming events
102    pub async fn stream_message(
103        &self,
104        params: MessageSendParams,
105    ) -> A2aResult<impl Stream<Item = A2aResult<StreamingEvent>>> {
106        let req = JsonRpcRequest::with_string_id(
107            METHOD_MESSAGE_STREAM,
108            Some(serde_json::to_value(&params)?),
109            self.next_id(),
110        );
111
112        let response = self
113            .http
114            .post(self.stream_url())
115            .header("accept", "text/event-stream")
116            .json(&req)
117            .send()
118            .await
119            .context("Failed to open streaming request")
120            .map_err(|e| A2aError::Internal(e.to_string()))?;
121
122        let status = response.status();
123        if !status.is_success() {
124            return Err(A2aError::rpc(
125                A2aErrorCode::InvalidAgentResponse,
126                format!("Streaming request failed with status {status}"),
127            ));
128        }
129
130        let byte_stream = response.bytes_stream();
131
132        let stream = async_stream::try_stream! {
133            let mut buffer = Vec::new();
134            futures::pin_mut!(byte_stream);
135
136            while let Some(chunk) = byte_stream.next().await {
137                let chunk = chunk.context("Failed to read streaming chunk")
138                    .map_err(|e| A2aError::Internal(e.to_string()))?;
139                buffer.extend_from_slice(&chunk);
140
141                while let Some(pos) = find_double_newline(&buffer) {
142                    let event_bytes = buffer.drain(..pos + 2).collect::<Vec<u8>>();
143                    if let Some(event) = parse_sse_event(&event_bytes)? {
144                        yield event;
145                    }
146                }
147            }
148
149            #[expect(clippy::collapsible_if)]
150            if !buffer.is_empty() {
151                if let Some(event) = parse_sse_event(&buffer)? {
152                    yield event;
153                }
154            }
155        };
156
157        Ok(stream)
158    }
159
160    /// Get a task by ID
161    pub async fn get_task(&self, task_id: String) -> A2aResult<Task> {
162        let params = serde_json::to_value(TaskQueryParams {
163            id: task_id,
164            history_length: None,
165        })?;
166        let result_value = self.call_rpc(METHOD_TASKS_GET, Some(params)).await?;
167        let task: Task = serde_json::from_value(result_value)
168            .context("Failed to deserialize task")
169            .map_err(|e| A2aError::Internal(e.to_string()))?;
170        Ok(task)
171    }
172
173    /// List tasks with filters
174    pub async fn list_tasks(&self, params: Option<ListTasksParams>) -> A2aResult<Value> {
175        let result_value = self
176            .call_rpc(
177                METHOD_TASKS_LIST,
178                params.map(serde_json::to_value).transpose()?,
179            )
180            .await?;
181        Ok(result_value)
182    }
183
184    /// Cancel a task
185    pub async fn cancel_task(&self, task_id: String) -> A2aResult<Task> {
186        let params = serde_json::to_value(TaskIdParams { id: task_id })?;
187        let result_value = self.call_rpc(METHOD_TASKS_CANCEL, Some(params)).await?;
188        let task: Task = serde_json::from_value(result_value)
189            .context("Failed to deserialize task")
190            .map_err(|e| A2aError::Internal(e.to_string()))?;
191        Ok(task)
192    }
193
194    /// Set push notification config
195    pub async fn set_push_config(&self, config: TaskPushNotificationConfig) -> A2aResult<bool> {
196        let value = self
197            .call_rpc(
198                METHOD_TASKS_PUSH_CONFIG_SET,
199                Some(serde_json::to_value(config)?),
200            )
201            .await?;
202        // Server returns {"success": true}
203        let success = value
204            .get("success")
205            .and_then(|v| v.as_bool())
206            .unwrap_or(false);
207        Ok(success)
208    }
209
210    /// Get push notification config
211    pub async fn get_push_config(
212        &self,
213        task_id: String,
214    ) -> A2aResult<Option<TaskPushNotificationConfig>> {
215        let params = serde_json::to_value(TaskIdParams { id: task_id })?;
216        let value = self
217            .call_rpc(METHOD_TASKS_PUSH_CONFIG_GET, Some(params))
218            .await?;
219        if value.is_null() {
220            return Ok(None);
221        }
222        let cfg: TaskPushNotificationConfig = serde_json::from_value(value)
223            .context("Failed to deserialize push notification config")
224            .map_err(|e| A2aError::Internal(e.to_string()))?;
225        Ok(Some(cfg))
226    }
227
228    async fn call_rpc(&self, method: &str, params: Option<Value>) -> A2aResult<Value> {
229        let request = JsonRpcRequest::with_string_id(method, params, self.next_id());
230
231        let resp = self
232            .http
233            .post(self.rpc_url())
234            .json(&request)
235            .send()
236            .await
237            .context("RPC request failed")
238            .map_err(|e| A2aError::Internal(e.to_string()))?;
239
240        let status = resp.status();
241        let json: Value = resp
242            .json()
243            .await
244            .context("Failed to parse RPC response")
245            .map_err(|e| A2aError::Internal(e.to_string()))?;
246
247        if !status.is_success() {
248            return Err(A2aError::rpc(
249                A2aErrorCode::InvalidAgentResponse,
250                format!("RPC failed with status {status}: {json:?}"),
251            ));
252        }
253
254        // Deserialize JSON-RPC envelope
255        let rpc_response: crate::a2a::rpc::JsonRpcResponse = serde_json::from_value(json.clone())
256            .context("Invalid JSON-RPC response")
257            .map_err(|e| A2aError::Internal(e.to_string()))?;
258
259        if let Some(result) = rpc_response.result {
260            Ok(result)
261        } else if let Some(err) = rpc_response.error {
262            Err(A2aError::rpc(err.code.into(), err.message))
263        } else {
264            Err(A2aError::rpc(
265                A2aErrorCode::InvalidAgentResponse,
266                "Empty RPC response",
267            ))
268        }
269    }
270}
271
272/// Find the position of the first double newline delimiter ("\n\n")
273fn find_double_newline(buf: &[u8]) -> Option<usize> {
274    buf.windows(2).position(|w| w == b"\n\n")
275}
276
277/// Parse a single SSE event from raw bytes
278fn parse_sse_event(bytes: &[u8]) -> A2aResult<Option<StreamingEvent>> {
279    // SSE events are lines starting with "data: " and separated by blank line
280    let text = std::str::from_utf8(bytes)
281        .context("Invalid UTF-8 in SSE event")
282        .map_err(|e| A2aError::Internal(e.to_string()))?;
283
284    for line in text.lines() {
285        if let Some(payload) = line.strip_prefix("data: ") {
286            // Parse the streaming response wrapper
287            let wrapper: SendStreamingMessageResponse = serde_json::from_str(payload)
288                .context("Failed to deserialize streaming event")
289                .map_err(|e| A2aError::Internal(e.to_string()))?;
290            return Ok(Some(wrapper.event));
291        }
292    }
293
294    Ok(None)
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_find_double_newline() {
303        let data = b"data: x\n\nrest";
304        assert_eq!(find_double_newline(data), Some(7));
305    }
306
307    #[test]
308    fn test_parse_sse_event_empty() {
309        let res = parse_sse_event(b"event: ping\n\n").unwrap();
310        assert!(res.is_none());
311    }
312}