Skip to main content

zeph_a2a/
client.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! A2A protocol HTTP client with optional TLS enforcement and SSRF protection.
5
6use std::pin::Pin;
7use std::time::Duration;
8
9use eventsource_stream::Eventsource;
10use futures_core::Stream;
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use tokio_stream::StreamExt;
13use zeph_common::net::is_private_ip;
14
15use crate::error::A2aError;
16use crate::jsonrpc::{
17    JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
18    METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
19};
20use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
21
22/// A pinned, heap-allocated stream of [`TaskEvent`]s from a streaming A2A call.
23///
24/// Produced by [`A2aClient::stream_message`]. Each item is either a status update
25/// or an artifact update; errors are surfaced inline as `Err(A2aError)`.
26pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
27
28/// A single event received on a streaming (`message/stream`) A2A connection.
29///
30/// The A2A spec multiplexes two event kinds over the same SSE channel. This enum
31/// uses `#[serde(untagged)]` so that the deserializer inspects the `kind` field
32/// inside the inner struct to determine the variant.
33#[non_exhaustive]
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(untagged)]
36pub enum TaskEvent {
37    /// A task lifecycle transition (e.g., `submitted` → `working` → `completed`).
38    StatusUpdate(TaskStatusUpdateEvent),
39    /// A new or updated output artifact from the agent.
40    ArtifactUpdate(TaskArtifactUpdateEvent),
41}
42
43/// HTTP client for the A2A protocol.
44///
45/// `A2aClient` wraps a `reqwest::Client` and provides typed methods for the four
46/// A2A JSON-RPC operations: `message/send`, `message/stream`, `tasks/get`, and
47/// `tasks/cancel`. Each call optionally accepts a bearer token for authentication.
48///
49/// # Security
50///
51/// Use [`with_security`](A2aClient::with_security) to harden the client for
52/// production deployments:
53/// - `require_tls = true` rejects any `http://` endpoint before connecting.
54/// - `ssrf_protection = true` resolves the endpoint's hostname via DNS and rejects
55///   addresses in private/loopback ranges (10/8, 172.16/12, 192.168/16, 127/8, etc.).
56///
57/// # Examples
58///
59/// ```rust,no_run
60/// use zeph_a2a::{A2aClient, SendMessageParams, Message};
61///
62/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
63/// let client = A2aClient::new(reqwest::Client::new())
64///     .with_security(true, true); // require HTTPS, block SSRF
65///
66/// let params = SendMessageParams {
67///     message: Message::user_text("Summarize this page."),
68///     configuration: None,
69/// };
70/// let task = client.send_message("https://agent.example.com/a2a", params, Some("tok")).await?;
71/// println!("Task state: {:?}", task.status.state);
72/// # Ok(())
73/// # }
74/// ```
75pub struct A2aClient {
76    client: reqwest::Client,
77    require_tls: bool,
78    ssrf_protection: bool,
79    /// Per-request timeout applied to `rpc_call` (send + JSON parse) and to the initial
80    /// `send()` in `stream_message`. The SSE body stream itself is not bounded — that
81    /// is the caller's responsibility.
82    ///
83    /// If the underlying `reqwest::Client` was also built with `.timeout()`, both limits
84    /// race: whichever fires first wins. `request_timeout` takes semantic priority because
85    /// it maps to `A2aError::Timeout`; the reqwest-level timeout maps to `A2aError::Http`.
86    request_timeout: Duration,
87}
88
89impl A2aClient {
90    /// Create a new `A2aClient` with no security restrictions.
91    ///
92    /// Security features are disabled by default for local/dev usage. Enable them
93    /// with [`with_security`](Self::with_security) for production deployments.
94    #[must_use]
95    pub fn new(client: reqwest::Client) -> Self {
96        Self {
97            client,
98            require_tls: false,
99            ssrf_protection: false,
100            request_timeout: Duration::from_secs(30),
101        }
102    }
103
104    /// Configure TLS enforcement and SSRF protection for this client.
105    ///
106    /// Both flags default to `false`. This method uses the builder pattern and
107    /// can be chained directly after [`new`](Self::new).
108    ///
109    /// - `require_tls`: reject any endpoint that does not start with `https://`.
110    /// - `ssrf_protection`: resolve the endpoint hostname via DNS and reject private IP ranges.
111    ///
112    /// # Examples
113    ///
114    /// ```rust
115    /// use zeph_a2a::A2aClient;
116    ///
117    /// let client = A2aClient::new(reqwest::Client::new())
118    ///     .with_security(true, true);
119    /// ```
120    #[must_use]
121    pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
122        self.require_tls = require_tls;
123        self.ssrf_protection = ssrf_protection;
124        self
125    }
126
127    /// Set the per-request timeout for RPC and streaming connection calls (default: 30 seconds).
128    ///
129    /// Applied to the full send + JSON response parse in `rpc_call`, and to the initial
130    /// HTTP `send()` in `stream_message`. The SSE body stream after connection is intentionally
131    /// unbounded — streams can legitimately run for a long time.
132    #[must_use]
133    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
134        self.request_timeout = timeout;
135        self
136    }
137
138    /// # Errors
139    /// Returns `A2aError` on network, JSON, or JSON-RPC errors, or `A2aError::Timeout`
140    /// if the request exceeds the configured `request_timeout`.
141    pub async fn send_message(
142        &self,
143        endpoint: &str,
144        params: SendMessageParams,
145        token: Option<&str>,
146    ) -> Result<Task, A2aError> {
147        self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
148            .await
149    }
150
151    /// # Errors
152    /// Returns `A2aError` on network failure or if the SSE connection cannot be established.
153    pub async fn stream_message(
154        &self,
155        endpoint: &str,
156        params: SendMessageParams,
157        token: Option<&str>,
158    ) -> Result<TaskEventStream, A2aError> {
159        self.validate_endpoint(endpoint).await?;
160        let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
161        let mut req = self.client.post(endpoint).json(&request);
162        if let Some(t) = token {
163            req = req.bearer_auth(t);
164        }
165        let resp = tokio::time::timeout(self.request_timeout, req.send())
166            .await
167            .map_err(|_| A2aError::Timeout(self.request_timeout))?
168            .map_err(A2aError::Http)?;
169
170        if !resp.status().is_success() {
171            let status = resp.status();
172            let body = tokio::time::timeout(Duration::from_secs(5), resp.text())
173                .await
174                .unwrap_or(Ok(String::new()))
175                .unwrap_or_default();
176            // Truncate body to avoid leaking large upstream error responses.
177            let truncated = if body.len() > 256 {
178                format!("{}…", &body[..256])
179            } else {
180                body
181            };
182            return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
183        }
184
185        let event_stream = resp.bytes_stream().eventsource();
186        let mapped = event_stream.filter_map(|event| match event {
187            Ok(event) => {
188                if event.data.is_empty() || event.data == "[DONE]" {
189                    return None;
190                }
191                match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
192                    Ok(rpc_resp) => match rpc_resp.into_result() {
193                        Ok(task_event) => Some(Ok(task_event)),
194                        Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
195                    },
196                    Err(e) => Some(Err(A2aError::Stream(format!(
197                        "failed to parse SSE event: {e}"
198                    )))),
199                }
200            }
201            Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
202        });
203
204        Ok(Box::pin(mapped))
205    }
206
207    /// # Errors
208    /// Returns `A2aError` on network, JSON, or JSON-RPC errors, or `A2aError::Timeout`
209    /// if the request exceeds the configured `request_timeout`.
210    pub async fn get_task(
211        &self,
212        endpoint: &str,
213        params: TaskIdParams,
214        token: Option<&str>,
215    ) -> Result<Task, A2aError> {
216        self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
217            .await
218    }
219
220    /// # Errors
221    /// Returns `A2aError` on network, JSON, or JSON-RPC errors, or `A2aError::Timeout`
222    /// if the request exceeds the configured `request_timeout`.
223    pub async fn cancel_task(
224        &self,
225        endpoint: &str,
226        params: TaskIdParams,
227        token: Option<&str>,
228    ) -> Result<Task, A2aError> {
229        self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
230            .await
231    }
232
233    async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
234        if self.require_tls && !endpoint.starts_with("https://") {
235            return Err(A2aError::Security(format!(
236                "TLS required but endpoint uses HTTP: {endpoint}"
237            )));
238        }
239
240        if self.ssrf_protection {
241            let url: url::Url = endpoint
242                .parse()
243                .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
244
245            if let Some(host) = url.host_str() {
246                let addrs = tokio::net::lookup_host(format!(
247                    "{}:{}",
248                    host,
249                    url.port_or_known_default().unwrap_or(443)
250                ))
251                .await
252                .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
253
254                for addr in addrs {
255                    if is_private_ip(addr.ip()) {
256                        return Err(A2aError::Security(format!(
257                            "SSRF protection: private IP {} for host {host}",
258                            addr.ip()
259                        )));
260                    }
261                }
262            }
263        }
264
265        Ok(())
266    }
267
268    async fn rpc_call<P: Serialize, R: DeserializeOwned>(
269        &self,
270        endpoint: &str,
271        method: &str,
272        params: P,
273        token: Option<&str>,
274    ) -> Result<R, A2aError> {
275        self.validate_endpoint(endpoint).await?;
276        let request = JsonRpcRequest::new(method, params);
277        let mut req = self.client.post(endpoint).json(&request);
278        if let Some(t) = token {
279            req = req.bearer_auth(t);
280        }
281        let rpc_response: JsonRpcResponse<R> = tokio::time::timeout(self.request_timeout, async {
282            let resp = req.send().await?;
283            resp.json().await
284        })
285        .await
286        .map_err(|_| A2aError::Timeout(self.request_timeout))?
287        .map_err(A2aError::Http)?;
288        rpc_response.into_result().map_err(A2aError::from)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use std::net::IpAddr;
295
296    use super::*;
297    use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
298    use crate::types::{
299        Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
300        TaskStatusUpdateEvent,
301    };
302
303    #[test]
304    fn task_event_deserialize_status_update() {
305        let event = TaskStatusUpdateEvent {
306            kind: "status-update".into(),
307            task_id: "t-1".into(),
308            context_id: None,
309            status: TaskStatus {
310                state: TaskState::Working,
311                timestamp: "ts".into(),
312                message: Some(Message::user_text("thinking...")),
313            },
314            is_final: false,
315        };
316        let json = serde_json::to_string(&event).unwrap();
317        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
318        assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
319    }
320
321    #[test]
322    fn task_event_deserialize_artifact_update() {
323        let event = TaskArtifactUpdateEvent {
324            kind: "artifact-update".into(),
325            task_id: "t-1".into(),
326            context_id: None,
327            artifact: Artifact {
328                artifact_id: "a-1".into(),
329                name: None,
330                parts: vec![Part::text("result")],
331                metadata: None,
332            },
333            is_final: true,
334        };
335        let json = serde_json::to_string(&event).unwrap();
336        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
337        assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
338    }
339
340    #[test]
341    fn rpc_response_with_task_result() {
342        let task = Task {
343            id: "t-1".into(),
344            context_id: None,
345            status: TaskStatus {
346                state: TaskState::Completed,
347                timestamp: "ts".into(),
348                message: None,
349            },
350            artifacts: vec![],
351            history: vec![],
352            metadata: None,
353        };
354        let resp = JsonRpcResponse {
355            jsonrpc: "2.0".into(),
356            id: serde_json::Value::String("req-1".into()),
357            result: Some(task),
358            error: None,
359        };
360        let json = serde_json::to_string(&resp).unwrap();
361        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
362        let task = back.into_result().unwrap();
363        assert_eq!(task.id, "t-1");
364        assert_eq!(task.status.state, TaskState::Completed);
365    }
366
367    #[test]
368    fn rpc_response_with_error() {
369        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
370            jsonrpc: "2.0".into(),
371            id: serde_json::Value::String("req-1".into()),
372            result: None,
373            error: Some(JsonRpcError {
374                code: -32001,
375                message: "task not found".into(),
376                data: None,
377            }),
378        };
379        let json = serde_json::to_string(&resp).unwrap();
380        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
381        let err = back.into_result().unwrap_err();
382        assert_eq!(err.code, -32001);
383    }
384
385    #[test]
386    fn a2a_client_construction() {
387        let client = A2aClient::new(reqwest::Client::new());
388        drop(client);
389    }
390
391    #[test]
392    fn is_private_ip_loopback() {
393        assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
394        assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
395    }
396
397    #[test]
398    fn is_private_ip_private_ranges() {
399        assert!(is_private_ip("10.0.0.1".parse().unwrap()));
400        assert!(is_private_ip("172.16.0.1".parse().unwrap()));
401        assert!(is_private_ip("192.168.1.1".parse().unwrap()));
402    }
403
404    #[test]
405    fn is_private_ip_link_local() {
406        assert!(is_private_ip("169.254.0.1".parse().unwrap()));
407    }
408
409    #[test]
410    fn is_private_ip_unspecified() {
411        assert!(is_private_ip("0.0.0.0".parse().unwrap()));
412        assert!(is_private_ip("::".parse().unwrap()));
413    }
414
415    #[test]
416    fn is_private_ip_public() {
417        assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
418        assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
419    }
420
421    #[tokio::test]
422    async fn tls_enforcement_rejects_http() {
423        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
424        let result = client.validate_endpoint("http://example.com/rpc").await;
425        assert!(result.is_err());
426        let err = result.unwrap_err();
427        assert!(matches!(err, A2aError::Security(_)));
428        assert!(err.to_string().contains("TLS required"));
429    }
430
431    #[tokio::test]
432    async fn tls_enforcement_allows_https() {
433        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
434        let result = client.validate_endpoint("https://example.com/rpc").await;
435        assert!(result.is_ok());
436    }
437
438    #[tokio::test]
439    async fn ssrf_protection_rejects_localhost() {
440        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
441        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
442        assert!(result.is_err());
443        assert!(result.unwrap_err().to_string().contains("SSRF"));
444    }
445
446    #[tokio::test]
447    async fn no_security_allows_http_localhost() {
448        let client = A2aClient::new(reqwest::Client::new());
449        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
450        assert!(result.is_ok());
451    }
452
453    #[test]
454    fn jsonrpc_request_serialization_for_send_message() {
455        let params = SendMessageParams {
456            message: Message::user_text("hello"),
457            configuration: None,
458        };
459        let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
460        let json = serde_json::to_string(&req).unwrap();
461        assert!(json.contains("\"method\":\"message/send\""));
462        assert!(json.contains("\"jsonrpc\":\"2.0\""));
463        assert!(json.contains("\"hello\""));
464    }
465
466    #[test]
467    fn jsonrpc_request_serialization_for_get_task() {
468        let params = TaskIdParams {
469            id: "task-123".into(),
470            history_length: Some(5),
471        };
472        let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
473        let json = serde_json::to_string(&req).unwrap();
474        assert!(json.contains("\"method\":\"tasks/get\""));
475        assert!(json.contains("\"task-123\""));
476        assert!(json.contains("\"historyLength\":5"));
477    }
478
479    #[test]
480    fn jsonrpc_request_serialization_for_cancel_task() {
481        let params = TaskIdParams {
482            id: "task-456".into(),
483            history_length: None,
484        };
485        let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
486        let json = serde_json::to_string(&req).unwrap();
487        assert!(json.contains("\"method\":\"tasks/cancel\""));
488        assert!(!json.contains("historyLength"));
489    }
490
491    #[test]
492    fn jsonrpc_request_serialization_for_stream() {
493        let params = SendMessageParams {
494            message: Message::user_text("stream me"),
495            configuration: None,
496        };
497        let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
498        let json = serde_json::to_string(&req).unwrap();
499        assert!(json.contains("\"method\":\"message/stream\""));
500    }
501
502    #[tokio::test]
503    async fn send_message_connection_error() {
504        let client = A2aClient::new(reqwest::Client::new());
505        let params = SendMessageParams {
506            message: Message::user_text("hello"),
507            configuration: None,
508        };
509        let result = client
510            .send_message("http://127.0.0.1:1/rpc", params, None)
511            .await;
512        assert!(result.is_err());
513        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
514    }
515
516    #[tokio::test]
517    async fn get_task_connection_error() {
518        let client = A2aClient::new(reqwest::Client::new());
519        let params = TaskIdParams {
520            id: "t-1".into(),
521            history_length: None,
522        };
523        let result = client
524            .get_task("http://127.0.0.1:1/rpc", params, None)
525            .await;
526        assert!(result.is_err());
527        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
528    }
529
530    #[tokio::test]
531    async fn cancel_task_connection_error() {
532        let client = A2aClient::new(reqwest::Client::new());
533        let params = TaskIdParams {
534            id: "t-1".into(),
535            history_length: None,
536        };
537        let result = client
538            .cancel_task("http://127.0.0.1:1/rpc", params, None)
539            .await;
540        assert!(result.is_err());
541        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
542    }
543
544    #[tokio::test]
545    async fn stream_message_connection_error() {
546        let client = A2aClient::new(reqwest::Client::new());
547        let params = SendMessageParams {
548            message: Message::user_text("stream me"),
549            configuration: None,
550        };
551        let result = client
552            .stream_message("http://127.0.0.1:1/rpc", params, None)
553            .await;
554        assert!(result.is_err());
555    }
556
557    #[tokio::test]
558    async fn stream_message_tls_required_rejects_http() {
559        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
560        let params = SendMessageParams {
561            message: Message::user_text("hello"),
562            configuration: None,
563        };
564        let result = client
565            .stream_message("http://example.com/rpc", params, None)
566            .await;
567        match result {
568            Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
569            _ => panic!("expected Security error"),
570        }
571    }
572
573    #[tokio::test]
574    async fn send_message_tls_required_rejects_http() {
575        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
576        let params = SendMessageParams {
577            message: Message::user_text("hello"),
578            configuration: None,
579        };
580        let result = client
581            .send_message("http://example.com/rpc", params, None)
582            .await;
583        assert!(result.is_err());
584        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
585    }
586
587    #[tokio::test]
588    async fn get_task_tls_required_rejects_http() {
589        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
590        let params = TaskIdParams {
591            id: "t-1".into(),
592            history_length: None,
593        };
594        let result = client
595            .get_task("http://example.com/rpc", params, None)
596            .await;
597        assert!(result.is_err());
598        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
599    }
600
601    #[tokio::test]
602    async fn cancel_task_tls_required_rejects_http() {
603        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
604        let params = TaskIdParams {
605            id: "t-1".into(),
606            history_length: None,
607        };
608        let result = client
609            .cancel_task("http://example.com/rpc", params, None)
610            .await;
611        assert!(result.is_err());
612        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
613    }
614
615    #[tokio::test]
616    async fn validate_endpoint_invalid_url_with_ssrf() {
617        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
618        let result = client.validate_endpoint("not-a-url").await;
619        assert!(result.is_err());
620        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
621    }
622
623    #[test]
624    fn with_security_returns_configured_client() {
625        let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
626        assert!(client.require_tls);
627        assert!(client.ssrf_protection);
628    }
629
630    #[test]
631    fn default_client_no_security() {
632        let client = A2aClient::new(reqwest::Client::new());
633        assert!(!client.require_tls);
634        assert!(!client.ssrf_protection);
635    }
636
637    #[test]
638    fn task_event_clone() {
639        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
640            kind: "status-update".into(),
641            task_id: "t-1".into(),
642            context_id: None,
643            status: TaskStatus {
644                state: TaskState::Working,
645                timestamp: "ts".into(),
646                message: None,
647            },
648            is_final: false,
649        });
650        let cloned = event.clone();
651        let json1 = serde_json::to_string(&event).unwrap();
652        let json2 = serde_json::to_string(&cloned).unwrap();
653        assert_eq!(json1, json2);
654    }
655
656    #[test]
657    fn task_event_debug() {
658        let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
659            kind: "artifact-update".into(),
660            task_id: "t-1".into(),
661            context_id: None,
662            artifact: Artifact {
663                artifact_id: "a-1".into(),
664                name: None,
665                parts: vec![Part::text("data")],
666                metadata: None,
667            },
668            is_final: true,
669        });
670        let dbg = format!("{event:?}");
671        assert!(dbg.contains("ArtifactUpdate"));
672    }
673
674    #[test]
675    fn is_private_ip_ipv4_non_private() {
676        assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
677    }
678
679    #[test]
680    fn is_private_ip_ipv6_non_private() {
681        assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
682    }
683
684    #[test]
685    fn rpc_response_error_takes_priority_over_result() {
686        let resp = JsonRpcResponse {
687            jsonrpc: "2.0".into(),
688            id: serde_json::Value::String("1".into()),
689            result: Some(Task {
690                id: "t-1".into(),
691                context_id: None,
692                status: TaskStatus {
693                    state: TaskState::Completed,
694                    timestamp: "ts".into(),
695                    message: None,
696                },
697                artifacts: vec![],
698                history: vec![],
699                metadata: None,
700            }),
701            error: Some(JsonRpcError {
702                code: -32001,
703                message: "error".into(),
704                data: None,
705            }),
706        };
707        let err = resp.into_result().unwrap_err();
708        assert_eq!(err.code, -32001);
709    }
710
711    #[test]
712    fn rpc_response_neither_result_nor_error() {
713        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
714            jsonrpc: "2.0".into(),
715            id: serde_json::Value::String("1".into()),
716            result: None,
717            error: None,
718        };
719        let err = resp.into_result().unwrap_err();
720        assert_eq!(err.code, -32603);
721    }
722
723    #[test]
724    fn task_event_serialize_round_trip() {
725        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
726            kind: "status-update".into(),
727            task_id: "t-1".into(),
728            context_id: Some("ctx-1".into()),
729            status: TaskStatus {
730                state: TaskState::Completed,
731                timestamp: "2025-01-01T00:00:00Z".into(),
732                message: Some(Message::user_text("done")),
733            },
734            is_final: true,
735        });
736        let json = serde_json::to_string(&event).unwrap();
737        let back: TaskEvent = serde_json::from_str(&json).unwrap();
738        assert!(matches!(back, TaskEvent::StatusUpdate(_)));
739    }
740}
741
742#[cfg(test)]
743mod wiremock_tests {
744    use tokio_stream::StreamExt;
745    use wiremock::matchers::{header, method, path};
746    use wiremock::{Mock, MockServer, ResponseTemplate};
747
748    use crate::client::A2aClient;
749    use crate::jsonrpc::{SendMessageParams, TaskIdParams};
750    use crate::testing::*;
751    use crate::types::Message;
752
753    #[tokio::test]
754    async fn send_message_success() {
755        let server = MockServer::start().await;
756        Mock::given(method("POST"))
757            .and(path("/rpc"))
758            .respond_with(task_rpc_response("task-1", "submitted"))
759            .mount(&server)
760            .await;
761
762        let client = A2aClient::new(reqwest::Client::new());
763        let params = SendMessageParams {
764            message: Message::user_text("hello"),
765            configuration: None,
766        };
767        let task = client
768            .send_message(&format!("{}/rpc", server.uri()), params, None)
769            .await
770            .unwrap();
771        assert_eq!(task.id, "task-1");
772    }
773
774    #[tokio::test]
775    async fn send_message_rpc_error() {
776        let server = MockServer::start().await;
777        Mock::given(method("POST"))
778            .and(path("/rpc"))
779            .respond_with(task_rpc_error_response(-32001, "task not found"))
780            .mount(&server)
781            .await;
782
783        let client = A2aClient::new(reqwest::Client::new());
784        let params = SendMessageParams {
785            message: Message::user_text("hi"),
786            configuration: None,
787        };
788        let result = client
789            .send_message(&format!("{}/rpc", server.uri()), params, None)
790            .await;
791        assert!(result.is_err());
792        let err = result.unwrap_err();
793        assert!(matches!(
794            err,
795            crate::error::A2aError::JsonRpc { code: -32001, .. }
796        ));
797    }
798
799    #[tokio::test]
800    async fn send_message_with_bearer_auth() {
801        let server = MockServer::start().await;
802        Mock::given(method("POST"))
803            .and(path("/rpc"))
804            .and(header("authorization", "Bearer secret-token"))
805            .respond_with(task_rpc_response("task-auth", "submitted"))
806            .mount(&server)
807            .await;
808
809        let client = A2aClient::new(reqwest::Client::new());
810        let params = SendMessageParams {
811            message: Message::user_text("secure"),
812            configuration: None,
813        };
814        let task = client
815            .send_message(
816                &format!("{}/rpc", server.uri()),
817                params,
818                Some("secret-token"),
819            )
820            .await
821            .unwrap();
822        assert_eq!(task.id, "task-auth");
823    }
824
825    #[tokio::test]
826    async fn get_task_success() {
827        let server = MockServer::start().await;
828        Mock::given(method("POST"))
829            .and(path("/rpc"))
830            .respond_with(task_rpc_response("task-get", "completed"))
831            .mount(&server)
832            .await;
833
834        let client = A2aClient::new(reqwest::Client::new());
835        let params = TaskIdParams {
836            id: "task-get".into(),
837            history_length: None,
838        };
839        let task = client
840            .get_task(&format!("{}/rpc", server.uri()), params, None)
841            .await
842            .unwrap();
843        assert_eq!(task.id, "task-get");
844    }
845
846    #[tokio::test]
847    async fn cancel_task_success() {
848        let server = MockServer::start().await;
849        Mock::given(method("POST"))
850            .and(path("/rpc"))
851            .respond_with(task_rpc_response("task-cancel", "canceled"))
852            .mount(&server)
853            .await;
854
855        let client = A2aClient::new(reqwest::Client::new());
856        let params = TaskIdParams {
857            id: "task-cancel".into(),
858            history_length: None,
859        };
860        let task = client
861            .cancel_task(&format!("{}/rpc", server.uri()), params, None)
862            .await
863            .unwrap();
864        assert_eq!(task.id, "task-cancel");
865    }
866
867    #[tokio::test]
868    async fn stream_message_success() {
869        let server = MockServer::start().await;
870        Mock::given(method("POST"))
871            .and(path("/rpc"))
872            .respond_with(sse_task_events_response("task-stream", "result content"))
873            .mount(&server)
874            .await;
875
876        let client = A2aClient::new(reqwest::Client::new());
877        let params = SendMessageParams {
878            message: Message::user_text("stream"),
879            configuration: None,
880        };
881        let stream = client
882            .stream_message(&format!("{}/rpc", server.uri()), params, None)
883            .await
884            .unwrap();
885        let events: Vec<_> = stream.collect().await;
886        assert!(!events.is_empty());
887    }
888
889    #[tokio::test]
890    async fn stream_message_http_error() {
891        let server = MockServer::start().await;
892        Mock::given(method("POST"))
893            .and(path("/rpc"))
894            .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
895            .mount(&server)
896            .await;
897
898        let client = A2aClient::new(reqwest::Client::new());
899        let params = SendMessageParams {
900            message: Message::user_text("fail"),
901            configuration: None,
902        };
903        let result = client
904            .stream_message(&format!("{}/rpc", server.uri()), params, None)
905            .await;
906        let err = result.err().expect("expected error");
907        assert!(matches!(err, crate::error::A2aError::Stream(_)));
908    }
909
910    #[tokio::test]
911    async fn rpc_call_times_out() {
912        let server = MockServer::start().await;
913        Mock::given(method("POST"))
914            .and(path("/rpc"))
915            .respond_with(
916                ResponseTemplate::new(200)
917                    .set_delay(std::time::Duration::from_secs(5))
918                    .set_body_json(serde_json::json!({
919                        "jsonrpc": "2.0",
920                        "id": "req-1",
921                        "result": {
922                            "id": "t-1",
923                            "status": {"state": "completed", "timestamp": "2026-01-01T00:00:00Z"}
924                        }
925                    })),
926            )
927            .mount(&server)
928            .await;
929
930        let client = A2aClient::new(reqwest::Client::new())
931            .with_request_timeout(std::time::Duration::from_millis(100));
932        let params = SendMessageParams {
933            message: Message::user_text("hello"),
934            configuration: None,
935        };
936        let result = client
937            .send_message(&format!("{}/rpc", server.uri()), params, None)
938            .await;
939        assert!(result.is_err());
940        assert!(
941            matches!(result.unwrap_err(), crate::error::A2aError::Timeout(_)),
942            "expected Timeout error"
943        );
944    }
945}