1use std::net::IpAddr;
2use std::pin::Pin;
3
4use eventsource_stream::Eventsource;
5use futures_core::Stream;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use tokio_stream::StreamExt;
8
9use crate::error::A2aError;
10use crate::jsonrpc::{
11 JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
12 METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
13};
14use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
15
16pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(untagged)]
20pub enum TaskEvent {
21 StatusUpdate(TaskStatusUpdateEvent),
22 ArtifactUpdate(TaskArtifactUpdateEvent),
23}
24
25pub struct A2aClient {
26 client: reqwest::Client,
27 require_tls: bool,
28 ssrf_protection: bool,
29}
30
31impl A2aClient {
32 #[must_use]
33 pub fn new(client: reqwest::Client) -> Self {
34 Self {
35 client,
36 require_tls: false,
37 ssrf_protection: false,
38 }
39 }
40
41 #[must_use]
42 pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
43 self.require_tls = require_tls;
44 self.ssrf_protection = ssrf_protection;
45 self
46 }
47
48 pub async fn send_message(
51 &self,
52 endpoint: &str,
53 params: SendMessageParams,
54 token: Option<&str>,
55 ) -> Result<Task, A2aError> {
56 self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
57 .await
58 }
59
60 pub async fn stream_message(
63 &self,
64 endpoint: &str,
65 params: SendMessageParams,
66 token: Option<&str>,
67 ) -> Result<TaskEventStream, A2aError> {
68 self.validate_endpoint(endpoint).await?;
69 let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
70 let mut req = self.client.post(endpoint).json(&request);
71 if let Some(t) = token {
72 req = req.bearer_auth(t);
73 }
74 let resp = req.send().await?;
75
76 if !resp.status().is_success() {
77 let status = resp.status();
78 let body = resp.text().await.unwrap_or_default();
79 let truncated = if body.len() > 256 {
81 format!("{}…", &body[..256])
82 } else {
83 body
84 };
85 return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
86 }
87
88 let event_stream = resp.bytes_stream().eventsource();
89 let mapped = event_stream.filter_map(|event| match event {
90 Ok(event) => {
91 if event.data.is_empty() || event.data == "[DONE]" {
92 return None;
93 }
94 match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
95 Ok(rpc_resp) => match rpc_resp.into_result() {
96 Ok(task_event) => Some(Ok(task_event)),
97 Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
98 },
99 Err(e) => Some(Err(A2aError::Stream(format!(
100 "failed to parse SSE event: {e}"
101 )))),
102 }
103 }
104 Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
105 });
106
107 Ok(Box::pin(mapped))
108 }
109
110 pub async fn get_task(
113 &self,
114 endpoint: &str,
115 params: TaskIdParams,
116 token: Option<&str>,
117 ) -> Result<Task, A2aError> {
118 self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
119 .await
120 }
121
122 pub async fn cancel_task(
125 &self,
126 endpoint: &str,
127 params: TaskIdParams,
128 token: Option<&str>,
129 ) -> Result<Task, A2aError> {
130 self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
131 .await
132 }
133
134 async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
135 if self.require_tls && !endpoint.starts_with("https://") {
136 return Err(A2aError::Security(format!(
137 "TLS required but endpoint uses HTTP: {endpoint}"
138 )));
139 }
140
141 if self.ssrf_protection {
142 let url: url::Url = endpoint
143 .parse()
144 .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
145
146 if let Some(host) = url.host_str() {
147 let addrs = tokio::net::lookup_host(format!(
148 "{}:{}",
149 host,
150 url.port_or_known_default().unwrap_or(443)
151 ))
152 .await
153 .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
154
155 for addr in addrs {
156 if is_private_ip(addr.ip()) {
157 return Err(A2aError::Security(format!(
158 "SSRF protection: private IP {} for host {host}",
159 addr.ip()
160 )));
161 }
162 }
163 }
164 }
165
166 Ok(())
167 }
168
169 async fn rpc_call<P: Serialize, R: DeserializeOwned>(
170 &self,
171 endpoint: &str,
172 method: &str,
173 params: P,
174 token: Option<&str>,
175 ) -> Result<R, A2aError> {
176 self.validate_endpoint(endpoint).await?;
177 let request = JsonRpcRequest::new(method, params);
178 let mut req = self.client.post(endpoint).json(&request);
179 if let Some(t) = token {
180 req = req.bearer_auth(t);
181 }
182 let resp = req.send().await?;
183 let rpc_response: JsonRpcResponse<R> = resp.json().await?;
184 rpc_response.into_result().map_err(A2aError::from)
185 }
186}
187
188fn is_private_ip(ip: IpAddr) -> bool {
189 match ip {
190 IpAddr::V4(v4) => {
191 v4.is_loopback()
192 || v4.is_private()
193 || v4.is_link_local()
194 || v4.is_unspecified()
195 || v4.is_broadcast()
196 }
197 IpAddr::V6(v6) => {
198 if v6.is_loopback() || v6.is_unspecified() {
199 return true;
200 }
201 let seg = v6.segments();
202 if seg[0] & 0xffc0 == 0xfe80 {
204 return true;
205 }
206 if seg[0] & 0xfe00 == 0xfc00 {
208 return true;
209 }
210 if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
212 let v4 = v6
213 .to_ipv4_mapped()
214 .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
215 return v4.is_loopback()
216 || v4.is_private()
217 || v4.is_link_local()
218 || v4.is_unspecified()
219 || v4.is_broadcast();
220 }
221 false
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
230 use crate::types::{
231 Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
232 TaskStatusUpdateEvent,
233 };
234
235 #[test]
236 fn task_event_deserialize_status_update() {
237 let event = TaskStatusUpdateEvent {
238 kind: "status-update".into(),
239 task_id: "t-1".into(),
240 context_id: None,
241 status: TaskStatus {
242 state: TaskState::Working,
243 timestamp: "ts".into(),
244 message: Some(Message::user_text("thinking...")),
245 },
246 is_final: false,
247 };
248 let json = serde_json::to_string(&event).unwrap();
249 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
250 assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
251 }
252
253 #[test]
254 fn task_event_deserialize_artifact_update() {
255 let event = TaskArtifactUpdateEvent {
256 kind: "artifact-update".into(),
257 task_id: "t-1".into(),
258 context_id: None,
259 artifact: Artifact {
260 artifact_id: "a-1".into(),
261 name: None,
262 parts: vec![Part::text("result")],
263 metadata: None,
264 },
265 is_final: true,
266 };
267 let json = serde_json::to_string(&event).unwrap();
268 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
269 assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
270 }
271
272 #[test]
273 fn rpc_response_with_task_result() {
274 let task = Task {
275 id: "t-1".into(),
276 context_id: None,
277 status: TaskStatus {
278 state: TaskState::Completed,
279 timestamp: "ts".into(),
280 message: None,
281 },
282 artifacts: vec![],
283 history: vec![],
284 metadata: None,
285 };
286 let resp = JsonRpcResponse {
287 jsonrpc: "2.0".into(),
288 id: serde_json::Value::String("req-1".into()),
289 result: Some(task),
290 error: None,
291 };
292 let json = serde_json::to_string(&resp).unwrap();
293 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
294 let task = back.into_result().unwrap();
295 assert_eq!(task.id, "t-1");
296 assert_eq!(task.status.state, TaskState::Completed);
297 }
298
299 #[test]
300 fn rpc_response_with_error() {
301 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
302 jsonrpc: "2.0".into(),
303 id: serde_json::Value::String("req-1".into()),
304 result: None,
305 error: Some(JsonRpcError {
306 code: -32001,
307 message: "task not found".into(),
308 data: None,
309 }),
310 };
311 let json = serde_json::to_string(&resp).unwrap();
312 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
313 let err = back.into_result().unwrap_err();
314 assert_eq!(err.code, -32001);
315 }
316
317 #[test]
318 fn a2a_client_construction() {
319 let client = A2aClient::new(reqwest::Client::new());
320 drop(client);
321 }
322
323 #[test]
324 fn is_private_ip_loopback() {
325 assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
326 assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
327 }
328
329 #[test]
330 fn is_private_ip_private_ranges() {
331 assert!(is_private_ip("10.0.0.1".parse().unwrap()));
332 assert!(is_private_ip("172.16.0.1".parse().unwrap()));
333 assert!(is_private_ip("192.168.1.1".parse().unwrap()));
334 }
335
336 #[test]
337 fn is_private_ip_link_local() {
338 assert!(is_private_ip("169.254.0.1".parse().unwrap()));
339 }
340
341 #[test]
342 fn is_private_ip_unspecified() {
343 assert!(is_private_ip("0.0.0.0".parse().unwrap()));
344 assert!(is_private_ip("::".parse().unwrap()));
345 }
346
347 #[test]
348 fn is_private_ip_public() {
349 assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
350 assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
351 }
352
353 #[tokio::test]
354 async fn tls_enforcement_rejects_http() {
355 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
356 let result = client.validate_endpoint("http://example.com/rpc").await;
357 assert!(result.is_err());
358 let err = result.unwrap_err();
359 assert!(matches!(err, A2aError::Security(_)));
360 assert!(err.to_string().contains("TLS required"));
361 }
362
363 #[tokio::test]
364 async fn tls_enforcement_allows_https() {
365 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
366 let result = client.validate_endpoint("https://example.com/rpc").await;
367 assert!(result.is_ok());
368 }
369
370 #[tokio::test]
371 async fn ssrf_protection_rejects_localhost() {
372 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
373 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
374 assert!(result.is_err());
375 assert!(result.unwrap_err().to_string().contains("SSRF"));
376 }
377
378 #[tokio::test]
379 async fn no_security_allows_http_localhost() {
380 let client = A2aClient::new(reqwest::Client::new());
381 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
382 assert!(result.is_ok());
383 }
384
385 #[test]
386 fn jsonrpc_request_serialization_for_send_message() {
387 let params = SendMessageParams {
388 message: Message::user_text("hello"),
389 configuration: None,
390 };
391 let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
392 let json = serde_json::to_string(&req).unwrap();
393 assert!(json.contains("\"method\":\"message/send\""));
394 assert!(json.contains("\"jsonrpc\":\"2.0\""));
395 assert!(json.contains("\"hello\""));
396 }
397
398 #[test]
399 fn jsonrpc_request_serialization_for_get_task() {
400 let params = TaskIdParams {
401 id: "task-123".into(),
402 history_length: Some(5),
403 };
404 let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
405 let json = serde_json::to_string(&req).unwrap();
406 assert!(json.contains("\"method\":\"tasks/get\""));
407 assert!(json.contains("\"task-123\""));
408 assert!(json.contains("\"historyLength\":5"));
409 }
410
411 #[test]
412 fn jsonrpc_request_serialization_for_cancel_task() {
413 let params = TaskIdParams {
414 id: "task-456".into(),
415 history_length: None,
416 };
417 let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
418 let json = serde_json::to_string(&req).unwrap();
419 assert!(json.contains("\"method\":\"tasks/cancel\""));
420 assert!(!json.contains("historyLength"));
421 }
422
423 #[test]
424 fn jsonrpc_request_serialization_for_stream() {
425 let params = SendMessageParams {
426 message: Message::user_text("stream me"),
427 configuration: None,
428 };
429 let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
430 let json = serde_json::to_string(&req).unwrap();
431 assert!(json.contains("\"method\":\"message/stream\""));
432 }
433
434 #[tokio::test]
435 async fn send_message_connection_error() {
436 let client = A2aClient::new(reqwest::Client::new());
437 let params = SendMessageParams {
438 message: Message::user_text("hello"),
439 configuration: None,
440 };
441 let result = client
442 .send_message("http://127.0.0.1:1/rpc", params, None)
443 .await;
444 assert!(result.is_err());
445 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
446 }
447
448 #[tokio::test]
449 async fn get_task_connection_error() {
450 let client = A2aClient::new(reqwest::Client::new());
451 let params = TaskIdParams {
452 id: "t-1".into(),
453 history_length: None,
454 };
455 let result = client
456 .get_task("http://127.0.0.1:1/rpc", params, None)
457 .await;
458 assert!(result.is_err());
459 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
460 }
461
462 #[tokio::test]
463 async fn cancel_task_connection_error() {
464 let client = A2aClient::new(reqwest::Client::new());
465 let params = TaskIdParams {
466 id: "t-1".into(),
467 history_length: None,
468 };
469 let result = client
470 .cancel_task("http://127.0.0.1:1/rpc", params, None)
471 .await;
472 assert!(result.is_err());
473 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
474 }
475
476 #[tokio::test]
477 async fn stream_message_connection_error() {
478 let client = A2aClient::new(reqwest::Client::new());
479 let params = SendMessageParams {
480 message: Message::user_text("stream me"),
481 configuration: None,
482 };
483 let result = client
484 .stream_message("http://127.0.0.1:1/rpc", params, None)
485 .await;
486 assert!(result.is_err());
487 }
488
489 #[tokio::test]
490 async fn stream_message_tls_required_rejects_http() {
491 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
492 let params = SendMessageParams {
493 message: Message::user_text("hello"),
494 configuration: None,
495 };
496 let result = client
497 .stream_message("http://example.com/rpc", params, None)
498 .await;
499 match result {
500 Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
501 _ => panic!("expected Security error"),
502 }
503 }
504
505 #[tokio::test]
506 async fn send_message_tls_required_rejects_http() {
507 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
508 let params = SendMessageParams {
509 message: Message::user_text("hello"),
510 configuration: None,
511 };
512 let result = client
513 .send_message("http://example.com/rpc", params, None)
514 .await;
515 assert!(result.is_err());
516 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
517 }
518
519 #[tokio::test]
520 async fn get_task_tls_required_rejects_http() {
521 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
522 let params = TaskIdParams {
523 id: "t-1".into(),
524 history_length: None,
525 };
526 let result = client
527 .get_task("http://example.com/rpc", params, None)
528 .await;
529 assert!(result.is_err());
530 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
531 }
532
533 #[tokio::test]
534 async fn cancel_task_tls_required_rejects_http() {
535 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
536 let params = TaskIdParams {
537 id: "t-1".into(),
538 history_length: None,
539 };
540 let result = client
541 .cancel_task("http://example.com/rpc", params, None)
542 .await;
543 assert!(result.is_err());
544 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
545 }
546
547 #[tokio::test]
548 async fn validate_endpoint_invalid_url_with_ssrf() {
549 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
550 let result = client.validate_endpoint("not-a-url").await;
551 assert!(result.is_err());
552 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
553 }
554
555 #[test]
556 fn with_security_returns_configured_client() {
557 let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
558 assert!(client.require_tls);
559 assert!(client.ssrf_protection);
560 }
561
562 #[test]
563 fn default_client_no_security() {
564 let client = A2aClient::new(reqwest::Client::new());
565 assert!(!client.require_tls);
566 assert!(!client.ssrf_protection);
567 }
568
569 #[test]
570 fn task_event_clone() {
571 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
572 kind: "status-update".into(),
573 task_id: "t-1".into(),
574 context_id: None,
575 status: TaskStatus {
576 state: TaskState::Working,
577 timestamp: "ts".into(),
578 message: None,
579 },
580 is_final: false,
581 });
582 let cloned = event.clone();
583 let json1 = serde_json::to_string(&event).unwrap();
584 let json2 = serde_json::to_string(&cloned).unwrap();
585 assert_eq!(json1, json2);
586 }
587
588 #[test]
589 fn task_event_debug() {
590 let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
591 kind: "artifact-update".into(),
592 task_id: "t-1".into(),
593 context_id: None,
594 artifact: Artifact {
595 artifact_id: "a-1".into(),
596 name: None,
597 parts: vec![Part::text("data")],
598 metadata: None,
599 },
600 is_final: true,
601 });
602 let dbg = format!("{event:?}");
603 assert!(dbg.contains("ArtifactUpdate"));
604 }
605
606 #[test]
607 fn is_private_ip_ipv4_non_private() {
608 assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
609 }
610
611 #[test]
612 fn is_private_ip_ipv6_non_private() {
613 assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
614 }
615
616 #[test]
617 fn rpc_response_error_takes_priority_over_result() {
618 let resp = JsonRpcResponse {
619 jsonrpc: "2.0".into(),
620 id: serde_json::Value::String("1".into()),
621 result: Some(Task {
622 id: "t-1".into(),
623 context_id: None,
624 status: TaskStatus {
625 state: TaskState::Completed,
626 timestamp: "ts".into(),
627 message: None,
628 },
629 artifacts: vec![],
630 history: vec![],
631 metadata: None,
632 }),
633 error: Some(JsonRpcError {
634 code: -32001,
635 message: "error".into(),
636 data: None,
637 }),
638 };
639 let err = resp.into_result().unwrap_err();
640 assert_eq!(err.code, -32001);
641 }
642
643 #[test]
644 fn rpc_response_neither_result_nor_error() {
645 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
646 jsonrpc: "2.0".into(),
647 id: serde_json::Value::String("1".into()),
648 result: None,
649 error: None,
650 };
651 let err = resp.into_result().unwrap_err();
652 assert_eq!(err.code, -32603);
653 }
654
655 #[test]
656 fn task_event_serialize_round_trip() {
657 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
658 kind: "status-update".into(),
659 task_id: "t-1".into(),
660 context_id: Some("ctx-1".into()),
661 status: TaskStatus {
662 state: TaskState::Completed,
663 timestamp: "2025-01-01T00:00:00Z".into(),
664 message: Some(Message::user_text("done")),
665 },
666 is_final: true,
667 });
668 let json = serde_json::to_string(&event).unwrap();
669 let back: TaskEvent = serde_json::from_str(&json).unwrap();
670 assert!(matches!(back, TaskEvent::StatusUpdate(_)));
671 }
672}