Skip to main content

zeph_a2a/
client.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! A2A protocol HTTP client with optional TLS enforcement and SSRF protection.
5
6use std::pin::Pin;
7
8use eventsource_stream::Eventsource;
9use futures_core::Stream;
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use tokio_stream::StreamExt;
12use zeph_common::net::is_private_ip;
13
14use crate::error::A2aError;
15use crate::jsonrpc::{
16    JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
17    METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
18};
19use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
20
21/// A pinned, heap-allocated stream of [`TaskEvent`]s from a streaming A2A call.
22///
23/// Produced by [`A2aClient::stream_message`]. Each item is either a status update
24/// or an artifact update; errors are surfaced inline as `Err(A2aError)`.
25pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
26
27/// A single event received on a streaming (`message/stream`) A2A connection.
28///
29/// The A2A spec multiplexes two event kinds over the same SSE channel. This enum
30/// uses `#[serde(untagged)]` so that the deserializer inspects the `kind` field
31/// inside the inner struct to determine the variant.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(untagged)]
34pub enum TaskEvent {
35    /// A task lifecycle transition (e.g., `submitted` → `working` → `completed`).
36    StatusUpdate(TaskStatusUpdateEvent),
37    /// A new or updated output artifact from the agent.
38    ArtifactUpdate(TaskArtifactUpdateEvent),
39}
40
41/// HTTP client for the A2A protocol.
42///
43/// `A2aClient` wraps a `reqwest::Client` and provides typed methods for the four
44/// A2A JSON-RPC operations: `message/send`, `message/stream`, `tasks/get`, and
45/// `tasks/cancel`. Each call optionally accepts a bearer token for authentication.
46///
47/// # Security
48///
49/// Use [`with_security`](A2aClient::with_security) to harden the client for
50/// production deployments:
51/// - `require_tls = true` rejects any `http://` endpoint before connecting.
52/// - `ssrf_protection = true` resolves the endpoint's hostname via DNS and rejects
53///   addresses in private/loopback ranges (10/8, 172.16/12, 192.168/16, 127/8, etc.).
54///
55/// # Examples
56///
57/// ```rust,no_run
58/// use zeph_a2a::{A2aClient, SendMessageParams, Message};
59///
60/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
61/// let client = A2aClient::new(reqwest::Client::new())
62///     .with_security(true, true); // require HTTPS, block SSRF
63///
64/// let params = SendMessageParams {
65///     message: Message::user_text("Summarize this page."),
66///     configuration: None,
67/// };
68/// let task = client.send_message("https://agent.example.com/a2a", params, Some("tok")).await?;
69/// println!("Task state: {:?}", task.status.state);
70/// # Ok(())
71/// # }
72/// ```
73pub struct A2aClient {
74    client: reqwest::Client,
75    require_tls: bool,
76    ssrf_protection: bool,
77}
78
79impl A2aClient {
80    /// Create a new `A2aClient` with no security restrictions.
81    ///
82    /// Security features are disabled by default for local/dev usage. Enable them
83    /// with [`with_security`](Self::with_security) for production deployments.
84    #[must_use]
85    pub fn new(client: reqwest::Client) -> Self {
86        Self {
87            client,
88            require_tls: false,
89            ssrf_protection: false,
90        }
91    }
92
93    /// Configure TLS enforcement and SSRF protection for this client.
94    ///
95    /// Both flags default to `false`. This method uses the builder pattern and
96    /// can be chained directly after [`new`](Self::new).
97    ///
98    /// - `require_tls`: reject any endpoint that does not start with `https://`.
99    /// - `ssrf_protection`: resolve the endpoint hostname via DNS and reject private IP ranges.
100    ///
101    /// # Examples
102    ///
103    /// ```rust
104    /// use zeph_a2a::A2aClient;
105    ///
106    /// let client = A2aClient::new(reqwest::Client::new())
107    ///     .with_security(true, true);
108    /// ```
109    #[must_use]
110    pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
111        self.require_tls = require_tls;
112        self.ssrf_protection = ssrf_protection;
113        self
114    }
115
116    /// # Errors
117    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
118    pub async fn send_message(
119        &self,
120        endpoint: &str,
121        params: SendMessageParams,
122        token: Option<&str>,
123    ) -> Result<Task, A2aError> {
124        self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
125            .await
126    }
127
128    /// # Errors
129    /// Returns `A2aError` on network failure or if the SSE connection cannot be established.
130    pub async fn stream_message(
131        &self,
132        endpoint: &str,
133        params: SendMessageParams,
134        token: Option<&str>,
135    ) -> Result<TaskEventStream, A2aError> {
136        self.validate_endpoint(endpoint).await?;
137        let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
138        let mut req = self.client.post(endpoint).json(&request);
139        if let Some(t) = token {
140            req = req.bearer_auth(t);
141        }
142        let resp = req.send().await?;
143
144        if !resp.status().is_success() {
145            let status = resp.status();
146            let body = resp.text().await.unwrap_or_default();
147            // Truncate body to avoid leaking large upstream error responses.
148            let truncated = if body.len() > 256 {
149                format!("{}…", &body[..256])
150            } else {
151                body
152            };
153            return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
154        }
155
156        let event_stream = resp.bytes_stream().eventsource();
157        let mapped = event_stream.filter_map(|event| match event {
158            Ok(event) => {
159                if event.data.is_empty() || event.data == "[DONE]" {
160                    return None;
161                }
162                match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
163                    Ok(rpc_resp) => match rpc_resp.into_result() {
164                        Ok(task_event) => Some(Ok(task_event)),
165                        Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
166                    },
167                    Err(e) => Some(Err(A2aError::Stream(format!(
168                        "failed to parse SSE event: {e}"
169                    )))),
170                }
171            }
172            Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
173        });
174
175        Ok(Box::pin(mapped))
176    }
177
178    /// # Errors
179    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
180    pub async fn get_task(
181        &self,
182        endpoint: &str,
183        params: TaskIdParams,
184        token: Option<&str>,
185    ) -> Result<Task, A2aError> {
186        self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
187            .await
188    }
189
190    /// # Errors
191    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
192    pub async fn cancel_task(
193        &self,
194        endpoint: &str,
195        params: TaskIdParams,
196        token: Option<&str>,
197    ) -> Result<Task, A2aError> {
198        self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
199            .await
200    }
201
202    async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
203        if self.require_tls && !endpoint.starts_with("https://") {
204            return Err(A2aError::Security(format!(
205                "TLS required but endpoint uses HTTP: {endpoint}"
206            )));
207        }
208
209        if self.ssrf_protection {
210            let url: url::Url = endpoint
211                .parse()
212                .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
213
214            if let Some(host) = url.host_str() {
215                let addrs = tokio::net::lookup_host(format!(
216                    "{}:{}",
217                    host,
218                    url.port_or_known_default().unwrap_or(443)
219                ))
220                .await
221                .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
222
223                for addr in addrs {
224                    if is_private_ip(addr.ip()) {
225                        return Err(A2aError::Security(format!(
226                            "SSRF protection: private IP {} for host {host}",
227                            addr.ip()
228                        )));
229                    }
230                }
231            }
232        }
233
234        Ok(())
235    }
236
237    async fn rpc_call<P: Serialize, R: DeserializeOwned>(
238        &self,
239        endpoint: &str,
240        method: &str,
241        params: P,
242        token: Option<&str>,
243    ) -> Result<R, A2aError> {
244        self.validate_endpoint(endpoint).await?;
245        let request = JsonRpcRequest::new(method, params);
246        let mut req = self.client.post(endpoint).json(&request);
247        if let Some(t) = token {
248            req = req.bearer_auth(t);
249        }
250        let resp = req.send().await?;
251        let rpc_response: JsonRpcResponse<R> = resp.json().await?;
252        rpc_response.into_result().map_err(A2aError::from)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::net::IpAddr;
259
260    use super::*;
261    use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
262    use crate::types::{
263        Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
264        TaskStatusUpdateEvent,
265    };
266
267    #[test]
268    fn task_event_deserialize_status_update() {
269        let event = TaskStatusUpdateEvent {
270            kind: "status-update".into(),
271            task_id: "t-1".into(),
272            context_id: None,
273            status: TaskStatus {
274                state: TaskState::Working,
275                timestamp: "ts".into(),
276                message: Some(Message::user_text("thinking...")),
277            },
278            is_final: false,
279        };
280        let json = serde_json::to_string(&event).unwrap();
281        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
282        assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
283    }
284
285    #[test]
286    fn task_event_deserialize_artifact_update() {
287        let event = TaskArtifactUpdateEvent {
288            kind: "artifact-update".into(),
289            task_id: "t-1".into(),
290            context_id: None,
291            artifact: Artifact {
292                artifact_id: "a-1".into(),
293                name: None,
294                parts: vec![Part::text("result")],
295                metadata: None,
296            },
297            is_final: true,
298        };
299        let json = serde_json::to_string(&event).unwrap();
300        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
301        assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
302    }
303
304    #[test]
305    fn rpc_response_with_task_result() {
306        let task = Task {
307            id: "t-1".into(),
308            context_id: None,
309            status: TaskStatus {
310                state: TaskState::Completed,
311                timestamp: "ts".into(),
312                message: None,
313            },
314            artifacts: vec![],
315            history: vec![],
316            metadata: None,
317        };
318        let resp = JsonRpcResponse {
319            jsonrpc: "2.0".into(),
320            id: serde_json::Value::String("req-1".into()),
321            result: Some(task),
322            error: None,
323        };
324        let json = serde_json::to_string(&resp).unwrap();
325        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
326        let task = back.into_result().unwrap();
327        assert_eq!(task.id, "t-1");
328        assert_eq!(task.status.state, TaskState::Completed);
329    }
330
331    #[test]
332    fn rpc_response_with_error() {
333        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
334            jsonrpc: "2.0".into(),
335            id: serde_json::Value::String("req-1".into()),
336            result: None,
337            error: Some(JsonRpcError {
338                code: -32001,
339                message: "task not found".into(),
340                data: None,
341            }),
342        };
343        let json = serde_json::to_string(&resp).unwrap();
344        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
345        let err = back.into_result().unwrap_err();
346        assert_eq!(err.code, -32001);
347    }
348
349    #[test]
350    fn a2a_client_construction() {
351        let client = A2aClient::new(reqwest::Client::new());
352        drop(client);
353    }
354
355    #[test]
356    fn is_private_ip_loopback() {
357        assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
358        assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
359    }
360
361    #[test]
362    fn is_private_ip_private_ranges() {
363        assert!(is_private_ip("10.0.0.1".parse().unwrap()));
364        assert!(is_private_ip("172.16.0.1".parse().unwrap()));
365        assert!(is_private_ip("192.168.1.1".parse().unwrap()));
366    }
367
368    #[test]
369    fn is_private_ip_link_local() {
370        assert!(is_private_ip("169.254.0.1".parse().unwrap()));
371    }
372
373    #[test]
374    fn is_private_ip_unspecified() {
375        assert!(is_private_ip("0.0.0.0".parse().unwrap()));
376        assert!(is_private_ip("::".parse().unwrap()));
377    }
378
379    #[test]
380    fn is_private_ip_public() {
381        assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
382        assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
383    }
384
385    #[tokio::test]
386    async fn tls_enforcement_rejects_http() {
387        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
388        let result = client.validate_endpoint("http://example.com/rpc").await;
389        assert!(result.is_err());
390        let err = result.unwrap_err();
391        assert!(matches!(err, A2aError::Security(_)));
392        assert!(err.to_string().contains("TLS required"));
393    }
394
395    #[tokio::test]
396    async fn tls_enforcement_allows_https() {
397        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
398        let result = client.validate_endpoint("https://example.com/rpc").await;
399        assert!(result.is_ok());
400    }
401
402    #[tokio::test]
403    async fn ssrf_protection_rejects_localhost() {
404        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
405        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
406        assert!(result.is_err());
407        assert!(result.unwrap_err().to_string().contains("SSRF"));
408    }
409
410    #[tokio::test]
411    async fn no_security_allows_http_localhost() {
412        let client = A2aClient::new(reqwest::Client::new());
413        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
414        assert!(result.is_ok());
415    }
416
417    #[test]
418    fn jsonrpc_request_serialization_for_send_message() {
419        let params = SendMessageParams {
420            message: Message::user_text("hello"),
421            configuration: None,
422        };
423        let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
424        let json = serde_json::to_string(&req).unwrap();
425        assert!(json.contains("\"method\":\"message/send\""));
426        assert!(json.contains("\"jsonrpc\":\"2.0\""));
427        assert!(json.contains("\"hello\""));
428    }
429
430    #[test]
431    fn jsonrpc_request_serialization_for_get_task() {
432        let params = TaskIdParams {
433            id: "task-123".into(),
434            history_length: Some(5),
435        };
436        let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
437        let json = serde_json::to_string(&req).unwrap();
438        assert!(json.contains("\"method\":\"tasks/get\""));
439        assert!(json.contains("\"task-123\""));
440        assert!(json.contains("\"historyLength\":5"));
441    }
442
443    #[test]
444    fn jsonrpc_request_serialization_for_cancel_task() {
445        let params = TaskIdParams {
446            id: "task-456".into(),
447            history_length: None,
448        };
449        let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
450        let json = serde_json::to_string(&req).unwrap();
451        assert!(json.contains("\"method\":\"tasks/cancel\""));
452        assert!(!json.contains("historyLength"));
453    }
454
455    #[test]
456    fn jsonrpc_request_serialization_for_stream() {
457        let params = SendMessageParams {
458            message: Message::user_text("stream me"),
459            configuration: None,
460        };
461        let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
462        let json = serde_json::to_string(&req).unwrap();
463        assert!(json.contains("\"method\":\"message/stream\""));
464    }
465
466    #[tokio::test]
467    async fn send_message_connection_error() {
468        let client = A2aClient::new(reqwest::Client::new());
469        let params = SendMessageParams {
470            message: Message::user_text("hello"),
471            configuration: None,
472        };
473        let result = client
474            .send_message("http://127.0.0.1:1/rpc", params, None)
475            .await;
476        assert!(result.is_err());
477        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
478    }
479
480    #[tokio::test]
481    async fn get_task_connection_error() {
482        let client = A2aClient::new(reqwest::Client::new());
483        let params = TaskIdParams {
484            id: "t-1".into(),
485            history_length: None,
486        };
487        let result = client
488            .get_task("http://127.0.0.1:1/rpc", params, None)
489            .await;
490        assert!(result.is_err());
491        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
492    }
493
494    #[tokio::test]
495    async fn cancel_task_connection_error() {
496        let client = A2aClient::new(reqwest::Client::new());
497        let params = TaskIdParams {
498            id: "t-1".into(),
499            history_length: None,
500        };
501        let result = client
502            .cancel_task("http://127.0.0.1:1/rpc", params, None)
503            .await;
504        assert!(result.is_err());
505        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
506    }
507
508    #[tokio::test]
509    async fn stream_message_connection_error() {
510        let client = A2aClient::new(reqwest::Client::new());
511        let params = SendMessageParams {
512            message: Message::user_text("stream me"),
513            configuration: None,
514        };
515        let result = client
516            .stream_message("http://127.0.0.1:1/rpc", params, None)
517            .await;
518        assert!(result.is_err());
519    }
520
521    #[tokio::test]
522    async fn stream_message_tls_required_rejects_http() {
523        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
524        let params = SendMessageParams {
525            message: Message::user_text("hello"),
526            configuration: None,
527        };
528        let result = client
529            .stream_message("http://example.com/rpc", params, None)
530            .await;
531        match result {
532            Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
533            _ => panic!("expected Security error"),
534        }
535    }
536
537    #[tokio::test]
538    async fn send_message_tls_required_rejects_http() {
539        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
540        let params = SendMessageParams {
541            message: Message::user_text("hello"),
542            configuration: None,
543        };
544        let result = client
545            .send_message("http://example.com/rpc", params, None)
546            .await;
547        assert!(result.is_err());
548        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
549    }
550
551    #[tokio::test]
552    async fn get_task_tls_required_rejects_http() {
553        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
554        let params = TaskIdParams {
555            id: "t-1".into(),
556            history_length: None,
557        };
558        let result = client
559            .get_task("http://example.com/rpc", params, None)
560            .await;
561        assert!(result.is_err());
562        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
563    }
564
565    #[tokio::test]
566    async fn cancel_task_tls_required_rejects_http() {
567        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
568        let params = TaskIdParams {
569            id: "t-1".into(),
570            history_length: None,
571        };
572        let result = client
573            .cancel_task("http://example.com/rpc", params, None)
574            .await;
575        assert!(result.is_err());
576        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
577    }
578
579    #[tokio::test]
580    async fn validate_endpoint_invalid_url_with_ssrf() {
581        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
582        let result = client.validate_endpoint("not-a-url").await;
583        assert!(result.is_err());
584        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
585    }
586
587    #[test]
588    fn with_security_returns_configured_client() {
589        let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
590        assert!(client.require_tls);
591        assert!(client.ssrf_protection);
592    }
593
594    #[test]
595    fn default_client_no_security() {
596        let client = A2aClient::new(reqwest::Client::new());
597        assert!(!client.require_tls);
598        assert!(!client.ssrf_protection);
599    }
600
601    #[test]
602    fn task_event_clone() {
603        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
604            kind: "status-update".into(),
605            task_id: "t-1".into(),
606            context_id: None,
607            status: TaskStatus {
608                state: TaskState::Working,
609                timestamp: "ts".into(),
610                message: None,
611            },
612            is_final: false,
613        });
614        let cloned = event.clone();
615        let json1 = serde_json::to_string(&event).unwrap();
616        let json2 = serde_json::to_string(&cloned).unwrap();
617        assert_eq!(json1, json2);
618    }
619
620    #[test]
621    fn task_event_debug() {
622        let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
623            kind: "artifact-update".into(),
624            task_id: "t-1".into(),
625            context_id: None,
626            artifact: Artifact {
627                artifact_id: "a-1".into(),
628                name: None,
629                parts: vec![Part::text("data")],
630                metadata: None,
631            },
632            is_final: true,
633        });
634        let dbg = format!("{event:?}");
635        assert!(dbg.contains("ArtifactUpdate"));
636    }
637
638    #[test]
639    fn is_private_ip_ipv4_non_private() {
640        assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
641    }
642
643    #[test]
644    fn is_private_ip_ipv6_non_private() {
645        assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
646    }
647
648    #[test]
649    fn rpc_response_error_takes_priority_over_result() {
650        let resp = JsonRpcResponse {
651            jsonrpc: "2.0".into(),
652            id: serde_json::Value::String("1".into()),
653            result: Some(Task {
654                id: "t-1".into(),
655                context_id: None,
656                status: TaskStatus {
657                    state: TaskState::Completed,
658                    timestamp: "ts".into(),
659                    message: None,
660                },
661                artifacts: vec![],
662                history: vec![],
663                metadata: None,
664            }),
665            error: Some(JsonRpcError {
666                code: -32001,
667                message: "error".into(),
668                data: None,
669            }),
670        };
671        let err = resp.into_result().unwrap_err();
672        assert_eq!(err.code, -32001);
673    }
674
675    #[test]
676    fn rpc_response_neither_result_nor_error() {
677        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
678            jsonrpc: "2.0".into(),
679            id: serde_json::Value::String("1".into()),
680            result: None,
681            error: None,
682        };
683        let err = resp.into_result().unwrap_err();
684        assert_eq!(err.code, -32603);
685    }
686
687    #[test]
688    fn task_event_serialize_round_trip() {
689        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
690            kind: "status-update".into(),
691            task_id: "t-1".into(),
692            context_id: Some("ctx-1".into()),
693            status: TaskStatus {
694                state: TaskState::Completed,
695                timestamp: "2025-01-01T00:00:00Z".into(),
696                message: Some(Message::user_text("done")),
697            },
698            is_final: true,
699        });
700        let json = serde_json::to_string(&event).unwrap();
701        let back: TaskEvent = serde_json::from_str(&json).unwrap();
702        assert!(matches!(back, TaskEvent::StatusUpdate(_)));
703    }
704}
705
706#[cfg(test)]
707mod wiremock_tests {
708    use tokio_stream::StreamExt;
709    use wiremock::matchers::{header, method, path};
710    use wiremock::{Mock, MockServer, ResponseTemplate};
711
712    use crate::client::A2aClient;
713    use crate::jsonrpc::{SendMessageParams, TaskIdParams};
714    use crate::testing::*;
715    use crate::types::Message;
716
717    #[tokio::test]
718    async fn send_message_success() {
719        let server = MockServer::start().await;
720        Mock::given(method("POST"))
721            .and(path("/rpc"))
722            .respond_with(task_rpc_response("task-1", "submitted"))
723            .mount(&server)
724            .await;
725
726        let client = A2aClient::new(reqwest::Client::new());
727        let params = SendMessageParams {
728            message: Message::user_text("hello"),
729            configuration: None,
730        };
731        let task = client
732            .send_message(&format!("{}/rpc", server.uri()), params, None)
733            .await
734            .unwrap();
735        assert_eq!(task.id, "task-1");
736    }
737
738    #[tokio::test]
739    async fn send_message_rpc_error() {
740        let server = MockServer::start().await;
741        Mock::given(method("POST"))
742            .and(path("/rpc"))
743            .respond_with(task_rpc_error_response(-32001, "task not found"))
744            .mount(&server)
745            .await;
746
747        let client = A2aClient::new(reqwest::Client::new());
748        let params = SendMessageParams {
749            message: Message::user_text("hi"),
750            configuration: None,
751        };
752        let result = client
753            .send_message(&format!("{}/rpc", server.uri()), params, None)
754            .await;
755        assert!(result.is_err());
756        let err = result.unwrap_err();
757        assert!(matches!(
758            err,
759            crate::error::A2aError::JsonRpc { code: -32001, .. }
760        ));
761    }
762
763    #[tokio::test]
764    async fn send_message_with_bearer_auth() {
765        let server = MockServer::start().await;
766        Mock::given(method("POST"))
767            .and(path("/rpc"))
768            .and(header("authorization", "Bearer secret-token"))
769            .respond_with(task_rpc_response("task-auth", "submitted"))
770            .mount(&server)
771            .await;
772
773        let client = A2aClient::new(reqwest::Client::new());
774        let params = SendMessageParams {
775            message: Message::user_text("secure"),
776            configuration: None,
777        };
778        let task = client
779            .send_message(
780                &format!("{}/rpc", server.uri()),
781                params,
782                Some("secret-token"),
783            )
784            .await
785            .unwrap();
786        assert_eq!(task.id, "task-auth");
787    }
788
789    #[tokio::test]
790    async fn get_task_success() {
791        let server = MockServer::start().await;
792        Mock::given(method("POST"))
793            .and(path("/rpc"))
794            .respond_with(task_rpc_response("task-get", "completed"))
795            .mount(&server)
796            .await;
797
798        let client = A2aClient::new(reqwest::Client::new());
799        let params = TaskIdParams {
800            id: "task-get".into(),
801            history_length: None,
802        };
803        let task = client
804            .get_task(&format!("{}/rpc", server.uri()), params, None)
805            .await
806            .unwrap();
807        assert_eq!(task.id, "task-get");
808    }
809
810    #[tokio::test]
811    async fn cancel_task_success() {
812        let server = MockServer::start().await;
813        Mock::given(method("POST"))
814            .and(path("/rpc"))
815            .respond_with(task_rpc_response("task-cancel", "canceled"))
816            .mount(&server)
817            .await;
818
819        let client = A2aClient::new(reqwest::Client::new());
820        let params = TaskIdParams {
821            id: "task-cancel".into(),
822            history_length: None,
823        };
824        let task = client
825            .cancel_task(&format!("{}/rpc", server.uri()), params, None)
826            .await
827            .unwrap();
828        assert_eq!(task.id, "task-cancel");
829    }
830
831    #[tokio::test]
832    async fn stream_message_success() {
833        let server = MockServer::start().await;
834        Mock::given(method("POST"))
835            .and(path("/rpc"))
836            .respond_with(sse_task_events_response("task-stream", "result content"))
837            .mount(&server)
838            .await;
839
840        let client = A2aClient::new(reqwest::Client::new());
841        let params = SendMessageParams {
842            message: Message::user_text("stream"),
843            configuration: None,
844        };
845        let stream = client
846            .stream_message(&format!("{}/rpc", server.uri()), params, None)
847            .await
848            .unwrap();
849        let events: Vec<_> = stream.collect().await;
850        assert!(!events.is_empty());
851    }
852
853    #[tokio::test]
854    async fn stream_message_http_error() {
855        let server = MockServer::start().await;
856        Mock::given(method("POST"))
857            .and(path("/rpc"))
858            .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
859            .mount(&server)
860            .await;
861
862        let client = A2aClient::new(reqwest::Client::new());
863        let params = SendMessageParams {
864            message: Message::user_text("fail"),
865            configuration: None,
866        };
867        let result = client
868            .stream_message(&format!("{}/rpc", server.uri()), params, None)
869            .await;
870        let err = result.err().expect("expected error");
871        assert!(matches!(err, crate::error::A2aError::Stream(_)));
872    }
873}