syncable_ag_ui_client/
sse.rs1use std::pin::Pin;
20use std::task::{Context, Poll};
21
22use syncable_ag_ui_core::{Event, JsonValue};
23use futures::Stream;
24use reqwest::Client;
25use reqwest_eventsource::{Event as SseEvent, EventSource};
26
27use crate::error::{ClientError, Result};
28
29#[derive(Debug, Clone)]
31pub struct SseConfig {
32 pub connect_timeout: std::time::Duration,
34 pub headers: Vec<(String, String)>,
36}
37
38impl Default for SseConfig {
39 fn default() -> Self {
40 Self {
41 connect_timeout: std::time::Duration::from_secs(30),
42 headers: Vec::new(),
43 }
44 }
45}
46
47impl SseConfig {
48 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn connect_timeout(mut self, timeout: std::time::Duration) -> Self {
55 self.connect_timeout = timeout;
56 self
57 }
58
59 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
61 self.headers.push((name.into(), value.into()));
62 self
63 }
64
65 pub fn bearer_token(self, token: impl Into<String>) -> Self {
67 self.header("Authorization", format!("Bearer {}", token.into()))
68 }
69}
70
71pub struct SseClient {
76 event_source: EventSource,
77}
78
79impl SseClient {
80 pub async fn connect(url: &str) -> Result<Self> {
92 Self::connect_with_config(url, SseConfig::default()).await
93 }
94
95 pub async fn connect_with_config(url: &str, config: SseConfig) -> Result<Self> {
111 let client = Client::builder()
112 .timeout(config.connect_timeout)
113 .build()
114 .map_err(|e| ClientError::connection(e.to_string()))?;
115
116 let mut request = client.get(url);
117
118 for (name, value) in config.headers {
119 request = request.header(&name, &value);
120 }
121
122 let event_source = EventSource::new(request)
123 .map_err(|e| ClientError::connection(e.to_string()))?;
124
125 Ok(Self { event_source })
126 }
127
128 pub fn into_stream(self) -> SseEventStream {
132 SseEventStream {
133 event_source: self.event_source,
134 }
135 }
136}
137
138pub struct SseEventStream {
142 event_source: EventSource,
143}
144
145impl Stream for SseEventStream {
146 type Item = Result<Event<JsonValue>>;
147
148 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149 loop {
150 match Pin::new(&mut self.event_source).poll_next(cx) {
151 Poll::Ready(Some(Ok(sse_event))) => {
152 match sse_event {
153 SseEvent::Open => {
154 continue;
156 }
157 SseEvent::Message(msg) => {
158 match serde_json::from_str::<Event<JsonValue>>(&msg.data) {
160 Ok(event) => return Poll::Ready(Some(Ok(event))),
161 Err(e) => {
162 return Poll::Ready(Some(Err(ClientError::parse(format!(
163 "failed to parse event: {}",
164 e
165 )))))
166 }
167 }
168 }
169 }
170 }
171 Poll::Ready(Some(Err(e))) => {
172 return Poll::Ready(Some(Err(ClientError::sse(e.to_string()))))
173 }
174 Poll::Ready(None) => return Poll::Ready(None),
175 Poll::Pending => return Poll::Pending,
176 }
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_sse_config_default() {
187 let config = SseConfig::default();
188 assert_eq!(config.connect_timeout, std::time::Duration::from_secs(30));
189 assert!(config.headers.is_empty());
190 }
191
192 #[test]
193 fn test_sse_config_builder() {
194 let config = SseConfig::new()
195 .connect_timeout(std::time::Duration::from_secs(60))
196 .header("X-Custom", "value")
197 .bearer_token("token123");
198
199 assert_eq!(config.connect_timeout, std::time::Duration::from_secs(60));
200 assert_eq!(config.headers.len(), 2);
201 assert_eq!(config.headers[0], ("X-Custom".to_string(), "value".to_string()));
202 assert_eq!(
203 config.headers[1],
204 ("Authorization".to_string(), "Bearer token123".to_string())
205 );
206 }
207}