vtcode_core/a2a/
client.rs1use std::sync::{
5 Arc,
6 atomic::{AtomicU64, Ordering},
7};
8
9use anyhow::Context;
10use futures::{Stream, StreamExt};
11use reqwest::Client;
12use serde_json::Value;
13
14use crate::a2a::agent_card::AgentCard;
15use crate::a2a::errors::{A2aError, A2aErrorCode, A2aResult};
16use crate::a2a::rpc::{
17 JsonRpcRequest, ListTasksParams, METHOD_MESSAGE_SEND, METHOD_MESSAGE_STREAM,
18 METHOD_TASKS_CANCEL, METHOD_TASKS_GET, METHOD_TASKS_LIST, METHOD_TASKS_PUSH_CONFIG_GET,
19 METHOD_TASKS_PUSH_CONFIG_SET, MessageSendParams, SendStreamingMessageResponse, StreamingEvent,
20 TaskIdParams, TaskPushNotificationConfig, TaskQueryParams,
21};
22use crate::a2a::types::Task;
23
24#[derive(Clone, Debug)]
26pub struct A2aClient {
27 base_url: String,
28 http: Client,
29 request_id: Arc<AtomicU64>,
30}
31
32impl A2aClient {
33 pub fn new(base_url: impl Into<String>) -> A2aResult<Self> {
35 let http = Client::builder()
36 .build()
37 .context("Failed to build HTTP client")
38 .map_err(|e| A2aError::Internal(e.to_string()))?;
39
40 Ok(Self {
41 base_url: base_url.into().trim_end_matches('/').to_string(),
42 http,
43 request_id: Arc::new(AtomicU64::new(1)),
44 })
45 }
46
47 fn next_id(&self) -> String {
48 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
49 format!("a2a-{}", id)
50 }
51
52 fn rpc_url(&self) -> String {
53 format!("{}/a2a", self.base_url)
54 }
55
56 fn stream_url(&self) -> String {
57 format!("{}/a2a/stream", self.base_url)
58 }
59
60 fn agent_card_url(&self) -> String {
61 format!("{}/.well-known/agent-card.json", self.base_url)
62 }
63
64 pub async fn agent_card(&self) -> A2aResult<AgentCard> {
66 let resp = self
67 .http
68 .get(self.agent_card_url())
69 .send()
70 .await
71 .context("Failed to fetch agent card")
72 .map_err(|e| A2aError::Internal(e.to_string()))?;
73
74 let status = resp.status();
75 if !status.is_success() {
76 return Err(A2aError::rpc(
77 A2aErrorCode::InvalidAgentResponse,
78 format!("Agent card request failed with status {status}"),
79 ));
80 }
81
82 let card = resp
83 .json::<AgentCard>()
84 .await
85 .context("Invalid agent card response")
86 .map_err(|e| A2aError::Internal(e.to_string()))?;
87 Ok(card)
88 }
89
90 pub async fn send_message(&self, params: MessageSendParams) -> A2aResult<Task> {
92 let result_value = self
93 .call_rpc(METHOD_MESSAGE_SEND, Some(serde_json::to_value(¶ms)?))
94 .await?;
95 let task: Task = serde_json::from_value(result_value)
96 .context("Failed to deserialize task")
97 .map_err(|e| A2aError::Internal(e.to_string()))?;
98 Ok(task)
99 }
100
101 pub async fn stream_message(
103 &self,
104 params: MessageSendParams,
105 ) -> A2aResult<impl Stream<Item = A2aResult<StreamingEvent>>> {
106 let req = JsonRpcRequest::with_string_id(
107 METHOD_MESSAGE_STREAM,
108 Some(serde_json::to_value(¶ms)?),
109 self.next_id(),
110 );
111
112 let response = self
113 .http
114 .post(self.stream_url())
115 .header("accept", "text/event-stream")
116 .json(&req)
117 .send()
118 .await
119 .context("Failed to open streaming request")
120 .map_err(|e| A2aError::Internal(e.to_string()))?;
121
122 let status = response.status();
123 if !status.is_success() {
124 return Err(A2aError::rpc(
125 A2aErrorCode::InvalidAgentResponse,
126 format!("Streaming request failed with status {status}"),
127 ));
128 }
129
130 let byte_stream = response.bytes_stream();
131
132 let stream = async_stream::try_stream! {
133 let mut buffer = Vec::new();
134 futures::pin_mut!(byte_stream);
135
136 while let Some(chunk) = byte_stream.next().await {
137 let chunk = chunk.context("Failed to read streaming chunk")
138 .map_err(|e| A2aError::Internal(e.to_string()))?;
139 buffer.extend_from_slice(&chunk);
140
141 while let Some(pos) = find_double_newline(&buffer) {
142 let event_bytes = buffer.drain(..pos + 2).collect::<Vec<u8>>();
143 if let Some(event) = parse_sse_event(&event_bytes)? {
144 yield event;
145 }
146 }
147 }
148
149 #[expect(clippy::collapsible_if)]
150 if !buffer.is_empty() {
151 if let Some(event) = parse_sse_event(&buffer)? {
152 yield event;
153 }
154 }
155 };
156
157 Ok(stream)
158 }
159
160 pub async fn get_task(&self, task_id: String) -> A2aResult<Task> {
162 let params = serde_json::to_value(TaskQueryParams {
163 id: task_id,
164 history_length: None,
165 })?;
166 let result_value = self.call_rpc(METHOD_TASKS_GET, Some(params)).await?;
167 let task: Task = serde_json::from_value(result_value)
168 .context("Failed to deserialize task")
169 .map_err(|e| A2aError::Internal(e.to_string()))?;
170 Ok(task)
171 }
172
173 pub async fn list_tasks(&self, params: Option<ListTasksParams>) -> A2aResult<Value> {
175 let result_value = self
176 .call_rpc(
177 METHOD_TASKS_LIST,
178 params.map(serde_json::to_value).transpose()?,
179 )
180 .await?;
181 Ok(result_value)
182 }
183
184 pub async fn cancel_task(&self, task_id: String) -> A2aResult<Task> {
186 let params = serde_json::to_value(TaskIdParams { id: task_id })?;
187 let result_value = self.call_rpc(METHOD_TASKS_CANCEL, Some(params)).await?;
188 let task: Task = serde_json::from_value(result_value)
189 .context("Failed to deserialize task")
190 .map_err(|e| A2aError::Internal(e.to_string()))?;
191 Ok(task)
192 }
193
194 pub async fn set_push_config(&self, config: TaskPushNotificationConfig) -> A2aResult<bool> {
196 let value = self
197 .call_rpc(
198 METHOD_TASKS_PUSH_CONFIG_SET,
199 Some(serde_json::to_value(config)?),
200 )
201 .await?;
202 let success = value
204 .get("success")
205 .and_then(|v| v.as_bool())
206 .unwrap_or(false);
207 Ok(success)
208 }
209
210 pub async fn get_push_config(
212 &self,
213 task_id: String,
214 ) -> A2aResult<Option<TaskPushNotificationConfig>> {
215 let params = serde_json::to_value(TaskIdParams { id: task_id })?;
216 let value = self
217 .call_rpc(METHOD_TASKS_PUSH_CONFIG_GET, Some(params))
218 .await?;
219 if value.is_null() {
220 return Ok(None);
221 }
222 let cfg: TaskPushNotificationConfig = serde_json::from_value(value)
223 .context("Failed to deserialize push notification config")
224 .map_err(|e| A2aError::Internal(e.to_string()))?;
225 Ok(Some(cfg))
226 }
227
228 async fn call_rpc(&self, method: &str, params: Option<Value>) -> A2aResult<Value> {
229 let request = JsonRpcRequest::with_string_id(method, params, self.next_id());
230
231 let resp = self
232 .http
233 .post(self.rpc_url())
234 .json(&request)
235 .send()
236 .await
237 .context("RPC request failed")
238 .map_err(|e| A2aError::Internal(e.to_string()))?;
239
240 let status = resp.status();
241 let json: Value = resp
242 .json()
243 .await
244 .context("Failed to parse RPC response")
245 .map_err(|e| A2aError::Internal(e.to_string()))?;
246
247 if !status.is_success() {
248 return Err(A2aError::rpc(
249 A2aErrorCode::InvalidAgentResponse,
250 format!("RPC failed with status {status}: {json:?}"),
251 ));
252 }
253
254 let rpc_response: crate::a2a::rpc::JsonRpcResponse = serde_json::from_value(json.clone())
256 .context("Invalid JSON-RPC response")
257 .map_err(|e| A2aError::Internal(e.to_string()))?;
258
259 if let Some(result) = rpc_response.result {
260 Ok(result)
261 } else if let Some(err) = rpc_response.error {
262 Err(A2aError::rpc(err.code.into(), err.message))
263 } else {
264 Err(A2aError::rpc(
265 A2aErrorCode::InvalidAgentResponse,
266 "Empty RPC response",
267 ))
268 }
269 }
270}
271
272fn find_double_newline(buf: &[u8]) -> Option<usize> {
274 buf.windows(2).position(|w| w == b"\n\n")
275}
276
277fn parse_sse_event(bytes: &[u8]) -> A2aResult<Option<StreamingEvent>> {
279 let text = std::str::from_utf8(bytes)
281 .context("Invalid UTF-8 in SSE event")
282 .map_err(|e| A2aError::Internal(e.to_string()))?;
283
284 for line in text.lines() {
285 if let Some(payload) = line.strip_prefix("data: ") {
286 let wrapper: SendStreamingMessageResponse = serde_json::from_str(payload)
288 .context("Failed to deserialize streaming event")
289 .map_err(|e| A2aError::Internal(e.to_string()))?;
290 return Ok(Some(wrapper.event));
291 }
292 }
293
294 Ok(None)
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_find_double_newline() {
303 let data = b"data: x\n\nrest";
304 assert_eq!(find_double_newline(data), Some(7));
305 }
306
307 #[test]
308 fn test_parse_sse_event_empty() {
309 let res = parse_sse_event(b"event: ping\n\n").unwrap();
310 assert!(res.is_none());
311 }
312}