Skip to main content

wme_client/
realtime.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use chrono::{DateTime, Utc};
6use futures::{Stream, StreamExt};
7use tokio::sync::mpsc;
8use tracing::{debug, error, trace, warn};
9
10use wme_models::metadata::ArticleUpdate;
11use wme_models::RequestParams;
12
13use crate::{ClientError, Result, WmeClient};
14
15/// Realtime connection options.
16#[derive(Debug, Clone, Default)]
17pub struct RealtimeConnectOptions {
18    /// Global since timestamp
19    pub since: Option<DateTime<Utc>>,
20    /// Per-partition resume (critical for reconnect without data loss)
21    pub since_per_partition: Option<HashMap<String, DateTime<Utc>>>,
22    /// Which partitions to connect to
23    pub partitions: Option<Vec<u32>>,
24    /// Per-partition offsets for resume
25    pub offsets: Option<HashMap<String, u64>>,
26}
27
28impl RealtimeConnectOptions {
29    /// Create new options with just a global since timestamp.
30    pub fn since(timestamp: DateTime<Utc>) -> Self {
31        Self {
32            since: Some(timestamp),
33            since_per_partition: None,
34            partitions: None,
35            offsets: None,
36        }
37    }
38
39    /// Create new options with per-partition resume.
40    pub fn since_per_partition(partitions: HashMap<String, DateTime<Utc>>) -> Self {
41        Self {
42            since: None,
43            since_per_partition: Some(partitions),
44            partitions: None,
45            offsets: None,
46        }
47    }
48
49    /// Set specific partitions to connect to.
50    pub fn with_partitions(mut self, partitions: Vec<u32>) -> Self {
51        self.partitions = Some(partitions);
52        self
53    }
54
55    /// Set per-partition offsets for resume.
56    pub fn with_offsets(mut self, offsets: HashMap<String, u64>) -> Self {
57        self.offsets = Some(offsets);
58        self
59    }
60
61    /// Convert to JSON request body.
62    fn to_request_body(&self, filters: Option<&wme_models::RequestParams>) -> serde_json::Value {
63        let mut body = serde_json::Map::new();
64
65        if let Some(since) = self.since {
66            body.insert("since".to_string(), serde_json::json!(since.to_rfc3339()));
67        }
68
69        if let Some(ref per_partition) = self.since_per_partition {
70            let map: serde_json::Map<String, serde_json::Value> = per_partition
71                .iter()
72                .map(|(k, v)| (k.clone(), serde_json::json!(v.to_rfc3339())))
73                .collect();
74            body.insert(
75                "since_per_partition".to_string(),
76                serde_json::Value::Object(map),
77            );
78        }
79
80        if let Some(ref partitions) = self.partitions {
81            body.insert("parts".to_string(), serde_json::json!(partitions));
82        }
83
84        if let Some(ref offsets) = self.offsets {
85            let map: serde_json::Map<String, serde_json::Value> = offsets
86                .iter()
87                .map(|(k, v)| (k.clone(), serde_json::json!(v)))
88                .collect();
89            body.insert("offsets".to_string(), serde_json::Value::Object(map));
90        }
91
92        if let Some(req) = filters {
93            if let Some(ref fields) = req.fields {
94                body.insert("fields".to_string(), serde_json::json!(fields));
95            }
96            if let Some(ref filters) = req.filters {
97                body.insert("filters".to_string(), serde_json::json!(filters));
98            }
99        }
100
101        serde_json::Value::Object(body)
102    }
103}
104
105/// Realtime API client for SSE streaming.
106pub struct RealtimeClient<'a> {
107    client: &'a WmeClient,
108}
109
110impl<'a> RealtimeClient<'a> {
111    /// Create a new realtime client.
112    pub(crate) fn new(client: &'a WmeClient) -> Self {
113        Self { client }
114    }
115
116    /// Connect to the realtime SSE stream.
117    ///
118    /// # Arguments
119    ///
120    /// * `options` - Connection options including resume timestamps
121    /// * `filters` - Optional filters for the stream
122    ///
123    /// # Returns
124    ///
125    /// A stream of article updates. The stream yields `ArticleUpdate` items
126    /// parsed from the SSE events.
127    ///
128    /// # Example
129    ///
130    /// ```rust,no_run
131    /// use wme_client::{WmeClient, RealtimeConnectOptions};
132    /// use futures::StreamExt;
133    /// use chrono::Utc;
134    ///
135    /// # async fn example(client: &WmeClient) -> Result<(), Box<dyn std::error::Error>> {
136    /// let options = RealtimeConnectOptions::since(Utc::now());
137    /// let mut stream = client.realtime().connect(&options, None).await?;
138    ///
139    /// while let Some(result) = stream.next().await {
140    ///     match result {
141    ///         Ok(update) => println!("Article: {}", update.article.name),
142    ///         Err(e) => eprintln!("Error: {}", e),
143    ///     }
144    /// }
145    /// # Ok(())
146    /// # }
147    /// ```
148    pub async fn connect(
149        &self,
150        options: &RealtimeConnectOptions,
151        filters: Option<&RequestParams>,
152    ) -> Result<Box<dyn Stream<Item = Result<ArticleUpdate>> + Send + Unpin>> {
153        let url = format!("{}/v2/articles", self.client.base_urls().realtime);
154        let body = options.to_request_body(filters).to_string();
155
156        // Create reqwest request with SSE headers
157        let client = reqwest::Client::new();
158        let mut request = client
159            .post(&url)
160            .header("Accept", "text/event-stream")
161            .header("Content-Type", "application/json");
162
163        // Add authorization if available
164        if let Some(headers) = self.client.auth_headers().await? {
165            if let Some(auth) = headers.get("Authorization") {
166                request = request.header("Authorization", auth);
167            }
168        }
169
170        let response = request.body(body).send().await.map_err(ClientError::from)?;
171
172        if !response.status().is_success() {
173            return Err(ClientError::Http(format!(
174                "Failed to connect to realtime stream: {}",
175                response.status()
176            )));
177        }
178
179        let (tx, rx) = mpsc::channel(100);
180
181        // Spawn SSE streaming task
182        tokio::spawn(async move {
183            let mut stream = response.bytes_stream();
184            let mut buffer = String::new();
185            let mut current_event = String::new();
186
187            while let Some(chunk) = stream.next().await {
188                match chunk {
189                    Ok(bytes) => {
190                        // Append to buffer
191                        buffer.push_str(&String::from_utf8_lossy(&bytes));
192
193                        // Process complete lines
194                        while let Some(pos) = buffer.find('\n') {
195                            let line = buffer[..pos].to_string();
196                            buffer = buffer[pos + 1..].to_string();
197
198                            let line = line.trim_end_matches('\r');
199
200                            if line.is_empty() {
201                                // Empty line means end of event - process it
202                                if !current_event.is_empty() {
203                                    trace!(event_data = %current_event, "Processing SSE event");
204
205                                    match serde_json::from_str::<ArticleUpdate>(&current_event) {
206                                        Ok(update) => {
207                                            if tx.send(Ok(update)).await.is_err() {
208                                                debug!(
209                                                    "SSE stream receiver dropped, closing stream"
210                                                );
211                                                return;
212                                            }
213                                        }
214                                        Err(e) => {
215                                            warn!(error = %e, data = %current_event, "Failed to parse SSE event data");
216                                            if tx
217                                                .send(Err(ClientError::JsonParse(e.to_string())))
218                                                .await
219                                                .is_err()
220                                            {
221                                                return;
222                                            }
223                                        }
224                                    }
225                                    current_event.clear();
226                                }
227                            } else if let Some(data) = line.strip_prefix("data: ") {
228                                // SSE data field
229                                current_event.push_str(data);
230                            } else if line.starts_with("id: ") {
231                                // SSE id field - could be used for resuming
232                                trace!(event_id = %line, "Received SSE id");
233                            } else if line.starts_with("event: ") {
234                                // SSE event type field
235                                trace!(event_type = %line, "Received SSE event type");
236                            } else if line.starts_with(":") {
237                                // SSE comment - ignore
238                                trace!(comment = %line, "Received SSE comment");
239                            }
240                        }
241                    }
242                    Err(e) => {
243                        error!(error = %e, "SSE stream error");
244                        let _ = tx.send(Err(ClientError::Stream(e.to_string()))).await;
245                        return;
246                    }
247                }
248            }
249
250            // Process any remaining data in buffer
251            if !buffer.is_empty() {
252                let line = buffer.trim_end_matches('\r');
253                if let Some(data) = line.strip_prefix("data: ") {
254                    current_event.push_str(data);
255                }
256            }
257
258            // Process final event if any
259            if !current_event.is_empty() {
260                match serde_json::from_str::<ArticleUpdate>(&current_event) {
261                    Ok(update) => {
262                        let _ = tx.send(Ok(update)).await;
263                    }
264                    Err(e) => {
265                        let _ = tx.send(Err(ClientError::JsonParse(e.to_string()))).await;
266                    }
267                }
268            }
269
270            debug!("SSE stream ended");
271        });
272
273        // Create a stream from the receiver
274        let stream = SseStream { receiver: rx };
275        Ok(Box::new(stream))
276    }
277
278    /// List available hourly batches for a specific date and hour.
279    pub async fn list_batches(
280        &self,
281        date: &str,
282        hour: &str,
283    ) -> Result<Vec<wme_models::metadata::RealtimeBatchInfo>> {
284        self.list_batches_with_params(date, hour, None).await
285    }
286
287    /// List batches with filters and field selection.
288    pub async fn list_batches_with_params(
289        &self,
290        date: &str,
291        hour: &str,
292        params: Option<&RequestParams>,
293    ) -> Result<Vec<wme_models::metadata::RealtimeBatchInfo>> {
294        let url = format!(
295            "{}/v2/batches/{}/{}",
296            self.client.base_urls().api,
297            date,
298            hour
299        );
300        let headers = self.client.auth_headers().await?;
301
302        let response = if let Some(p) = params {
303            let body = serde_json::to_string(p)?;
304            self.client
305                .transport()
306                .request(reqwest::Method::POST, &url, headers, Some(body))
307                .await?
308        } else {
309            self.client
310                .transport()
311                .request(reqwest::Method::GET, &url, headers, None)
312                .await?
313        };
314
315        if !response.status().is_success() {
316            return Err(ClientError::Http(format!(
317                "Failed to list batches: {}",
318                response.status()
319            )));
320        }
321
322        let batches = response.json().await.map_err(ClientError::from)?;
323        Ok(batches)
324    }
325
326    /// Get specific batch metadata.
327    pub async fn get_batch_info(
328        &self,
329        date: &str,
330        hour: &str,
331        identifier: &str,
332    ) -> Result<wme_models::metadata::RealtimeBatchInfo> {
333        self.get_batch_info_with_params(date, hour, identifier, None)
334            .await
335    }
336
337    /// Get batch metadata with field selection.
338    pub async fn get_batch_info_with_params(
339        &self,
340        date: &str,
341        hour: &str,
342        identifier: &str,
343        params: Option<&RequestParams>,
344    ) -> Result<wme_models::metadata::RealtimeBatchInfo> {
345        let url = format!(
346            "{}/v2/batches/{}/{}/{}",
347            self.client.base_urls().api,
348            date,
349            hour,
350            identifier
351        );
352        let headers = self.client.auth_headers().await?;
353
354        let response = if let Some(p) = params {
355            let body = serde_json::to_string(p)?;
356            self.client
357                .transport()
358                .request(reqwest::Method::POST, &url, headers, Some(body))
359                .await?
360        } else {
361            self.client
362                .transport()
363                .request(reqwest::Method::GET, &url, headers, None)
364                .await?
365        };
366
367        if !response.status().is_success() {
368            return Err(ClientError::Http(format!(
369                "Failed to get batch info: {}",
370                response.status()
371            )));
372        }
373
374        let batch = response.json().await.map_err(ClientError::from)?;
375        Ok(batch)
376    }
377
378    /// Download a batch as a stream of bytes.
379    pub async fn download_batch(
380        &self,
381        date: &str,
382        hour: &str,
383        identifier: &str,
384        range: Option<&str>,
385    ) -> Result<Box<dyn Stream<Item = Result<bytes::Bytes>> + Send + Unpin>> {
386        let url = format!(
387            "{}/v2/batches/{}/{}/{}/download",
388            self.client.base_urls().api,
389            date,
390            hour,
391            identifier
392        );
393
394        let mut headers = self.client.auth_headers().await?;
395
396        if let Some(range) = range {
397            headers = headers.or_else(|| Some(std::collections::HashMap::new()));
398            if let Some(ref mut h) = headers {
399                h.insert(reqwest::header::RANGE.to_string(), range.to_string());
400            }
401        }
402
403        let stream = self
404            .client
405            .transport()
406            .stream(reqwest::Method::GET, &url, headers)
407            .await?;
408
409        Ok(stream)
410    }
411
412    /// Download a batch and parse it as a stream of articles.
413    ///
414    /// This method downloads the batch tarball, decompresses it, and parses
415    /// the NDJSON content into Article structs.
416    ///
417    /// # Example
418    ///
419    /// ```rust,no_run
420    /// use wme_client::{WmeClient, RealtimeConnectOptions};
421    /// use futures::StreamExt;
422    ///
423    /// # async fn example(client: &WmeClient) -> Result<(), Box<dyn std::error::Error>> {
424    /// let mut stream = client.realtime().stream_batch("2024-01-15", "12", "batch_id").await?;
425    ///
426    /// while let Some(result) = stream.next().await {
427    ///     match result {
428    ///         Ok(article) => println!("Article: {}", article.name),
429    ///         Err(e) => eprintln!("Error: {}", e),
430    ///     }
431    /// }
432    /// # Ok(())
433    /// # }
434    /// ```
435    pub async fn stream_batch(
436        &self,
437        date: &str,
438        hour: &str,
439        identifier: &str,
440    ) -> Result<Box<dyn Stream<Item = Result<wme_models::Article>> + Send + Unpin>> {
441        use async_compression::tokio::bufread::GzipDecoder;
442        use tokio::io::{AsyncBufReadExt, BufReader as TokioBufReader};
443        use tokio_tar::Archive;
444
445        // Download the batch
446        let byte_stream = self.download_batch(date, hour, identifier, None).await?;
447
448        // Convert byte stream to async reader
449        let reader = tokio_util::io::StreamReader::new(
450            byte_stream.map(|result| result.map_err(std::io::Error::other)),
451        );
452
453        // Decompress gzip
454        let gz_decoder = GzipDecoder::new(TokioBufReader::new(reader));
455
456        // Create tar archive
457        let mut archive = Archive::new(gz_decoder);
458
459        let (tx, rx) = mpsc::channel(100);
460
461        // Spawn task to process tarball entries
462        tokio::spawn(async move {
463            let mut entries = archive.entries().map_err(|e| {
464                error!(error = %e, "Failed to read tar archive entries");
465                ClientError::Io(format!("Failed to read tar entries: {}", e))
466            })?;
467
468            while let Some(entry) = entries.next().await {
469                match entry {
470                    Ok(mut entry) => {
471                        // Read NDJSON lines from this entry
472                        let mut lines = TokioBufReader::new(&mut entry).lines();
473
474                        while let Ok(Some(line)) = lines.next_line().await {
475                            if line.trim().is_empty() {
476                                continue;
477                            }
478
479                            match serde_json::from_str::<wme_models::Article>(&line) {
480                                Ok(article) => {
481                                    if tx.send(Ok(article)).await.is_err() {
482                                        debug!("Batch stream receiver dropped");
483                                        return Ok::<(), ClientError>(());
484                                    }
485                                }
486                                Err(e) => {
487                                    warn!(error = %e, line = %line, "Failed to parse NDJSON line");
488                                    if tx
489                                        .send(Err(ClientError::JsonParse(e.to_string())))
490                                        .await
491                                        .is_err()
492                                    {
493                                        return Ok::<(), ClientError>(());
494                                    }
495                                }
496                            }
497                        }
498                    }
499                    Err(e) => {
500                        error!(error = %e, "Error reading tar entry");
501                        let _ = tx
502                            .send(Err(ClientError::Io(format!("Tar entry error: {}", e))))
503                            .await;
504                    }
505                }
506            }
507
508            debug!("Batch stream completed");
509            Ok(())
510        });
511
512        // Create stream from receiver
513        let stream = BatchStream { receiver: rx };
514        Ok(Box::new(stream))
515    }
516
517    /// Get batch download metadata (HEAD request).
518    pub async fn head_batch_download(
519        &self,
520        date: &str,
521        hour: &str,
522        identifier: &str,
523    ) -> Result<reqwest::header::HeaderMap> {
524        let url = format!(
525            "{}/v2/batches/{}/{}/{}/download",
526            self.client.base_urls().api,
527            date,
528            hour,
529            identifier
530        );
531        let headers = self.client.auth_headers().await?;
532
533        let response = self
534            .client
535            .transport()
536            .request(reqwest::Method::HEAD, &url, headers, None)
537            .await?;
538
539        match response.status().as_u16() {
540            200 => Ok(response.headers().clone()),
541            status => Err(ClientError::Http(format!(
542                "Failed to get batch headers: {}",
543                status
544            ))),
545        }
546    }
547}
548
549/// Stream implementation for SSE events.
550struct SseStream {
551    receiver: mpsc::Receiver<Result<ArticleUpdate>>,
552}
553
554impl Stream for SseStream {
555    type Item = Result<ArticleUpdate>;
556
557    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
558        self.receiver.poll_recv(cx)
559    }
560}
561
562/// Stream implementation for batch articles.
563struct BatchStream {
564    receiver: mpsc::Receiver<Result<wme_models::Article>>,
565}
566
567impl Stream for BatchStream {
568    type Item = Result<wme_models::Article>;
569
570    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
571        self.receiver.poll_recv(cx)
572    }
573}