1use std::net::IpAddr;
5use std::pin::Pin;
6
7use eventsource_stream::Eventsource;
8use futures_core::Stream;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use tokio_stream::StreamExt;
11
12use crate::error::A2aError;
13use crate::jsonrpc::{
14 JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
15 METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
16};
17use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
18
19pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(untagged)]
23pub enum TaskEvent {
24 StatusUpdate(TaskStatusUpdateEvent),
25 ArtifactUpdate(TaskArtifactUpdateEvent),
26}
27
28pub struct A2aClient {
29 client: reqwest::Client,
30 require_tls: bool,
31 ssrf_protection: bool,
32}
33
34impl A2aClient {
35 #[must_use]
36 pub fn new(client: reqwest::Client) -> Self {
37 Self {
38 client,
39 require_tls: false,
40 ssrf_protection: false,
41 }
42 }
43
44 #[must_use]
45 pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
46 self.require_tls = require_tls;
47 self.ssrf_protection = ssrf_protection;
48 self
49 }
50
51 pub async fn send_message(
54 &self,
55 endpoint: &str,
56 params: SendMessageParams,
57 token: Option<&str>,
58 ) -> Result<Task, A2aError> {
59 self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
60 .await
61 }
62
63 pub async fn stream_message(
66 &self,
67 endpoint: &str,
68 params: SendMessageParams,
69 token: Option<&str>,
70 ) -> Result<TaskEventStream, A2aError> {
71 self.validate_endpoint(endpoint).await?;
72 let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
73 let mut req = self.client.post(endpoint).json(&request);
74 if let Some(t) = token {
75 req = req.bearer_auth(t);
76 }
77 let resp = req.send().await?;
78
79 if !resp.status().is_success() {
80 let status = resp.status();
81 let body = resp.text().await.unwrap_or_default();
82 let truncated = if body.len() > 256 {
84 format!("{}…", &body[..256])
85 } else {
86 body
87 };
88 return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
89 }
90
91 let event_stream = resp.bytes_stream().eventsource();
92 let mapped = event_stream.filter_map(|event| match event {
93 Ok(event) => {
94 if event.data.is_empty() || event.data == "[DONE]" {
95 return None;
96 }
97 match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
98 Ok(rpc_resp) => match rpc_resp.into_result() {
99 Ok(task_event) => Some(Ok(task_event)),
100 Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
101 },
102 Err(e) => Some(Err(A2aError::Stream(format!(
103 "failed to parse SSE event: {e}"
104 )))),
105 }
106 }
107 Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
108 });
109
110 Ok(Box::pin(mapped))
111 }
112
113 pub async fn get_task(
116 &self,
117 endpoint: &str,
118 params: TaskIdParams,
119 token: Option<&str>,
120 ) -> Result<Task, A2aError> {
121 self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
122 .await
123 }
124
125 pub async fn cancel_task(
128 &self,
129 endpoint: &str,
130 params: TaskIdParams,
131 token: Option<&str>,
132 ) -> Result<Task, A2aError> {
133 self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
134 .await
135 }
136
137 async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
138 if self.require_tls && !endpoint.starts_with("https://") {
139 return Err(A2aError::Security(format!(
140 "TLS required but endpoint uses HTTP: {endpoint}"
141 )));
142 }
143
144 if self.ssrf_protection {
145 let url: url::Url = endpoint
146 .parse()
147 .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
148
149 if let Some(host) = url.host_str() {
150 let addrs = tokio::net::lookup_host(format!(
151 "{}:{}",
152 host,
153 url.port_or_known_default().unwrap_or(443)
154 ))
155 .await
156 .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
157
158 for addr in addrs {
159 if is_private_ip(addr.ip()) {
160 return Err(A2aError::Security(format!(
161 "SSRF protection: private IP {} for host {host}",
162 addr.ip()
163 )));
164 }
165 }
166 }
167 }
168
169 Ok(())
170 }
171
172 async fn rpc_call<P: Serialize, R: DeserializeOwned>(
173 &self,
174 endpoint: &str,
175 method: &str,
176 params: P,
177 token: Option<&str>,
178 ) -> Result<R, A2aError> {
179 self.validate_endpoint(endpoint).await?;
180 let request = JsonRpcRequest::new(method, params);
181 let mut req = self.client.post(endpoint).json(&request);
182 if let Some(t) = token {
183 req = req.bearer_auth(t);
184 }
185 let resp = req.send().await?;
186 let rpc_response: JsonRpcResponse<R> = resp.json().await?;
187 rpc_response.into_result().map_err(A2aError::from)
188 }
189}
190
191fn is_private_ip(ip: IpAddr) -> bool {
192 match ip {
193 IpAddr::V4(v4) => {
194 v4.is_loopback()
195 || v4.is_private()
196 || v4.is_link_local()
197 || v4.is_unspecified()
198 || v4.is_broadcast()
199 }
200 IpAddr::V6(v6) => {
201 if v6.is_loopback() || v6.is_unspecified() {
202 return true;
203 }
204 let seg = v6.segments();
205 if seg[0] & 0xffc0 == 0xfe80 {
207 return true;
208 }
209 if seg[0] & 0xfe00 == 0xfc00 {
211 return true;
212 }
213 if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
215 let v4 = v6
216 .to_ipv4_mapped()
217 .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
218 return v4.is_loopback()
219 || v4.is_private()
220 || v4.is_link_local()
221 || v4.is_unspecified()
222 || v4.is_broadcast();
223 }
224 false
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
233 use crate::types::{
234 Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
235 TaskStatusUpdateEvent,
236 };
237
238 #[test]
239 fn task_event_deserialize_status_update() {
240 let event = TaskStatusUpdateEvent {
241 kind: "status-update".into(),
242 task_id: "t-1".into(),
243 context_id: None,
244 status: TaskStatus {
245 state: TaskState::Working,
246 timestamp: "ts".into(),
247 message: Some(Message::user_text("thinking...")),
248 },
249 is_final: false,
250 };
251 let json = serde_json::to_string(&event).unwrap();
252 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
253 assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
254 }
255
256 #[test]
257 fn task_event_deserialize_artifact_update() {
258 let event = TaskArtifactUpdateEvent {
259 kind: "artifact-update".into(),
260 task_id: "t-1".into(),
261 context_id: None,
262 artifact: Artifact {
263 artifact_id: "a-1".into(),
264 name: None,
265 parts: vec![Part::text("result")],
266 metadata: None,
267 },
268 is_final: true,
269 };
270 let json = serde_json::to_string(&event).unwrap();
271 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
272 assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
273 }
274
275 #[test]
276 fn rpc_response_with_task_result() {
277 let task = Task {
278 id: "t-1".into(),
279 context_id: None,
280 status: TaskStatus {
281 state: TaskState::Completed,
282 timestamp: "ts".into(),
283 message: None,
284 },
285 artifacts: vec![],
286 history: vec![],
287 metadata: None,
288 };
289 let resp = JsonRpcResponse {
290 jsonrpc: "2.0".into(),
291 id: serde_json::Value::String("req-1".into()),
292 result: Some(task),
293 error: None,
294 };
295 let json = serde_json::to_string(&resp).unwrap();
296 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
297 let task = back.into_result().unwrap();
298 assert_eq!(task.id, "t-1");
299 assert_eq!(task.status.state, TaskState::Completed);
300 }
301
302 #[test]
303 fn rpc_response_with_error() {
304 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
305 jsonrpc: "2.0".into(),
306 id: serde_json::Value::String("req-1".into()),
307 result: None,
308 error: Some(JsonRpcError {
309 code: -32001,
310 message: "task not found".into(),
311 data: None,
312 }),
313 };
314 let json = serde_json::to_string(&resp).unwrap();
315 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
316 let err = back.into_result().unwrap_err();
317 assert_eq!(err.code, -32001);
318 }
319
320 #[test]
321 fn a2a_client_construction() {
322 let client = A2aClient::new(reqwest::Client::new());
323 drop(client);
324 }
325
326 #[test]
327 fn is_private_ip_loopback() {
328 assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
329 assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
330 }
331
332 #[test]
333 fn is_private_ip_private_ranges() {
334 assert!(is_private_ip("10.0.0.1".parse().unwrap()));
335 assert!(is_private_ip("172.16.0.1".parse().unwrap()));
336 assert!(is_private_ip("192.168.1.1".parse().unwrap()));
337 }
338
339 #[test]
340 fn is_private_ip_link_local() {
341 assert!(is_private_ip("169.254.0.1".parse().unwrap()));
342 }
343
344 #[test]
345 fn is_private_ip_unspecified() {
346 assert!(is_private_ip("0.0.0.0".parse().unwrap()));
347 assert!(is_private_ip("::".parse().unwrap()));
348 }
349
350 #[test]
351 fn is_private_ip_public() {
352 assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
353 assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
354 }
355
356 #[tokio::test]
357 async fn tls_enforcement_rejects_http() {
358 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
359 let result = client.validate_endpoint("http://example.com/rpc").await;
360 assert!(result.is_err());
361 let err = result.unwrap_err();
362 assert!(matches!(err, A2aError::Security(_)));
363 assert!(err.to_string().contains("TLS required"));
364 }
365
366 #[tokio::test]
367 async fn tls_enforcement_allows_https() {
368 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
369 let result = client.validate_endpoint("https://example.com/rpc").await;
370 assert!(result.is_ok());
371 }
372
373 #[tokio::test]
374 async fn ssrf_protection_rejects_localhost() {
375 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
376 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
377 assert!(result.is_err());
378 assert!(result.unwrap_err().to_string().contains("SSRF"));
379 }
380
381 #[tokio::test]
382 async fn no_security_allows_http_localhost() {
383 let client = A2aClient::new(reqwest::Client::new());
384 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
385 assert!(result.is_ok());
386 }
387
388 #[test]
389 fn jsonrpc_request_serialization_for_send_message() {
390 let params = SendMessageParams {
391 message: Message::user_text("hello"),
392 configuration: None,
393 };
394 let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
395 let json = serde_json::to_string(&req).unwrap();
396 assert!(json.contains("\"method\":\"message/send\""));
397 assert!(json.contains("\"jsonrpc\":\"2.0\""));
398 assert!(json.contains("\"hello\""));
399 }
400
401 #[test]
402 fn jsonrpc_request_serialization_for_get_task() {
403 let params = TaskIdParams {
404 id: "task-123".into(),
405 history_length: Some(5),
406 };
407 let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
408 let json = serde_json::to_string(&req).unwrap();
409 assert!(json.contains("\"method\":\"tasks/get\""));
410 assert!(json.contains("\"task-123\""));
411 assert!(json.contains("\"historyLength\":5"));
412 }
413
414 #[test]
415 fn jsonrpc_request_serialization_for_cancel_task() {
416 let params = TaskIdParams {
417 id: "task-456".into(),
418 history_length: None,
419 };
420 let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
421 let json = serde_json::to_string(&req).unwrap();
422 assert!(json.contains("\"method\":\"tasks/cancel\""));
423 assert!(!json.contains("historyLength"));
424 }
425
426 #[test]
427 fn jsonrpc_request_serialization_for_stream() {
428 let params = SendMessageParams {
429 message: Message::user_text("stream me"),
430 configuration: None,
431 };
432 let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
433 let json = serde_json::to_string(&req).unwrap();
434 assert!(json.contains("\"method\":\"message/stream\""));
435 }
436
437 #[tokio::test]
438 async fn send_message_connection_error() {
439 let client = A2aClient::new(reqwest::Client::new());
440 let params = SendMessageParams {
441 message: Message::user_text("hello"),
442 configuration: None,
443 };
444 let result = client
445 .send_message("http://127.0.0.1:1/rpc", params, None)
446 .await;
447 assert!(result.is_err());
448 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
449 }
450
451 #[tokio::test]
452 async fn get_task_connection_error() {
453 let client = A2aClient::new(reqwest::Client::new());
454 let params = TaskIdParams {
455 id: "t-1".into(),
456 history_length: None,
457 };
458 let result = client
459 .get_task("http://127.0.0.1:1/rpc", params, None)
460 .await;
461 assert!(result.is_err());
462 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
463 }
464
465 #[tokio::test]
466 async fn cancel_task_connection_error() {
467 let client = A2aClient::new(reqwest::Client::new());
468 let params = TaskIdParams {
469 id: "t-1".into(),
470 history_length: None,
471 };
472 let result = client
473 .cancel_task("http://127.0.0.1:1/rpc", params, None)
474 .await;
475 assert!(result.is_err());
476 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
477 }
478
479 #[tokio::test]
480 async fn stream_message_connection_error() {
481 let client = A2aClient::new(reqwest::Client::new());
482 let params = SendMessageParams {
483 message: Message::user_text("stream me"),
484 configuration: None,
485 };
486 let result = client
487 .stream_message("http://127.0.0.1:1/rpc", params, None)
488 .await;
489 assert!(result.is_err());
490 }
491
492 #[tokio::test]
493 async fn stream_message_tls_required_rejects_http() {
494 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
495 let params = SendMessageParams {
496 message: Message::user_text("hello"),
497 configuration: None,
498 };
499 let result = client
500 .stream_message("http://example.com/rpc", params, None)
501 .await;
502 match result {
503 Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
504 _ => panic!("expected Security error"),
505 }
506 }
507
508 #[tokio::test]
509 async fn send_message_tls_required_rejects_http() {
510 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
511 let params = SendMessageParams {
512 message: Message::user_text("hello"),
513 configuration: None,
514 };
515 let result = client
516 .send_message("http://example.com/rpc", params, None)
517 .await;
518 assert!(result.is_err());
519 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
520 }
521
522 #[tokio::test]
523 async fn get_task_tls_required_rejects_http() {
524 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
525 let params = TaskIdParams {
526 id: "t-1".into(),
527 history_length: None,
528 };
529 let result = client
530 .get_task("http://example.com/rpc", params, None)
531 .await;
532 assert!(result.is_err());
533 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
534 }
535
536 #[tokio::test]
537 async fn cancel_task_tls_required_rejects_http() {
538 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
539 let params = TaskIdParams {
540 id: "t-1".into(),
541 history_length: None,
542 };
543 let result = client
544 .cancel_task("http://example.com/rpc", params, None)
545 .await;
546 assert!(result.is_err());
547 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
548 }
549
550 #[tokio::test]
551 async fn validate_endpoint_invalid_url_with_ssrf() {
552 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
553 let result = client.validate_endpoint("not-a-url").await;
554 assert!(result.is_err());
555 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
556 }
557
558 #[test]
559 fn with_security_returns_configured_client() {
560 let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
561 assert!(client.require_tls);
562 assert!(client.ssrf_protection);
563 }
564
565 #[test]
566 fn default_client_no_security() {
567 let client = A2aClient::new(reqwest::Client::new());
568 assert!(!client.require_tls);
569 assert!(!client.ssrf_protection);
570 }
571
572 #[test]
573 fn task_event_clone() {
574 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
575 kind: "status-update".into(),
576 task_id: "t-1".into(),
577 context_id: None,
578 status: TaskStatus {
579 state: TaskState::Working,
580 timestamp: "ts".into(),
581 message: None,
582 },
583 is_final: false,
584 });
585 let cloned = event.clone();
586 let json1 = serde_json::to_string(&event).unwrap();
587 let json2 = serde_json::to_string(&cloned).unwrap();
588 assert_eq!(json1, json2);
589 }
590
591 #[test]
592 fn task_event_debug() {
593 let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
594 kind: "artifact-update".into(),
595 task_id: "t-1".into(),
596 context_id: None,
597 artifact: Artifact {
598 artifact_id: "a-1".into(),
599 name: None,
600 parts: vec![Part::text("data")],
601 metadata: None,
602 },
603 is_final: true,
604 });
605 let dbg = format!("{event:?}");
606 assert!(dbg.contains("ArtifactUpdate"));
607 }
608
609 #[test]
610 fn is_private_ip_ipv4_non_private() {
611 assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
612 }
613
614 #[test]
615 fn is_private_ip_ipv6_non_private() {
616 assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
617 }
618
619 #[test]
620 fn rpc_response_error_takes_priority_over_result() {
621 let resp = JsonRpcResponse {
622 jsonrpc: "2.0".into(),
623 id: serde_json::Value::String("1".into()),
624 result: Some(Task {
625 id: "t-1".into(),
626 context_id: None,
627 status: TaskStatus {
628 state: TaskState::Completed,
629 timestamp: "ts".into(),
630 message: None,
631 },
632 artifacts: vec![],
633 history: vec![],
634 metadata: None,
635 }),
636 error: Some(JsonRpcError {
637 code: -32001,
638 message: "error".into(),
639 data: None,
640 }),
641 };
642 let err = resp.into_result().unwrap_err();
643 assert_eq!(err.code, -32001);
644 }
645
646 #[test]
647 fn rpc_response_neither_result_nor_error() {
648 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
649 jsonrpc: "2.0".into(),
650 id: serde_json::Value::String("1".into()),
651 result: None,
652 error: None,
653 };
654 let err = resp.into_result().unwrap_err();
655 assert_eq!(err.code, -32603);
656 }
657
658 #[test]
659 fn task_event_serialize_round_trip() {
660 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
661 kind: "status-update".into(),
662 task_id: "t-1".into(),
663 context_id: Some("ctx-1".into()),
664 status: TaskStatus {
665 state: TaskState::Completed,
666 timestamp: "2025-01-01T00:00:00Z".into(),
667 message: Some(Message::user_text("done")),
668 },
669 is_final: true,
670 });
671 let json = serde_json::to_string(&event).unwrap();
672 let back: TaskEvent = serde_json::from_str(&json).unwrap();
673 assert!(matches!(back, TaskEvent::StatusUpdate(_)));
674 }
675}