1#![allow(clippy::doc_markdown)]
5
6use alloc::collections::BTreeMap;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct Subscription {
36 pub conn_str: String,
37 pub publications: Vec<String>,
38 pub enabled: bool,
39 pub last_received_pos: u64,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Default)]
43pub struct Subscriptions {
44 inner: BTreeMap<String, Subscription>,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48pub enum SubscriptionError {
49 DuplicateName(String),
50 Corrupt(String),
51}
52
53impl Subscriptions {
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn len(&self) -> usize {
59 self.inner.len()
60 }
61
62 pub fn is_empty(&self) -> bool {
63 self.inner.is_empty()
64 }
65
66 pub fn contains(&self, name: &str) -> bool {
67 self.inner.contains_key(name)
68 }
69
70 pub fn get(&self, name: &str) -> Option<&Subscription> {
71 self.inner.get(name)
72 }
73
74 pub fn iter(&self) -> impl Iterator<Item = (&String, &Subscription)> {
75 self.inner.iter()
76 }
77
78 pub fn create(&mut self, name: String, sub: Subscription) -> Result<(), SubscriptionError> {
79 if self.inner.contains_key(&name) {
80 return Err(SubscriptionError::DuplicateName(name));
81 }
82 self.inner.insert(name, sub);
83 Ok(())
84 }
85
86 pub fn drop(&mut self, name: &str) -> bool {
87 self.inner.remove(name).is_some()
88 }
89
90 pub fn update_last_received_pos(&mut self, name: &str, pos: u64) -> bool {
96 if let Some(s) = self.inner.get_mut(name) {
97 if pos > s.last_received_pos {
101 s.last_received_pos = pos;
102 }
103 true
104 } else {
105 false
106 }
107 }
108
109 pub fn serialize(&self) -> Vec<u8> {
121 let mut out = Vec::with_capacity(2 + self.inner.len() * 64);
122 let n = u16::try_from(self.inner.len()).expect("≤ 65,535 subscriptions per cluster");
123 out.extend_from_slice(&n.to_le_bytes());
124 for (name, sub) in &self.inner {
125 write_short_str(&mut out, name);
126 write_long_str(&mut out, &sub.conn_str);
127 let np = u16::try_from(sub.publications.len())
128 .expect("≤ 65,535 publications per subscription");
129 out.extend_from_slice(&np.to_le_bytes());
130 for p in &sub.publications {
131 write_short_str(&mut out, p);
132 }
133 out.push(u8::from(sub.enabled));
134 out.extend_from_slice(&sub.last_received_pos.to_le_bytes());
135 }
136 out
137 }
138
139 pub fn deserialize(buf: &[u8]) -> Result<Self, SubscriptionError> {
140 let mut p = 0usize;
141 let n = read_u16(buf, &mut p)? as usize;
142 let mut inner = BTreeMap::new();
143 for _ in 0..n {
144 let name = read_short_str(buf, &mut p)?;
145 let conn_str = read_long_str(buf, &mut p)?;
146 let np = read_u16(buf, &mut p)? as usize;
147 let mut publications = Vec::with_capacity(np);
148 for _ in 0..np {
149 publications.push(read_short_str(buf, &mut p)?);
150 }
151 let enabled_byte = read_u8(buf, &mut p)?;
152 let enabled = match enabled_byte {
153 0 => false,
154 1 => true,
155 other => {
156 return Err(SubscriptionError::Corrupt(alloc::format!(
157 "invalid `enabled` byte {other}, expected 0 or 1"
158 )));
159 }
160 };
161 let last_received_pos = read_u64(buf, &mut p)?;
162 if inner
163 .insert(
164 name.clone(),
165 Subscription {
166 conn_str,
167 publications,
168 enabled,
169 last_received_pos,
170 },
171 )
172 .is_some()
173 {
174 return Err(SubscriptionError::Corrupt(alloc::format!(
175 "duplicate subscription name {name:?} in serialised payload"
176 )));
177 }
178 }
179 if p != buf.len() {
180 return Err(SubscriptionError::Corrupt(alloc::format!(
181 "trailing bytes in subscriptions payload: read {p}, len {}",
182 buf.len()
183 )));
184 }
185 Ok(Self { inner })
186 }
187}
188
189fn write_short_str(out: &mut Vec<u8>, s: &str) {
190 let n = u16::try_from(s.len()).expect("subscription / publication name fits in u16");
191 out.extend_from_slice(&n.to_le_bytes());
192 out.extend_from_slice(s.as_bytes());
193}
194
195fn write_long_str(out: &mut Vec<u8>, s: &str) {
196 let n = u32::try_from(s.len()).expect("conn_str fits in u32");
198 out.extend_from_slice(&n.to_le_bytes());
199 out.extend_from_slice(s.as_bytes());
200}
201
202fn read_u8(buf: &[u8], p: &mut usize) -> Result<u8, SubscriptionError> {
203 let v = buf
204 .get(*p)
205 .copied()
206 .ok_or_else(|| SubscriptionError::Corrupt("short read (u8)".to_string()))?;
207 *p += 1;
208 Ok(v)
209}
210
211fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, SubscriptionError> {
212 let slice = buf
213 .get(*p..*p + 2)
214 .ok_or_else(|| SubscriptionError::Corrupt("short read (u16)".to_string()))?;
215 let arr: [u8; 2] = slice
216 .try_into()
217 .map_err(|_| SubscriptionError::Corrupt("u16 slice".to_string()))?;
218 *p += 2;
219 Ok(u16::from_le_bytes(arr))
220}
221
222fn read_u32_as_usize(buf: &[u8], p: &mut usize) -> Result<usize, SubscriptionError> {
223 let slice = buf
224 .get(*p..*p + 4)
225 .ok_or_else(|| SubscriptionError::Corrupt("short read (u32)".to_string()))?;
226 let arr: [u8; 4] = slice
227 .try_into()
228 .map_err(|_| SubscriptionError::Corrupt("u32 slice".to_string()))?;
229 *p += 4;
230 Ok(u32::from_le_bytes(arr) as usize)
231}
232
233fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, SubscriptionError> {
234 let slice = buf
235 .get(*p..*p + 8)
236 .ok_or_else(|| SubscriptionError::Corrupt("short read (u64)".to_string()))?;
237 let arr: [u8; 8] = slice
238 .try_into()
239 .map_err(|_| SubscriptionError::Corrupt("u64 slice".to_string()))?;
240 *p += 8;
241 Ok(u64::from_le_bytes(arr))
242}
243
244fn read_short_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
245 let n = read_u16(buf, p)? as usize;
246 let slice = buf.get(*p..*p + n).ok_or_else(|| {
247 SubscriptionError::Corrupt(alloc::format!("short read (short str, {n} bytes)"))
248 })?;
249 *p += n;
250 core::str::from_utf8(slice)
251 .map(ToString::to_string)
252 .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
253}
254
255fn read_long_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
256 let n = read_u32_as_usize(buf, p)?;
257 let slice = buf.get(*p..*p + n).ok_or_else(|| {
258 SubscriptionError::Corrupt(alloc::format!("short read (long str, {n} bytes)"))
259 })?;
260 *p += n;
261 core::str::from_utf8(slice)
262 .map(ToString::to_string)
263 .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 conn_str: {e}")))
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 fn mk(
271 name: &str,
272 host: &str,
273 pubs: &[&str],
274 enabled: bool,
275 pos: u64,
276 ) -> (String, Subscription) {
277 (
278 name.to_string(),
279 Subscription {
280 conn_str: alloc::format!("host=127.0.0.1 port={host}"),
281 publications: pubs.iter().map(|s| (*s).to_string()).collect(),
282 enabled,
283 last_received_pos: pos,
284 },
285 )
286 }
287
288 #[test]
289 fn empty_roundtrips() {
290 let s = Subscriptions::new();
291 let bytes = s.serialize();
292 assert_eq!(Subscriptions::deserialize(&bytes).unwrap(), s);
293 }
294
295 #[test]
296 fn single_subscription_roundtrips() {
297 let mut s = Subscriptions::new();
298 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
299 s.create(n, sub).unwrap();
300 let bytes = s.serialize();
301 let s2 = Subscriptions::deserialize(&bytes).unwrap();
302 assert_eq!(s2, s);
303 assert!(s2.contains("sub_a"));
304 }
305
306 #[test]
307 fn multi_publication_roundtrips_with_nontrivial_last_pos() {
308 let mut s = Subscriptions::new();
309 let (n, sub) = mk("sub_z", "20002", &["p1", "p2", "p3"], true, 1_234_567_890);
310 s.create(n, sub).unwrap();
311 let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
312 assert_eq!(s2, s);
313 let r = s2.get("sub_z").unwrap();
314 assert_eq!(r.publications, alloc::vec!["p1", "p2", "p3"]);
315 assert_eq!(r.last_received_pos, 1_234_567_890);
316 }
317
318 #[test]
319 fn disabled_roundtrips() {
320 let mut s = Subscriptions::new();
321 let (n, sub) = mk("sub_off", "20002", &["pub_a"], false, 42);
322 s.create(n, sub).unwrap();
323 let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
324 assert!(!s2.get("sub_off").unwrap().enabled);
325 }
326
327 #[test]
328 fn duplicate_name_errors() {
329 let mut s = Subscriptions::new();
330 let (n1, sub1) = mk("sub_a", "20002", &["pub_a"], true, 0);
331 s.create(n1, sub1).unwrap();
332 let (n2, sub2) = mk("sub_a", "20003", &["pub_b"], true, 0);
333 assert_eq!(
334 s.create(n2, sub2).unwrap_err(),
335 SubscriptionError::DuplicateName("sub_a".into())
336 );
337 }
338
339 #[test]
340 fn drop_present_and_absent() {
341 let mut s = Subscriptions::new();
342 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
343 s.create(n, sub).unwrap();
344 assert!(s.drop("sub_a"));
345 assert!(!s.drop("sub_a"));
346 assert!(!s.drop("never"));
347 }
348
349 #[test]
350 fn update_last_pos_monotone_and_absent_returns_false() {
351 let mut s = Subscriptions::new();
352 let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 100);
353 s.create(n, sub).unwrap();
354 assert!(s.update_last_received_pos("sub_a", 50)); assert_eq!(s.get("sub_a").unwrap().last_received_pos, 100);
356 assert!(s.update_last_received_pos("sub_a", 200));
357 assert_eq!(s.get("sub_a").unwrap().last_received_pos, 200);
358 assert!(!s.update_last_received_pos("missing", 1));
359 }
360
361 #[test]
362 fn corrupt_enabled_byte_errors() {
363 let mut buf = Vec::new();
365 buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&3u16.to_le_bytes());
368 buf.extend_from_slice(b"bad");
369 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u16.to_le_bytes());
373 buf.push(2);
375 buf.extend_from_slice(&0u64.to_le_bytes());
377 let err = Subscriptions::deserialize(&buf).unwrap_err();
378 assert!(matches!(err, SubscriptionError::Corrupt(_)));
379 }
380
381 #[test]
382 fn deterministic_order_independent_of_insert_sequence() {
383 let mut s1 = Subscriptions::new();
384 let (n, sub) = mk("z", "20002", &["p1"], true, 0);
385 s1.create(n, sub).unwrap();
386 let (n, sub) = mk("a", "20003", &["p2"], true, 0);
387 s1.create(n, sub).unwrap();
388 let mut s2 = Subscriptions::new();
389 let (n, sub) = mk("a", "20003", &["p2"], true, 0);
390 s2.create(n, sub).unwrap();
391 let (n, sub) = mk("z", "20002", &["p1"], true, 0);
392 s2.create(n, sub).unwrap();
393 assert_eq!(s1.serialize(), s2.serialize());
394 }
395}