1use std::pin::Pin;
7
8use eventsource_stream::Eventsource;
9use futures_core::Stream;
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use tokio_stream::StreamExt;
12use zeph_common::net::is_private_ip;
13
14use crate::error::A2aError;
15use crate::jsonrpc::{
16 JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
17 METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
18};
19use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
20
21pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(untagged)]
34pub enum TaskEvent {
35 StatusUpdate(TaskStatusUpdateEvent),
37 ArtifactUpdate(TaskArtifactUpdateEvent),
39}
40
41pub struct A2aClient {
74 client: reqwest::Client,
75 require_tls: bool,
76 ssrf_protection: bool,
77}
78
79impl A2aClient {
80 #[must_use]
85 pub fn new(client: reqwest::Client) -> Self {
86 Self {
87 client,
88 require_tls: false,
89 ssrf_protection: false,
90 }
91 }
92
93 #[must_use]
110 pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
111 self.require_tls = require_tls;
112 self.ssrf_protection = ssrf_protection;
113 self
114 }
115
116 pub async fn send_message(
119 &self,
120 endpoint: &str,
121 params: SendMessageParams,
122 token: Option<&str>,
123 ) -> Result<Task, A2aError> {
124 self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
125 .await
126 }
127
128 pub async fn stream_message(
131 &self,
132 endpoint: &str,
133 params: SendMessageParams,
134 token: Option<&str>,
135 ) -> Result<TaskEventStream, A2aError> {
136 self.validate_endpoint(endpoint).await?;
137 let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
138 let mut req = self.client.post(endpoint).json(&request);
139 if let Some(t) = token {
140 req = req.bearer_auth(t);
141 }
142 let resp = req.send().await?;
143
144 if !resp.status().is_success() {
145 let status = resp.status();
146 let body = resp.text().await.unwrap_or_default();
147 let truncated = if body.len() > 256 {
149 format!("{}…", &body[..256])
150 } else {
151 body
152 };
153 return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
154 }
155
156 let event_stream = resp.bytes_stream().eventsource();
157 let mapped = event_stream.filter_map(|event| match event {
158 Ok(event) => {
159 if event.data.is_empty() || event.data == "[DONE]" {
160 return None;
161 }
162 match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
163 Ok(rpc_resp) => match rpc_resp.into_result() {
164 Ok(task_event) => Some(Ok(task_event)),
165 Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
166 },
167 Err(e) => Some(Err(A2aError::Stream(format!(
168 "failed to parse SSE event: {e}"
169 )))),
170 }
171 }
172 Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
173 });
174
175 Ok(Box::pin(mapped))
176 }
177
178 pub async fn get_task(
181 &self,
182 endpoint: &str,
183 params: TaskIdParams,
184 token: Option<&str>,
185 ) -> Result<Task, A2aError> {
186 self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
187 .await
188 }
189
190 pub async fn cancel_task(
193 &self,
194 endpoint: &str,
195 params: TaskIdParams,
196 token: Option<&str>,
197 ) -> Result<Task, A2aError> {
198 self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
199 .await
200 }
201
202 async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
203 if self.require_tls && !endpoint.starts_with("https://") {
204 return Err(A2aError::Security(format!(
205 "TLS required but endpoint uses HTTP: {endpoint}"
206 )));
207 }
208
209 if self.ssrf_protection {
210 let url: url::Url = endpoint
211 .parse()
212 .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
213
214 if let Some(host) = url.host_str() {
215 let addrs = tokio::net::lookup_host(format!(
216 "{}:{}",
217 host,
218 url.port_or_known_default().unwrap_or(443)
219 ))
220 .await
221 .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
222
223 for addr in addrs {
224 if is_private_ip(addr.ip()) {
225 return Err(A2aError::Security(format!(
226 "SSRF protection: private IP {} for host {host}",
227 addr.ip()
228 )));
229 }
230 }
231 }
232 }
233
234 Ok(())
235 }
236
237 async fn rpc_call<P: Serialize, R: DeserializeOwned>(
238 &self,
239 endpoint: &str,
240 method: &str,
241 params: P,
242 token: Option<&str>,
243 ) -> Result<R, A2aError> {
244 self.validate_endpoint(endpoint).await?;
245 let request = JsonRpcRequest::new(method, params);
246 let mut req = self.client.post(endpoint).json(&request);
247 if let Some(t) = token {
248 req = req.bearer_auth(t);
249 }
250 let resp = req.send().await?;
251 let rpc_response: JsonRpcResponse<R> = resp.json().await?;
252 rpc_response.into_result().map_err(A2aError::from)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::net::IpAddr;
259
260 use super::*;
261 use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
262 use crate::types::{
263 Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
264 TaskStatusUpdateEvent,
265 };
266
267 #[test]
268 fn task_event_deserialize_status_update() {
269 let event = TaskStatusUpdateEvent {
270 kind: "status-update".into(),
271 task_id: "t-1".into(),
272 context_id: None,
273 status: TaskStatus {
274 state: TaskState::Working,
275 timestamp: "ts".into(),
276 message: Some(Message::user_text("thinking...")),
277 },
278 is_final: false,
279 };
280 let json = serde_json::to_string(&event).unwrap();
281 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
282 assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
283 }
284
285 #[test]
286 fn task_event_deserialize_artifact_update() {
287 let event = TaskArtifactUpdateEvent {
288 kind: "artifact-update".into(),
289 task_id: "t-1".into(),
290 context_id: None,
291 artifact: Artifact {
292 artifact_id: "a-1".into(),
293 name: None,
294 parts: vec![Part::text("result")],
295 metadata: None,
296 },
297 is_final: true,
298 };
299 let json = serde_json::to_string(&event).unwrap();
300 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
301 assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
302 }
303
304 #[test]
305 fn rpc_response_with_task_result() {
306 let task = Task {
307 id: "t-1".into(),
308 context_id: None,
309 status: TaskStatus {
310 state: TaskState::Completed,
311 timestamp: "ts".into(),
312 message: None,
313 },
314 artifacts: vec![],
315 history: vec![],
316 metadata: None,
317 };
318 let resp = JsonRpcResponse {
319 jsonrpc: "2.0".into(),
320 id: serde_json::Value::String("req-1".into()),
321 result: Some(task),
322 error: None,
323 };
324 let json = serde_json::to_string(&resp).unwrap();
325 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
326 let task = back.into_result().unwrap();
327 assert_eq!(task.id, "t-1");
328 assert_eq!(task.status.state, TaskState::Completed);
329 }
330
331 #[test]
332 fn rpc_response_with_error() {
333 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
334 jsonrpc: "2.0".into(),
335 id: serde_json::Value::String("req-1".into()),
336 result: None,
337 error: Some(JsonRpcError {
338 code: -32001,
339 message: "task not found".into(),
340 data: None,
341 }),
342 };
343 let json = serde_json::to_string(&resp).unwrap();
344 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
345 let err = back.into_result().unwrap_err();
346 assert_eq!(err.code, -32001);
347 }
348
349 #[test]
350 fn a2a_client_construction() {
351 let client = A2aClient::new(reqwest::Client::new());
352 drop(client);
353 }
354
355 #[test]
356 fn is_private_ip_loopback() {
357 assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
358 assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
359 }
360
361 #[test]
362 fn is_private_ip_private_ranges() {
363 assert!(is_private_ip("10.0.0.1".parse().unwrap()));
364 assert!(is_private_ip("172.16.0.1".parse().unwrap()));
365 assert!(is_private_ip("192.168.1.1".parse().unwrap()));
366 }
367
368 #[test]
369 fn is_private_ip_link_local() {
370 assert!(is_private_ip("169.254.0.1".parse().unwrap()));
371 }
372
373 #[test]
374 fn is_private_ip_unspecified() {
375 assert!(is_private_ip("0.0.0.0".parse().unwrap()));
376 assert!(is_private_ip("::".parse().unwrap()));
377 }
378
379 #[test]
380 fn is_private_ip_public() {
381 assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
382 assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
383 }
384
385 #[tokio::test]
386 async fn tls_enforcement_rejects_http() {
387 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
388 let result = client.validate_endpoint("http://example.com/rpc").await;
389 assert!(result.is_err());
390 let err = result.unwrap_err();
391 assert!(matches!(err, A2aError::Security(_)));
392 assert!(err.to_string().contains("TLS required"));
393 }
394
395 #[tokio::test]
396 async fn tls_enforcement_allows_https() {
397 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
398 let result = client.validate_endpoint("https://example.com/rpc").await;
399 assert!(result.is_ok());
400 }
401
402 #[tokio::test]
403 async fn ssrf_protection_rejects_localhost() {
404 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
405 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
406 assert!(result.is_err());
407 assert!(result.unwrap_err().to_string().contains("SSRF"));
408 }
409
410 #[tokio::test]
411 async fn no_security_allows_http_localhost() {
412 let client = A2aClient::new(reqwest::Client::new());
413 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
414 assert!(result.is_ok());
415 }
416
417 #[test]
418 fn jsonrpc_request_serialization_for_send_message() {
419 let params = SendMessageParams {
420 message: Message::user_text("hello"),
421 configuration: None,
422 };
423 let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
424 let json = serde_json::to_string(&req).unwrap();
425 assert!(json.contains("\"method\":\"message/send\""));
426 assert!(json.contains("\"jsonrpc\":\"2.0\""));
427 assert!(json.contains("\"hello\""));
428 }
429
430 #[test]
431 fn jsonrpc_request_serialization_for_get_task() {
432 let params = TaskIdParams {
433 id: "task-123".into(),
434 history_length: Some(5),
435 };
436 let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
437 let json = serde_json::to_string(&req).unwrap();
438 assert!(json.contains("\"method\":\"tasks/get\""));
439 assert!(json.contains("\"task-123\""));
440 assert!(json.contains("\"historyLength\":5"));
441 }
442
443 #[test]
444 fn jsonrpc_request_serialization_for_cancel_task() {
445 let params = TaskIdParams {
446 id: "task-456".into(),
447 history_length: None,
448 };
449 let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
450 let json = serde_json::to_string(&req).unwrap();
451 assert!(json.contains("\"method\":\"tasks/cancel\""));
452 assert!(!json.contains("historyLength"));
453 }
454
455 #[test]
456 fn jsonrpc_request_serialization_for_stream() {
457 let params = SendMessageParams {
458 message: Message::user_text("stream me"),
459 configuration: None,
460 };
461 let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
462 let json = serde_json::to_string(&req).unwrap();
463 assert!(json.contains("\"method\":\"message/stream\""));
464 }
465
466 #[tokio::test]
467 async fn send_message_connection_error() {
468 let client = A2aClient::new(reqwest::Client::new());
469 let params = SendMessageParams {
470 message: Message::user_text("hello"),
471 configuration: None,
472 };
473 let result = client
474 .send_message("http://127.0.0.1:1/rpc", params, None)
475 .await;
476 assert!(result.is_err());
477 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
478 }
479
480 #[tokio::test]
481 async fn get_task_connection_error() {
482 let client = A2aClient::new(reqwest::Client::new());
483 let params = TaskIdParams {
484 id: "t-1".into(),
485 history_length: None,
486 };
487 let result = client
488 .get_task("http://127.0.0.1:1/rpc", params, None)
489 .await;
490 assert!(result.is_err());
491 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
492 }
493
494 #[tokio::test]
495 async fn cancel_task_connection_error() {
496 let client = A2aClient::new(reqwest::Client::new());
497 let params = TaskIdParams {
498 id: "t-1".into(),
499 history_length: None,
500 };
501 let result = client
502 .cancel_task("http://127.0.0.1:1/rpc", params, None)
503 .await;
504 assert!(result.is_err());
505 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
506 }
507
508 #[tokio::test]
509 async fn stream_message_connection_error() {
510 let client = A2aClient::new(reqwest::Client::new());
511 let params = SendMessageParams {
512 message: Message::user_text("stream me"),
513 configuration: None,
514 };
515 let result = client
516 .stream_message("http://127.0.0.1:1/rpc", params, None)
517 .await;
518 assert!(result.is_err());
519 }
520
521 #[tokio::test]
522 async fn stream_message_tls_required_rejects_http() {
523 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
524 let params = SendMessageParams {
525 message: Message::user_text("hello"),
526 configuration: None,
527 };
528 let result = client
529 .stream_message("http://example.com/rpc", params, None)
530 .await;
531 match result {
532 Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
533 _ => panic!("expected Security error"),
534 }
535 }
536
537 #[tokio::test]
538 async fn send_message_tls_required_rejects_http() {
539 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
540 let params = SendMessageParams {
541 message: Message::user_text("hello"),
542 configuration: None,
543 };
544 let result = client
545 .send_message("http://example.com/rpc", params, None)
546 .await;
547 assert!(result.is_err());
548 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
549 }
550
551 #[tokio::test]
552 async fn get_task_tls_required_rejects_http() {
553 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
554 let params = TaskIdParams {
555 id: "t-1".into(),
556 history_length: None,
557 };
558 let result = client
559 .get_task("http://example.com/rpc", params, None)
560 .await;
561 assert!(result.is_err());
562 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
563 }
564
565 #[tokio::test]
566 async fn cancel_task_tls_required_rejects_http() {
567 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
568 let params = TaskIdParams {
569 id: "t-1".into(),
570 history_length: None,
571 };
572 let result = client
573 .cancel_task("http://example.com/rpc", params, None)
574 .await;
575 assert!(result.is_err());
576 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
577 }
578
579 #[tokio::test]
580 async fn validate_endpoint_invalid_url_with_ssrf() {
581 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
582 let result = client.validate_endpoint("not-a-url").await;
583 assert!(result.is_err());
584 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
585 }
586
587 #[test]
588 fn with_security_returns_configured_client() {
589 let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
590 assert!(client.require_tls);
591 assert!(client.ssrf_protection);
592 }
593
594 #[test]
595 fn default_client_no_security() {
596 let client = A2aClient::new(reqwest::Client::new());
597 assert!(!client.require_tls);
598 assert!(!client.ssrf_protection);
599 }
600
601 #[test]
602 fn task_event_clone() {
603 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
604 kind: "status-update".into(),
605 task_id: "t-1".into(),
606 context_id: None,
607 status: TaskStatus {
608 state: TaskState::Working,
609 timestamp: "ts".into(),
610 message: None,
611 },
612 is_final: false,
613 });
614 let cloned = event.clone();
615 let json1 = serde_json::to_string(&event).unwrap();
616 let json2 = serde_json::to_string(&cloned).unwrap();
617 assert_eq!(json1, json2);
618 }
619
620 #[test]
621 fn task_event_debug() {
622 let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
623 kind: "artifact-update".into(),
624 task_id: "t-1".into(),
625 context_id: None,
626 artifact: Artifact {
627 artifact_id: "a-1".into(),
628 name: None,
629 parts: vec![Part::text("data")],
630 metadata: None,
631 },
632 is_final: true,
633 });
634 let dbg = format!("{event:?}");
635 assert!(dbg.contains("ArtifactUpdate"));
636 }
637
638 #[test]
639 fn is_private_ip_ipv4_non_private() {
640 assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
641 }
642
643 #[test]
644 fn is_private_ip_ipv6_non_private() {
645 assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
646 }
647
648 #[test]
649 fn rpc_response_error_takes_priority_over_result() {
650 let resp = JsonRpcResponse {
651 jsonrpc: "2.0".into(),
652 id: serde_json::Value::String("1".into()),
653 result: Some(Task {
654 id: "t-1".into(),
655 context_id: None,
656 status: TaskStatus {
657 state: TaskState::Completed,
658 timestamp: "ts".into(),
659 message: None,
660 },
661 artifacts: vec![],
662 history: vec![],
663 metadata: None,
664 }),
665 error: Some(JsonRpcError {
666 code: -32001,
667 message: "error".into(),
668 data: None,
669 }),
670 };
671 let err = resp.into_result().unwrap_err();
672 assert_eq!(err.code, -32001);
673 }
674
675 #[test]
676 fn rpc_response_neither_result_nor_error() {
677 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
678 jsonrpc: "2.0".into(),
679 id: serde_json::Value::String("1".into()),
680 result: None,
681 error: None,
682 };
683 let err = resp.into_result().unwrap_err();
684 assert_eq!(err.code, -32603);
685 }
686
687 #[test]
688 fn task_event_serialize_round_trip() {
689 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
690 kind: "status-update".into(),
691 task_id: "t-1".into(),
692 context_id: Some("ctx-1".into()),
693 status: TaskStatus {
694 state: TaskState::Completed,
695 timestamp: "2025-01-01T00:00:00Z".into(),
696 message: Some(Message::user_text("done")),
697 },
698 is_final: true,
699 });
700 let json = serde_json::to_string(&event).unwrap();
701 let back: TaskEvent = serde_json::from_str(&json).unwrap();
702 assert!(matches!(back, TaskEvent::StatusUpdate(_)));
703 }
704}
705
706#[cfg(test)]
707mod wiremock_tests {
708 use tokio_stream::StreamExt;
709 use wiremock::matchers::{header, method, path};
710 use wiremock::{Mock, MockServer, ResponseTemplate};
711
712 use crate::client::A2aClient;
713 use crate::jsonrpc::{SendMessageParams, TaskIdParams};
714 use crate::testing::*;
715 use crate::types::Message;
716
717 #[tokio::test]
718 async fn send_message_success() {
719 let server = MockServer::start().await;
720 Mock::given(method("POST"))
721 .and(path("/rpc"))
722 .respond_with(task_rpc_response("task-1", "submitted"))
723 .mount(&server)
724 .await;
725
726 let client = A2aClient::new(reqwest::Client::new());
727 let params = SendMessageParams {
728 message: Message::user_text("hello"),
729 configuration: None,
730 };
731 let task = client
732 .send_message(&format!("{}/rpc", server.uri()), params, None)
733 .await
734 .unwrap();
735 assert_eq!(task.id, "task-1");
736 }
737
738 #[tokio::test]
739 async fn send_message_rpc_error() {
740 let server = MockServer::start().await;
741 Mock::given(method("POST"))
742 .and(path("/rpc"))
743 .respond_with(task_rpc_error_response(-32001, "task not found"))
744 .mount(&server)
745 .await;
746
747 let client = A2aClient::new(reqwest::Client::new());
748 let params = SendMessageParams {
749 message: Message::user_text("hi"),
750 configuration: None,
751 };
752 let result = client
753 .send_message(&format!("{}/rpc", server.uri()), params, None)
754 .await;
755 assert!(result.is_err());
756 let err = result.unwrap_err();
757 assert!(matches!(
758 err,
759 crate::error::A2aError::JsonRpc { code: -32001, .. }
760 ));
761 }
762
763 #[tokio::test]
764 async fn send_message_with_bearer_auth() {
765 let server = MockServer::start().await;
766 Mock::given(method("POST"))
767 .and(path("/rpc"))
768 .and(header("authorization", "Bearer secret-token"))
769 .respond_with(task_rpc_response("task-auth", "submitted"))
770 .mount(&server)
771 .await;
772
773 let client = A2aClient::new(reqwest::Client::new());
774 let params = SendMessageParams {
775 message: Message::user_text("secure"),
776 configuration: None,
777 };
778 let task = client
779 .send_message(
780 &format!("{}/rpc", server.uri()),
781 params,
782 Some("secret-token"),
783 )
784 .await
785 .unwrap();
786 assert_eq!(task.id, "task-auth");
787 }
788
789 #[tokio::test]
790 async fn get_task_success() {
791 let server = MockServer::start().await;
792 Mock::given(method("POST"))
793 .and(path("/rpc"))
794 .respond_with(task_rpc_response("task-get", "completed"))
795 .mount(&server)
796 .await;
797
798 let client = A2aClient::new(reqwest::Client::new());
799 let params = TaskIdParams {
800 id: "task-get".into(),
801 history_length: None,
802 };
803 let task = client
804 .get_task(&format!("{}/rpc", server.uri()), params, None)
805 .await
806 .unwrap();
807 assert_eq!(task.id, "task-get");
808 }
809
810 #[tokio::test]
811 async fn cancel_task_success() {
812 let server = MockServer::start().await;
813 Mock::given(method("POST"))
814 .and(path("/rpc"))
815 .respond_with(task_rpc_response("task-cancel", "canceled"))
816 .mount(&server)
817 .await;
818
819 let client = A2aClient::new(reqwest::Client::new());
820 let params = TaskIdParams {
821 id: "task-cancel".into(),
822 history_length: None,
823 };
824 let task = client
825 .cancel_task(&format!("{}/rpc", server.uri()), params, None)
826 .await
827 .unwrap();
828 assert_eq!(task.id, "task-cancel");
829 }
830
831 #[tokio::test]
832 async fn stream_message_success() {
833 let server = MockServer::start().await;
834 Mock::given(method("POST"))
835 .and(path("/rpc"))
836 .respond_with(sse_task_events_response("task-stream", "result content"))
837 .mount(&server)
838 .await;
839
840 let client = A2aClient::new(reqwest::Client::new());
841 let params = SendMessageParams {
842 message: Message::user_text("stream"),
843 configuration: None,
844 };
845 let stream = client
846 .stream_message(&format!("{}/rpc", server.uri()), params, None)
847 .await
848 .unwrap();
849 let events: Vec<_> = stream.collect().await;
850 assert!(!events.is_empty());
851 }
852
853 #[tokio::test]
854 async fn stream_message_http_error() {
855 let server = MockServer::start().await;
856 Mock::given(method("POST"))
857 .and(path("/rpc"))
858 .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
859 .mount(&server)
860 .await;
861
862 let client = A2aClient::new(reqwest::Client::new());
863 let params = SendMessageParams {
864 message: Message::user_text("fail"),
865 configuration: None,
866 };
867 let result = client
868 .stream_message(&format!("{}/rpc", server.uri()), params, None)
869 .await;
870 let err = result.err().expect("expected error");
871 assert!(matches!(err, crate::error::A2aError::Stream(_)));
872 }
873}