Skip to main content

zeph_a2a/
client.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::IpAddr;
5use std::pin::Pin;
6
7use eventsource_stream::Eventsource;
8use futures_core::Stream;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use tokio_stream::StreamExt;
11
12use crate::error::A2aError;
13use crate::jsonrpc::{
14    JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
15    METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
16};
17use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
18
19pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(untagged)]
23pub enum TaskEvent {
24    StatusUpdate(TaskStatusUpdateEvent),
25    ArtifactUpdate(TaskArtifactUpdateEvent),
26}
27
28pub struct A2aClient {
29    client: reqwest::Client,
30    require_tls: bool,
31    ssrf_protection: bool,
32}
33
34impl A2aClient {
35    #[must_use]
36    pub fn new(client: reqwest::Client) -> Self {
37        Self {
38            client,
39            require_tls: false,
40            ssrf_protection: false,
41        }
42    }
43
44    #[must_use]
45    pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
46        self.require_tls = require_tls;
47        self.ssrf_protection = ssrf_protection;
48        self
49    }
50
51    /// # Errors
52    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
53    pub async fn send_message(
54        &self,
55        endpoint: &str,
56        params: SendMessageParams,
57        token: Option<&str>,
58    ) -> Result<Task, A2aError> {
59        self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
60            .await
61    }
62
63    /// # Errors
64    /// Returns `A2aError` on network failure or if the SSE connection cannot be established.
65    pub async fn stream_message(
66        &self,
67        endpoint: &str,
68        params: SendMessageParams,
69        token: Option<&str>,
70    ) -> Result<TaskEventStream, A2aError> {
71        self.validate_endpoint(endpoint).await?;
72        let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
73        let mut req = self.client.post(endpoint).json(&request);
74        if let Some(t) = token {
75            req = req.bearer_auth(t);
76        }
77        let resp = req.send().await?;
78
79        if !resp.status().is_success() {
80            let status = resp.status();
81            let body = resp.text().await.unwrap_or_default();
82            // Truncate body to avoid leaking large upstream error responses.
83            let truncated = if body.len() > 256 {
84                format!("{}…", &body[..256])
85            } else {
86                body
87            };
88            return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
89        }
90
91        let event_stream = resp.bytes_stream().eventsource();
92        let mapped = event_stream.filter_map(|event| match event {
93            Ok(event) => {
94                if event.data.is_empty() || event.data == "[DONE]" {
95                    return None;
96                }
97                match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
98                    Ok(rpc_resp) => match rpc_resp.into_result() {
99                        Ok(task_event) => Some(Ok(task_event)),
100                        Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
101                    },
102                    Err(e) => Some(Err(A2aError::Stream(format!(
103                        "failed to parse SSE event: {e}"
104                    )))),
105                }
106            }
107            Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
108        });
109
110        Ok(Box::pin(mapped))
111    }
112
113    /// # Errors
114    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
115    pub async fn get_task(
116        &self,
117        endpoint: &str,
118        params: TaskIdParams,
119        token: Option<&str>,
120    ) -> Result<Task, A2aError> {
121        self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
122            .await
123    }
124
125    /// # Errors
126    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
127    pub async fn cancel_task(
128        &self,
129        endpoint: &str,
130        params: TaskIdParams,
131        token: Option<&str>,
132    ) -> Result<Task, A2aError> {
133        self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
134            .await
135    }
136
137    async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
138        if self.require_tls && !endpoint.starts_with("https://") {
139            return Err(A2aError::Security(format!(
140                "TLS required but endpoint uses HTTP: {endpoint}"
141            )));
142        }
143
144        if self.ssrf_protection {
145            let url: url::Url = endpoint
146                .parse()
147                .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
148
149            if let Some(host) = url.host_str() {
150                let addrs = tokio::net::lookup_host(format!(
151                    "{}:{}",
152                    host,
153                    url.port_or_known_default().unwrap_or(443)
154                ))
155                .await
156                .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
157
158                for addr in addrs {
159                    if is_private_ip(addr.ip()) {
160                        return Err(A2aError::Security(format!(
161                            "SSRF protection: private IP {} for host {host}",
162                            addr.ip()
163                        )));
164                    }
165                }
166            }
167        }
168
169        Ok(())
170    }
171
172    async fn rpc_call<P: Serialize, R: DeserializeOwned>(
173        &self,
174        endpoint: &str,
175        method: &str,
176        params: P,
177        token: Option<&str>,
178    ) -> Result<R, A2aError> {
179        self.validate_endpoint(endpoint).await?;
180        let request = JsonRpcRequest::new(method, params);
181        let mut req = self.client.post(endpoint).json(&request);
182        if let Some(t) = token {
183            req = req.bearer_auth(t);
184        }
185        let resp = req.send().await?;
186        let rpc_response: JsonRpcResponse<R> = resp.json().await?;
187        rpc_response.into_result().map_err(A2aError::from)
188    }
189}
190
191fn is_private_ip(ip: IpAddr) -> bool {
192    match ip {
193        IpAddr::V4(v4) => {
194            v4.is_loopback()
195                || v4.is_private()
196                || v4.is_link_local()
197                || v4.is_unspecified()
198                || v4.is_broadcast()
199        }
200        IpAddr::V6(v6) => {
201            if v6.is_loopback() || v6.is_unspecified() {
202                return true;
203            }
204            let seg = v6.segments();
205            // fe80::/10 — link-local
206            if seg[0] & 0xffc0 == 0xfe80 {
207                return true;
208            }
209            // fc00::/7 — unique local
210            if seg[0] & 0xfe00 == 0xfc00 {
211                return true;
212            }
213            // ::ffff:x.x.x.x — IPv4-mapped, check inner IPv4
214            if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
215                let v4 = v6
216                    .to_ipv4_mapped()
217                    .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
218                return v4.is_loopback()
219                    || v4.is_private()
220                    || v4.is_link_local()
221                    || v4.is_unspecified()
222                    || v4.is_broadcast();
223            }
224            false
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
233    use crate::types::{
234        Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
235        TaskStatusUpdateEvent,
236    };
237
238    #[test]
239    fn task_event_deserialize_status_update() {
240        let event = TaskStatusUpdateEvent {
241            kind: "status-update".into(),
242            task_id: "t-1".into(),
243            context_id: None,
244            status: TaskStatus {
245                state: TaskState::Working,
246                timestamp: "ts".into(),
247                message: Some(Message::user_text("thinking...")),
248            },
249            is_final: false,
250        };
251        let json = serde_json::to_string(&event).unwrap();
252        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
253        assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
254    }
255
256    #[test]
257    fn task_event_deserialize_artifact_update() {
258        let event = TaskArtifactUpdateEvent {
259            kind: "artifact-update".into(),
260            task_id: "t-1".into(),
261            context_id: None,
262            artifact: Artifact {
263                artifact_id: "a-1".into(),
264                name: None,
265                parts: vec![Part::text("result")],
266                metadata: None,
267            },
268            is_final: true,
269        };
270        let json = serde_json::to_string(&event).unwrap();
271        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
272        assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
273    }
274
275    #[test]
276    fn rpc_response_with_task_result() {
277        let task = Task {
278            id: "t-1".into(),
279            context_id: None,
280            status: TaskStatus {
281                state: TaskState::Completed,
282                timestamp: "ts".into(),
283                message: None,
284            },
285            artifacts: vec![],
286            history: vec![],
287            metadata: None,
288        };
289        let resp = JsonRpcResponse {
290            jsonrpc: "2.0".into(),
291            id: serde_json::Value::String("req-1".into()),
292            result: Some(task),
293            error: None,
294        };
295        let json = serde_json::to_string(&resp).unwrap();
296        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
297        let task = back.into_result().unwrap();
298        assert_eq!(task.id, "t-1");
299        assert_eq!(task.status.state, TaskState::Completed);
300    }
301
302    #[test]
303    fn rpc_response_with_error() {
304        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
305            jsonrpc: "2.0".into(),
306            id: serde_json::Value::String("req-1".into()),
307            result: None,
308            error: Some(JsonRpcError {
309                code: -32001,
310                message: "task not found".into(),
311                data: None,
312            }),
313        };
314        let json = serde_json::to_string(&resp).unwrap();
315        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
316        let err = back.into_result().unwrap_err();
317        assert_eq!(err.code, -32001);
318    }
319
320    #[test]
321    fn a2a_client_construction() {
322        let client = A2aClient::new(reqwest::Client::new());
323        drop(client);
324    }
325
326    #[test]
327    fn is_private_ip_loopback() {
328        assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
329        assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
330    }
331
332    #[test]
333    fn is_private_ip_private_ranges() {
334        assert!(is_private_ip("10.0.0.1".parse().unwrap()));
335        assert!(is_private_ip("172.16.0.1".parse().unwrap()));
336        assert!(is_private_ip("192.168.1.1".parse().unwrap()));
337    }
338
339    #[test]
340    fn is_private_ip_link_local() {
341        assert!(is_private_ip("169.254.0.1".parse().unwrap()));
342    }
343
344    #[test]
345    fn is_private_ip_unspecified() {
346        assert!(is_private_ip("0.0.0.0".parse().unwrap()));
347        assert!(is_private_ip("::".parse().unwrap()));
348    }
349
350    #[test]
351    fn is_private_ip_public() {
352        assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
353        assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
354    }
355
356    #[tokio::test]
357    async fn tls_enforcement_rejects_http() {
358        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
359        let result = client.validate_endpoint("http://example.com/rpc").await;
360        assert!(result.is_err());
361        let err = result.unwrap_err();
362        assert!(matches!(err, A2aError::Security(_)));
363        assert!(err.to_string().contains("TLS required"));
364    }
365
366    #[tokio::test]
367    async fn tls_enforcement_allows_https() {
368        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
369        let result = client.validate_endpoint("https://example.com/rpc").await;
370        assert!(result.is_ok());
371    }
372
373    #[tokio::test]
374    async fn ssrf_protection_rejects_localhost() {
375        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
376        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
377        assert!(result.is_err());
378        assert!(result.unwrap_err().to_string().contains("SSRF"));
379    }
380
381    #[tokio::test]
382    async fn no_security_allows_http_localhost() {
383        let client = A2aClient::new(reqwest::Client::new());
384        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
385        assert!(result.is_ok());
386    }
387
388    #[test]
389    fn jsonrpc_request_serialization_for_send_message() {
390        let params = SendMessageParams {
391            message: Message::user_text("hello"),
392            configuration: None,
393        };
394        let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
395        let json = serde_json::to_string(&req).unwrap();
396        assert!(json.contains("\"method\":\"message/send\""));
397        assert!(json.contains("\"jsonrpc\":\"2.0\""));
398        assert!(json.contains("\"hello\""));
399    }
400
401    #[test]
402    fn jsonrpc_request_serialization_for_get_task() {
403        let params = TaskIdParams {
404            id: "task-123".into(),
405            history_length: Some(5),
406        };
407        let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
408        let json = serde_json::to_string(&req).unwrap();
409        assert!(json.contains("\"method\":\"tasks/get\""));
410        assert!(json.contains("\"task-123\""));
411        assert!(json.contains("\"historyLength\":5"));
412    }
413
414    #[test]
415    fn jsonrpc_request_serialization_for_cancel_task() {
416        let params = TaskIdParams {
417            id: "task-456".into(),
418            history_length: None,
419        };
420        let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
421        let json = serde_json::to_string(&req).unwrap();
422        assert!(json.contains("\"method\":\"tasks/cancel\""));
423        assert!(!json.contains("historyLength"));
424    }
425
426    #[test]
427    fn jsonrpc_request_serialization_for_stream() {
428        let params = SendMessageParams {
429            message: Message::user_text("stream me"),
430            configuration: None,
431        };
432        let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
433        let json = serde_json::to_string(&req).unwrap();
434        assert!(json.contains("\"method\":\"message/stream\""));
435    }
436
437    #[tokio::test]
438    async fn send_message_connection_error() {
439        let client = A2aClient::new(reqwest::Client::new());
440        let params = SendMessageParams {
441            message: Message::user_text("hello"),
442            configuration: None,
443        };
444        let result = client
445            .send_message("http://127.0.0.1:1/rpc", params, None)
446            .await;
447        assert!(result.is_err());
448        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
449    }
450
451    #[tokio::test]
452    async fn get_task_connection_error() {
453        let client = A2aClient::new(reqwest::Client::new());
454        let params = TaskIdParams {
455            id: "t-1".into(),
456            history_length: None,
457        };
458        let result = client
459            .get_task("http://127.0.0.1:1/rpc", params, None)
460            .await;
461        assert!(result.is_err());
462        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
463    }
464
465    #[tokio::test]
466    async fn cancel_task_connection_error() {
467        let client = A2aClient::new(reqwest::Client::new());
468        let params = TaskIdParams {
469            id: "t-1".into(),
470            history_length: None,
471        };
472        let result = client
473            .cancel_task("http://127.0.0.1:1/rpc", params, None)
474            .await;
475        assert!(result.is_err());
476        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
477    }
478
479    #[tokio::test]
480    async fn stream_message_connection_error() {
481        let client = A2aClient::new(reqwest::Client::new());
482        let params = SendMessageParams {
483            message: Message::user_text("stream me"),
484            configuration: None,
485        };
486        let result = client
487            .stream_message("http://127.0.0.1:1/rpc", params, None)
488            .await;
489        assert!(result.is_err());
490    }
491
492    #[tokio::test]
493    async fn stream_message_tls_required_rejects_http() {
494        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
495        let params = SendMessageParams {
496            message: Message::user_text("hello"),
497            configuration: None,
498        };
499        let result = client
500            .stream_message("http://example.com/rpc", params, None)
501            .await;
502        match result {
503            Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
504            _ => panic!("expected Security error"),
505        }
506    }
507
508    #[tokio::test]
509    async fn send_message_tls_required_rejects_http() {
510        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
511        let params = SendMessageParams {
512            message: Message::user_text("hello"),
513            configuration: None,
514        };
515        let result = client
516            .send_message("http://example.com/rpc", params, None)
517            .await;
518        assert!(result.is_err());
519        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
520    }
521
522    #[tokio::test]
523    async fn get_task_tls_required_rejects_http() {
524        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
525        let params = TaskIdParams {
526            id: "t-1".into(),
527            history_length: None,
528        };
529        let result = client
530            .get_task("http://example.com/rpc", params, None)
531            .await;
532        assert!(result.is_err());
533        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
534    }
535
536    #[tokio::test]
537    async fn cancel_task_tls_required_rejects_http() {
538        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
539        let params = TaskIdParams {
540            id: "t-1".into(),
541            history_length: None,
542        };
543        let result = client
544            .cancel_task("http://example.com/rpc", params, None)
545            .await;
546        assert!(result.is_err());
547        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
548    }
549
550    #[tokio::test]
551    async fn validate_endpoint_invalid_url_with_ssrf() {
552        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
553        let result = client.validate_endpoint("not-a-url").await;
554        assert!(result.is_err());
555        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
556    }
557
558    #[test]
559    fn with_security_returns_configured_client() {
560        let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
561        assert!(client.require_tls);
562        assert!(client.ssrf_protection);
563    }
564
565    #[test]
566    fn default_client_no_security() {
567        let client = A2aClient::new(reqwest::Client::new());
568        assert!(!client.require_tls);
569        assert!(!client.ssrf_protection);
570    }
571
572    #[test]
573    fn task_event_clone() {
574        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
575            kind: "status-update".into(),
576            task_id: "t-1".into(),
577            context_id: None,
578            status: TaskStatus {
579                state: TaskState::Working,
580                timestamp: "ts".into(),
581                message: None,
582            },
583            is_final: false,
584        });
585        let cloned = event.clone();
586        let json1 = serde_json::to_string(&event).unwrap();
587        let json2 = serde_json::to_string(&cloned).unwrap();
588        assert_eq!(json1, json2);
589    }
590
591    #[test]
592    fn task_event_debug() {
593        let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
594            kind: "artifact-update".into(),
595            task_id: "t-1".into(),
596            context_id: None,
597            artifact: Artifact {
598                artifact_id: "a-1".into(),
599                name: None,
600                parts: vec![Part::text("data")],
601                metadata: None,
602            },
603            is_final: true,
604        });
605        let dbg = format!("{event:?}");
606        assert!(dbg.contains("ArtifactUpdate"));
607    }
608
609    #[test]
610    fn is_private_ip_ipv4_non_private() {
611        assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
612    }
613
614    #[test]
615    fn is_private_ip_ipv6_non_private() {
616        assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
617    }
618
619    #[test]
620    fn rpc_response_error_takes_priority_over_result() {
621        let resp = JsonRpcResponse {
622            jsonrpc: "2.0".into(),
623            id: serde_json::Value::String("1".into()),
624            result: Some(Task {
625                id: "t-1".into(),
626                context_id: None,
627                status: TaskStatus {
628                    state: TaskState::Completed,
629                    timestamp: "ts".into(),
630                    message: None,
631                },
632                artifacts: vec![],
633                history: vec![],
634                metadata: None,
635            }),
636            error: Some(JsonRpcError {
637                code: -32001,
638                message: "error".into(),
639                data: None,
640            }),
641        };
642        let err = resp.into_result().unwrap_err();
643        assert_eq!(err.code, -32001);
644    }
645
646    #[test]
647    fn rpc_response_neither_result_nor_error() {
648        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
649            jsonrpc: "2.0".into(),
650            id: serde_json::Value::String("1".into()),
651            result: None,
652            error: None,
653        };
654        let err = resp.into_result().unwrap_err();
655        assert_eq!(err.code, -32603);
656    }
657
658    #[test]
659    fn task_event_serialize_round_trip() {
660        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
661            kind: "status-update".into(),
662            task_id: "t-1".into(),
663            context_id: Some("ctx-1".into()),
664            status: TaskStatus {
665                state: TaskState::Completed,
666                timestamp: "2025-01-01T00:00:00Z".into(),
667                message: Some(Message::user_text("done")),
668            },
669            is_final: true,
670        });
671        let json = serde_json::to_string(&event).unwrap();
672        let back: TaskEvent = serde_json::from_str(&json).unwrap();
673        assert!(matches!(back, TaskEvent::StatusUpdate(_)));
674    }
675}