Skip to main content

zeph_a2a/
client.rs

1use std::net::IpAddr;
2use std::pin::Pin;
3
4use eventsource_stream::Eventsource;
5use futures_core::Stream;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use tokio_stream::StreamExt;
8
9use crate::error::A2aError;
10use crate::jsonrpc::{
11    JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
12    METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
13};
14use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
15
16pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(untagged)]
20pub enum TaskEvent {
21    StatusUpdate(TaskStatusUpdateEvent),
22    ArtifactUpdate(TaskArtifactUpdateEvent),
23}
24
25pub struct A2aClient {
26    client: reqwest::Client,
27    require_tls: bool,
28    ssrf_protection: bool,
29}
30
31impl A2aClient {
32    #[must_use]
33    pub fn new(client: reqwest::Client) -> Self {
34        Self {
35            client,
36            require_tls: false,
37            ssrf_protection: false,
38        }
39    }
40
41    #[must_use]
42    pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
43        self.require_tls = require_tls;
44        self.ssrf_protection = ssrf_protection;
45        self
46    }
47
48    /// # Errors
49    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
50    pub async fn send_message(
51        &self,
52        endpoint: &str,
53        params: SendMessageParams,
54        token: Option<&str>,
55    ) -> Result<Task, A2aError> {
56        self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
57            .await
58    }
59
60    /// # Errors
61    /// Returns `A2aError` on network failure or if the SSE connection cannot be established.
62    pub async fn stream_message(
63        &self,
64        endpoint: &str,
65        params: SendMessageParams,
66        token: Option<&str>,
67    ) -> Result<TaskEventStream, A2aError> {
68        self.validate_endpoint(endpoint).await?;
69        let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
70        let mut req = self.client.post(endpoint).json(&request);
71        if let Some(t) = token {
72            req = req.bearer_auth(t);
73        }
74        let resp = req.send().await?;
75
76        if !resp.status().is_success() {
77            let status = resp.status();
78            let body = resp.text().await.unwrap_or_default();
79            // Truncate body to avoid leaking large upstream error responses.
80            let truncated = if body.len() > 256 {
81                format!("{}…", &body[..256])
82            } else {
83                body
84            };
85            return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
86        }
87
88        let event_stream = resp.bytes_stream().eventsource();
89        let mapped = event_stream.filter_map(|event| match event {
90            Ok(event) => {
91                if event.data.is_empty() || event.data == "[DONE]" {
92                    return None;
93                }
94                match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
95                    Ok(rpc_resp) => match rpc_resp.into_result() {
96                        Ok(task_event) => Some(Ok(task_event)),
97                        Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
98                    },
99                    Err(e) => Some(Err(A2aError::Stream(format!(
100                        "failed to parse SSE event: {e}"
101                    )))),
102                }
103            }
104            Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
105        });
106
107        Ok(Box::pin(mapped))
108    }
109
110    /// # Errors
111    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
112    pub async fn get_task(
113        &self,
114        endpoint: &str,
115        params: TaskIdParams,
116        token: Option<&str>,
117    ) -> Result<Task, A2aError> {
118        self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
119            .await
120    }
121
122    /// # Errors
123    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
124    pub async fn cancel_task(
125        &self,
126        endpoint: &str,
127        params: TaskIdParams,
128        token: Option<&str>,
129    ) -> Result<Task, A2aError> {
130        self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
131            .await
132    }
133
134    async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
135        if self.require_tls && !endpoint.starts_with("https://") {
136            return Err(A2aError::Security(format!(
137                "TLS required but endpoint uses HTTP: {endpoint}"
138            )));
139        }
140
141        if self.ssrf_protection {
142            let url: url::Url = endpoint
143                .parse()
144                .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
145
146            if let Some(host) = url.host_str() {
147                let addrs = tokio::net::lookup_host(format!(
148                    "{}:{}",
149                    host,
150                    url.port_or_known_default().unwrap_or(443)
151                ))
152                .await
153                .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
154
155                for addr in addrs {
156                    if is_private_ip(addr.ip()) {
157                        return Err(A2aError::Security(format!(
158                            "SSRF protection: private IP {} for host {host}",
159                            addr.ip()
160                        )));
161                    }
162                }
163            }
164        }
165
166        Ok(())
167    }
168
169    async fn rpc_call<P: Serialize, R: DeserializeOwned>(
170        &self,
171        endpoint: &str,
172        method: &str,
173        params: P,
174        token: Option<&str>,
175    ) -> Result<R, A2aError> {
176        self.validate_endpoint(endpoint).await?;
177        let request = JsonRpcRequest::new(method, params);
178        let mut req = self.client.post(endpoint).json(&request);
179        if let Some(t) = token {
180            req = req.bearer_auth(t);
181        }
182        let resp = req.send().await?;
183        let rpc_response: JsonRpcResponse<R> = resp.json().await?;
184        rpc_response.into_result().map_err(A2aError::from)
185    }
186}
187
188fn is_private_ip(ip: IpAddr) -> bool {
189    match ip {
190        IpAddr::V4(v4) => {
191            v4.is_loopback()
192                || v4.is_private()
193                || v4.is_link_local()
194                || v4.is_unspecified()
195                || v4.is_broadcast()
196        }
197        IpAddr::V6(v6) => {
198            if v6.is_loopback() || v6.is_unspecified() {
199                return true;
200            }
201            let seg = v6.segments();
202            // fe80::/10 — link-local
203            if seg[0] & 0xffc0 == 0xfe80 {
204                return true;
205            }
206            // fc00::/7 — unique local
207            if seg[0] & 0xfe00 == 0xfc00 {
208                return true;
209            }
210            // ::ffff:x.x.x.x — IPv4-mapped, check inner IPv4
211            if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
212                let v4 = v6
213                    .to_ipv4_mapped()
214                    .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
215                return v4.is_loopback()
216                    || v4.is_private()
217                    || v4.is_link_local()
218                    || v4.is_unspecified()
219                    || v4.is_broadcast();
220            }
221            false
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
230    use crate::types::{
231        Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
232        TaskStatusUpdateEvent,
233    };
234
235    #[test]
236    fn task_event_deserialize_status_update() {
237        let event = TaskStatusUpdateEvent {
238            kind: "status-update".into(),
239            task_id: "t-1".into(),
240            context_id: None,
241            status: TaskStatus {
242                state: TaskState::Working,
243                timestamp: "ts".into(),
244                message: Some(Message::user_text("thinking...")),
245            },
246            is_final: false,
247        };
248        let json = serde_json::to_string(&event).unwrap();
249        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
250        assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
251    }
252
253    #[test]
254    fn task_event_deserialize_artifact_update() {
255        let event = TaskArtifactUpdateEvent {
256            kind: "artifact-update".into(),
257            task_id: "t-1".into(),
258            context_id: None,
259            artifact: Artifact {
260                artifact_id: "a-1".into(),
261                name: None,
262                parts: vec![Part::text("result")],
263                metadata: None,
264            },
265            is_final: true,
266        };
267        let json = serde_json::to_string(&event).unwrap();
268        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
269        assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
270    }
271
272    #[test]
273    fn rpc_response_with_task_result() {
274        let task = Task {
275            id: "t-1".into(),
276            context_id: None,
277            status: TaskStatus {
278                state: TaskState::Completed,
279                timestamp: "ts".into(),
280                message: None,
281            },
282            artifacts: vec![],
283            history: vec![],
284            metadata: None,
285        };
286        let resp = JsonRpcResponse {
287            jsonrpc: "2.0".into(),
288            id: serde_json::Value::String("req-1".into()),
289            result: Some(task),
290            error: None,
291        };
292        let json = serde_json::to_string(&resp).unwrap();
293        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
294        let task = back.into_result().unwrap();
295        assert_eq!(task.id, "t-1");
296        assert_eq!(task.status.state, TaskState::Completed);
297    }
298
299    #[test]
300    fn rpc_response_with_error() {
301        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
302            jsonrpc: "2.0".into(),
303            id: serde_json::Value::String("req-1".into()),
304            result: None,
305            error: Some(JsonRpcError {
306                code: -32001,
307                message: "task not found".into(),
308                data: None,
309            }),
310        };
311        let json = serde_json::to_string(&resp).unwrap();
312        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
313        let err = back.into_result().unwrap_err();
314        assert_eq!(err.code, -32001);
315    }
316
317    #[test]
318    fn a2a_client_construction() {
319        let client = A2aClient::new(reqwest::Client::new());
320        drop(client);
321    }
322
323    #[test]
324    fn is_private_ip_loopback() {
325        assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
326        assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
327    }
328
329    #[test]
330    fn is_private_ip_private_ranges() {
331        assert!(is_private_ip("10.0.0.1".parse().unwrap()));
332        assert!(is_private_ip("172.16.0.1".parse().unwrap()));
333        assert!(is_private_ip("192.168.1.1".parse().unwrap()));
334    }
335
336    #[test]
337    fn is_private_ip_link_local() {
338        assert!(is_private_ip("169.254.0.1".parse().unwrap()));
339    }
340
341    #[test]
342    fn is_private_ip_unspecified() {
343        assert!(is_private_ip("0.0.0.0".parse().unwrap()));
344        assert!(is_private_ip("::".parse().unwrap()));
345    }
346
347    #[test]
348    fn is_private_ip_public() {
349        assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
350        assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
351    }
352
353    #[tokio::test]
354    async fn tls_enforcement_rejects_http() {
355        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
356        let result = client.validate_endpoint("http://example.com/rpc").await;
357        assert!(result.is_err());
358        let err = result.unwrap_err();
359        assert!(matches!(err, A2aError::Security(_)));
360        assert!(err.to_string().contains("TLS required"));
361    }
362
363    #[tokio::test]
364    async fn tls_enforcement_allows_https() {
365        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
366        let result = client.validate_endpoint("https://example.com/rpc").await;
367        assert!(result.is_ok());
368    }
369
370    #[tokio::test]
371    async fn ssrf_protection_rejects_localhost() {
372        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
373        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
374        assert!(result.is_err());
375        assert!(result.unwrap_err().to_string().contains("SSRF"));
376    }
377
378    #[tokio::test]
379    async fn no_security_allows_http_localhost() {
380        let client = A2aClient::new(reqwest::Client::new());
381        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
382        assert!(result.is_ok());
383    }
384
385    #[test]
386    fn jsonrpc_request_serialization_for_send_message() {
387        let params = SendMessageParams {
388            message: Message::user_text("hello"),
389            configuration: None,
390        };
391        let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
392        let json = serde_json::to_string(&req).unwrap();
393        assert!(json.contains("\"method\":\"message/send\""));
394        assert!(json.contains("\"jsonrpc\":\"2.0\""));
395        assert!(json.contains("\"hello\""));
396    }
397
398    #[test]
399    fn jsonrpc_request_serialization_for_get_task() {
400        let params = TaskIdParams {
401            id: "task-123".into(),
402            history_length: Some(5),
403        };
404        let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
405        let json = serde_json::to_string(&req).unwrap();
406        assert!(json.contains("\"method\":\"tasks/get\""));
407        assert!(json.contains("\"task-123\""));
408        assert!(json.contains("\"historyLength\":5"));
409    }
410
411    #[test]
412    fn jsonrpc_request_serialization_for_cancel_task() {
413        let params = TaskIdParams {
414            id: "task-456".into(),
415            history_length: None,
416        };
417        let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
418        let json = serde_json::to_string(&req).unwrap();
419        assert!(json.contains("\"method\":\"tasks/cancel\""));
420        assert!(!json.contains("historyLength"));
421    }
422
423    #[test]
424    fn jsonrpc_request_serialization_for_stream() {
425        let params = SendMessageParams {
426            message: Message::user_text("stream me"),
427            configuration: None,
428        };
429        let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
430        let json = serde_json::to_string(&req).unwrap();
431        assert!(json.contains("\"method\":\"message/stream\""));
432    }
433
434    #[tokio::test]
435    async fn send_message_connection_error() {
436        let client = A2aClient::new(reqwest::Client::new());
437        let params = SendMessageParams {
438            message: Message::user_text("hello"),
439            configuration: None,
440        };
441        let result = client
442            .send_message("http://127.0.0.1:1/rpc", params, None)
443            .await;
444        assert!(result.is_err());
445        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
446    }
447
448    #[tokio::test]
449    async fn get_task_connection_error() {
450        let client = A2aClient::new(reqwest::Client::new());
451        let params = TaskIdParams {
452            id: "t-1".into(),
453            history_length: None,
454        };
455        let result = client
456            .get_task("http://127.0.0.1:1/rpc", params, None)
457            .await;
458        assert!(result.is_err());
459        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
460    }
461
462    #[tokio::test]
463    async fn cancel_task_connection_error() {
464        let client = A2aClient::new(reqwest::Client::new());
465        let params = TaskIdParams {
466            id: "t-1".into(),
467            history_length: None,
468        };
469        let result = client
470            .cancel_task("http://127.0.0.1:1/rpc", params, None)
471            .await;
472        assert!(result.is_err());
473        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
474    }
475
476    #[tokio::test]
477    async fn stream_message_connection_error() {
478        let client = A2aClient::new(reqwest::Client::new());
479        let params = SendMessageParams {
480            message: Message::user_text("stream me"),
481            configuration: None,
482        };
483        let result = client
484            .stream_message("http://127.0.0.1:1/rpc", params, None)
485            .await;
486        assert!(result.is_err());
487    }
488
489    #[tokio::test]
490    async fn stream_message_tls_required_rejects_http() {
491        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
492        let params = SendMessageParams {
493            message: Message::user_text("hello"),
494            configuration: None,
495        };
496        let result = client
497            .stream_message("http://example.com/rpc", params, None)
498            .await;
499        match result {
500            Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
501            _ => panic!("expected Security error"),
502        }
503    }
504
505    #[tokio::test]
506    async fn send_message_tls_required_rejects_http() {
507        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
508        let params = SendMessageParams {
509            message: Message::user_text("hello"),
510            configuration: None,
511        };
512        let result = client
513            .send_message("http://example.com/rpc", params, None)
514            .await;
515        assert!(result.is_err());
516        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
517    }
518
519    #[tokio::test]
520    async fn get_task_tls_required_rejects_http() {
521        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
522        let params = TaskIdParams {
523            id: "t-1".into(),
524            history_length: None,
525        };
526        let result = client
527            .get_task("http://example.com/rpc", params, None)
528            .await;
529        assert!(result.is_err());
530        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
531    }
532
533    #[tokio::test]
534    async fn cancel_task_tls_required_rejects_http() {
535        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
536        let params = TaskIdParams {
537            id: "t-1".into(),
538            history_length: None,
539        };
540        let result = client
541            .cancel_task("http://example.com/rpc", params, None)
542            .await;
543        assert!(result.is_err());
544        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
545    }
546
547    #[tokio::test]
548    async fn validate_endpoint_invalid_url_with_ssrf() {
549        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
550        let result = client.validate_endpoint("not-a-url").await;
551        assert!(result.is_err());
552        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
553    }
554
555    #[test]
556    fn with_security_returns_configured_client() {
557        let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
558        assert!(client.require_tls);
559        assert!(client.ssrf_protection);
560    }
561
562    #[test]
563    fn default_client_no_security() {
564        let client = A2aClient::new(reqwest::Client::new());
565        assert!(!client.require_tls);
566        assert!(!client.ssrf_protection);
567    }
568
569    #[test]
570    fn task_event_clone() {
571        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
572            kind: "status-update".into(),
573            task_id: "t-1".into(),
574            context_id: None,
575            status: TaskStatus {
576                state: TaskState::Working,
577                timestamp: "ts".into(),
578                message: None,
579            },
580            is_final: false,
581        });
582        let cloned = event.clone();
583        let json1 = serde_json::to_string(&event).unwrap();
584        let json2 = serde_json::to_string(&cloned).unwrap();
585        assert_eq!(json1, json2);
586    }
587
588    #[test]
589    fn task_event_debug() {
590        let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
591            kind: "artifact-update".into(),
592            task_id: "t-1".into(),
593            context_id: None,
594            artifact: Artifact {
595                artifact_id: "a-1".into(),
596                name: None,
597                parts: vec![Part::text("data")],
598                metadata: None,
599            },
600            is_final: true,
601        });
602        let dbg = format!("{event:?}");
603        assert!(dbg.contains("ArtifactUpdate"));
604    }
605
606    #[test]
607    fn is_private_ip_ipv4_non_private() {
608        assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
609    }
610
611    #[test]
612    fn is_private_ip_ipv6_non_private() {
613        assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
614    }
615
616    #[test]
617    fn rpc_response_error_takes_priority_over_result() {
618        let resp = JsonRpcResponse {
619            jsonrpc: "2.0".into(),
620            id: serde_json::Value::String("1".into()),
621            result: Some(Task {
622                id: "t-1".into(),
623                context_id: None,
624                status: TaskStatus {
625                    state: TaskState::Completed,
626                    timestamp: "ts".into(),
627                    message: None,
628                },
629                artifacts: vec![],
630                history: vec![],
631                metadata: None,
632            }),
633            error: Some(JsonRpcError {
634                code: -32001,
635                message: "error".into(),
636                data: None,
637            }),
638        };
639        let err = resp.into_result().unwrap_err();
640        assert_eq!(err.code, -32001);
641    }
642
643    #[test]
644    fn rpc_response_neither_result_nor_error() {
645        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
646            jsonrpc: "2.0".into(),
647            id: serde_json::Value::String("1".into()),
648            result: None,
649            error: None,
650        };
651        let err = resp.into_result().unwrap_err();
652        assert_eq!(err.code, -32603);
653    }
654
655    #[test]
656    fn task_event_serialize_round_trip() {
657        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
658            kind: "status-update".into(),
659            task_id: "t-1".into(),
660            context_id: Some("ctx-1".into()),
661            status: TaskStatus {
662                state: TaskState::Completed,
663                timestamp: "2025-01-01T00:00:00Z".into(),
664                message: Some(Message::user_text("done")),
665            },
666            is_final: true,
667        });
668        let json = serde_json::to_string(&event).unwrap();
669        let back: TaskEvent = serde_json::from_str(&json).unwrap();
670        assert!(matches!(back, TaskEvent::StatusUpdate(_)));
671    }
672}