1use eventsource_stream::Eventsource;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::protocol::error::A2AError;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SseEvent {
15 pub kind: String,
17
18 pub payload: Value,
20
21 #[serde(default)]
23 pub final_event: bool,
24}
25
26impl SseEvent {
27 pub fn is_terminal(&self) -> bool {
29 if self.final_event {
30 return true;
31 }
32
33 if let Some(state) = self.payload.get("state").and_then(|s| s.as_str()) {
35 matches!(state, "completed" | "failed" | "canceled" | "rejected")
36 } else {
37 false
38 }
39 }
40
41 pub fn is_error(&self) -> bool {
43 if let Some(state) = self.payload.get("state").and_then(|s| s.as_str()) {
44 matches!(state, "failed" | "canceled" | "rejected")
45 } else {
46 false
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct SseCodec;
54
55impl SseCodec {
56 pub fn new() -> Self {
58 Self
59 }
60
61 pub fn parse_stream<S>(&self, byte_stream: S) -> impl Stream<Item = Result<SseEvent, A2AError>>
66 where
67 S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
68 {
69 byte_stream.eventsource().map(|result| {
70 match result {
71 Ok(event) => {
72 let jsonrpc: Value = serde_json::from_str(&event.data).map_err(|e| {
74 A2AError::Protocol(format!("Failed to parse SSE event data: {}", e))
75 })?;
76
77 if let Some(error) = jsonrpc.get("error") {
79 let error_msg = error
80 .get("message")
81 .and_then(|m| m.as_str())
82 .unwrap_or("Unknown error");
83 return Err(A2AError::Protocol(format!(
84 "SSE stream error: {}",
85 error_msg
86 )));
87 }
88
89 let result = jsonrpc.get("result").ok_or_else(|| {
91 A2AError::Protocol("SSE event missing 'result' field".to_string())
92 })?;
93
94 let final_event = result
96 .get("final")
97 .and_then(|f| f.as_bool())
98 .unwrap_or(false);
99
100 let kind = result
102 .get("kind")
103 .and_then(|k| k.as_str())
104 .unwrap_or("event")
105 .to_string();
106
107 Ok(SseEvent {
108 kind,
109 payload: result.clone(),
110 final_event,
111 })
112 }
113 Err(e) => Err(A2AError::Transport(format!("SSE stream error: {}", e))),
114 }
115 })
116 }
117}
118
119impl Default for SseCodec {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use futures::StreamExt;
128 use serde_json::json;
129
130 use super::*;
131
132 #[test]
133 fn test_sse_event_is_terminal() {
134 let event = SseEvent {
135 kind: "status-update".to_string(),
136 payload: json!({
137 "state": "completed"
138 }),
139 final_event: false,
140 };
141 assert!(event.is_terminal());
142
143 let event = SseEvent {
144 kind: "artifact-update".to_string(),
145 payload: json!({}),
146 final_event: true,
147 };
148 assert!(event.is_terminal());
149
150 let event = SseEvent {
151 kind: "status-update".to_string(),
152 payload: json!({
153 "state": "running"
154 }),
155 final_event: false,
156 };
157 assert!(!event.is_terminal());
158 }
159
160 #[test]
161 fn test_sse_event_is_error() {
162 let event = SseEvent {
163 kind: "status-update".to_string(),
164 payload: json!({
165 "state": "failed"
166 }),
167 final_event: false,
168 };
169 assert!(event.is_error());
170
171 let event = SseEvent {
172 kind: "status-update".to_string(),
173 payload: json!({
174 "state": "completed"
175 }),
176 final_event: false,
177 };
178 assert!(!event.is_error());
179 }
180
181 #[tokio::test]
182 async fn test_parse_sse_stream() {
183 use futures::pin_mut;
184
185 let codec = SseCodec::new();
186
187 let sse_data = "data: {\"jsonrpc\":\"2.0\",\"result\":{\"kind\":\"status-update\",\"state\":\"running\"},\"id\":\"1\"}\n\n\
189 data: {\"jsonrpc\":\"2.0\",\"result\":{\"kind\":\"artifact-update\",\"final\":true},\"id\":\"2\"}\n\n";
190
191 let byte_stream = futures::stream::once(async move {
192 Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from(sse_data))
193 });
194
195 let event_stream = codec.parse_stream(byte_stream);
196 pin_mut!(event_stream);
197
198 let event1 = event_stream.next().await.unwrap().unwrap();
200 assert_eq!(event1.kind, "status-update");
201 assert!(!event1.final_event);
202
203 let event2 = event_stream.next().await.unwrap().unwrap();
205 assert_eq!(event2.kind, "artifact-update");
206 assert!(event2.final_event);
207 }
208
209 #[tokio::test]
210 async fn test_parse_sse_error() {
211 use futures::pin_mut;
212
213 let codec = SseCodec::new();
214
215 let sse_data = "data: {\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"Invalid Request\"},\"id\":\"1\"}\n\n";
216
217 let byte_stream = futures::stream::once(async move {
218 Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from(sse_data))
219 });
220
221 let event_stream = codec.parse_stream(byte_stream);
222 pin_mut!(event_stream);
223
224 let result = event_stream.next().await.unwrap();
225 assert!(result.is_err());
226
227 match result {
228 Err(A2AError::Protocol(msg)) => {
229 assert!(msg.contains("Invalid Request"));
230 }
231 _ => panic!("Expected Protocol error"),
232 }
233 }
234}