s2_api/v1/stream/
sse.rs

1use std::{str::FromStr, time::Duration};
2
3use s2_common::{http::ParseableHeader, types};
4use serde::Serialize;
5
6use super::ReadBatch;
7
8static LAST_EVENT_ID_HEADER: http::HeaderName = http::HeaderName::from_static("last-event-id");
9
10#[derive(Debug, Clone, Copy)]
11pub struct LastEventId {
12    pub seq_num: u64,
13    pub count: usize,
14    pub bytes: usize,
15}
16
17impl ParseableHeader for LastEventId {
18    fn name() -> &'static http::HeaderName {
19        &LAST_EVENT_ID_HEADER
20    }
21}
22
23impl Serialize for LastEventId {
24    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25    where
26        S: serde::Serializer,
27    {
28        self.to_string().serialize(serializer)
29    }
30}
31
32impl std::fmt::Display for LastEventId {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        let Self {
35            seq_num,
36            count,
37            bytes,
38        } = self;
39        write!(f, "{seq_num},{count},{bytes}")
40    }
41}
42
43impl FromStr for LastEventId {
44    type Err = types::ValidationError;
45
46    fn from_str(s: &str) -> Result<Self, Self::Err> {
47        let mut iter = s.splitn(3, ",");
48
49        fn get_next<T>(
50            iter: &mut std::str::SplitN<&str>,
51            field: &str,
52        ) -> Result<T, types::ValidationError>
53        where
54            T: FromStr,
55            <T as FromStr>::Err: std::fmt::Display,
56        {
57            let item = iter
58                .next()
59                .ok_or_else(|| format!("missing {field} in Last-Event-Id"))?;
60            item.parse()
61                .map_err(|e| format!("invalid {field} in Last-Event-ID: {e}").into())
62        }
63
64        let seq_num = get_next(&mut iter, "seq_num")?;
65        let count = get_next(&mut iter, "count")?;
66        let bytes = get_next(&mut iter, "bytes")?;
67
68        Ok(Self {
69            seq_num,
70            count,
71            bytes,
72        })
73    }
74}
75
76macro_rules! event {
77    ($name:ident, $val:expr) => {
78        #[derive(Serialize)]
79        #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
80        #[serde(rename_all = "snake_case")]
81        pub enum $name {
82            $name,
83        }
84
85        impl AsRef<str> for $name {
86            fn as_ref(&self) -> &str {
87                $val
88            }
89        }
90    };
91}
92
93event!(Batch, "batch");
94event!(Error, "error");
95event!(Ping, "ping");
96
97#[derive(Serialize)]
98#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
99#[serde(untagged)]
100pub enum ReadEvent {
101    #[cfg_attr(feature = "utoipa", schema(title = "batch"))]
102    Batch {
103        #[cfg_attr(feature = "utoipa", schema(inline))]
104        event: Batch,
105        data: ReadBatch,
106        #[cfg_attr(feature = "utoipa", schema(value_type = String, pattern = "^[0-9]+,[0-9]+,[0-9]+$"))]
107        id: LastEventId,
108    },
109    #[cfg_attr(feature = "utoipa", schema(title = "error"))]
110    Error {
111        #[cfg_attr(feature = "utoipa", schema(inline))]
112        event: Error,
113        data: String,
114    },
115    #[cfg_attr(feature = "utoipa", schema(title = "ping"))]
116    Ping {
117        #[cfg_attr(feature = "utoipa", schema(inline))]
118        event: Ping,
119        data: PingEventData,
120    },
121    #[cfg_attr(feature = "utoipa", schema(title = "done"))]
122    #[serde(skip)]
123    Done {
124        #[cfg_attr(feature = "utoipa", schema(value_type = String, pattern = r"^\[DONE\]$"))]
125        data: DoneEventData,
126    },
127}
128
129fn elapsed_since_epoch() -> Duration {
130    std::time::SystemTime::now()
131        .duration_since(std::time::SystemTime::UNIX_EPOCH)
132        .expect("healthy clock")
133}
134
135impl ReadEvent {
136    pub fn batch(data: ReadBatch, id: LastEventId) -> Self {
137        Self::Batch {
138            event: Batch::Batch,
139            data,
140            id,
141        }
142    }
143
144    pub fn error(data: String) -> Self {
145        Self::Error {
146            event: Error::Error,
147            data,
148        }
149    }
150
151    pub fn ping() -> Self {
152        Self::Ping {
153            event: Ping::Ping,
154            data: PingEventData {
155                timestamp: elapsed_since_epoch().as_millis() as u64,
156            },
157        }
158    }
159
160    pub fn done() -> Self {
161        Self::Done {
162            data: DoneEventData,
163        }
164    }
165}
166
167#[cfg(feature = "axum")]
168impl TryFrom<ReadEvent> for axum::response::sse::Event {
169    type Error = axum::Error;
170
171    fn try_from(event: ReadEvent) -> Result<Self, Self::Error> {
172        match event {
173            ReadEvent::Batch { event, data, id } => Self::default()
174                .event(event)
175                .id(id.to_string())
176                .json_data(data),
177            ReadEvent::Error { event, data } => Ok(Self::default().event(event).data(data)),
178            ReadEvent::Ping { event, data } => Self::default().event(event).json_data(data),
179            ReadEvent::Done { data } => Ok(Self::default().data(data)),
180        }
181    }
182}
183
184#[derive(Debug, Clone, Serialize)]
185#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
186#[serde(rename = "[DONE]")]
187pub struct DoneEventData;
188
189impl AsRef<str> for DoneEventData {
190    fn as_ref(&self) -> &str {
191        "[DONE]"
192    }
193}
194
195#[rustfmt::skip]
196#[derive(Debug, Clone, Serialize)]
197#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
198pub struct PingEventData {
199    pub timestamp: u64,
200}