Skip to main content

spg_engine/
subscriptions.rs

1// pedantic doc_markdown flags the embedded wire-format spec block
2// and a handful of proper nouns; allowing at the module level
3// keeps the spec readable.
4#![allow(clippy::doc_markdown)]
5
6//! v6.1.4 — logical-replication subscription catalog.
7//!
8//! In-memory table of subscriptions, owned by the engine. The
9//! catalog persists across restarts via the snapshot envelope's
10//! v4 trailer block (see `crate::lib::build_envelope`) — same
11//! mechanism v6.1.2 added for publications, just an extra section.
12//!
13//! Subscriptions are the receive side of logical replication. A
14//! `CreateSubscription` row holds:
15//!   - `name`              the local identifier
16//!   - `conn_str`          PG keyword=value string the worker
17//!                         parses for `host=…` and `port=…`
18//!   - `publications`      list of remote publication names
19//!   - `enabled`           v6.1.4 hard-codes to `true`; ALTER
20//!                         SUBSCRIPTION ENABLE / DISABLE lands
21//!                         in a future sub-version
22//!   - `last_received_pos` master-WAL byte offset the worker has
23//!                         applied through (updated live by the
24//!                         worker, persisted at the next snapshot)
25//!
26//! The worker itself lives in `spg-server::replication::
27//! run_subscription_worker` — the engine layer only owns the
28//! catalog state, snapshots, and answers `SHOW SUBSCRIPTIONS`.
29
30use alloc::collections::BTreeMap;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct Subscription {
36    pub conn_str: String,
37    pub publications: Vec<String>,
38    pub enabled: bool,
39    pub last_received_pos: u64,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Default)]
43pub struct Subscriptions {
44    inner: BTreeMap<String, Subscription>,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48pub enum SubscriptionError {
49    DuplicateName(String),
50    Corrupt(String),
51}
52
53impl Subscriptions {
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    pub fn len(&self) -> usize {
59        self.inner.len()
60    }
61
62    pub fn is_empty(&self) -> bool {
63        self.inner.is_empty()
64    }
65
66    pub fn contains(&self, name: &str) -> bool {
67        self.inner.contains_key(name)
68    }
69
70    pub fn get(&self, name: &str) -> Option<&Subscription> {
71        self.inner.get(name)
72    }
73
74    pub fn iter(&self) -> impl Iterator<Item = (&String, &Subscription)> {
75        self.inner.iter()
76    }
77
78    pub fn create(&mut self, name: String, sub: Subscription) -> Result<(), SubscriptionError> {
79        if self.inner.contains_key(&name) {
80            return Err(SubscriptionError::DuplicateName(name));
81        }
82        self.inner.insert(name, sub);
83        Ok(())
84    }
85
86    pub fn drop(&mut self, name: &str) -> bool {
87        self.inner.remove(name).is_some()
88    }
89
90    /// v6.1.4 — update the worker's last-applied master-WAL
91    /// offset. Called by the subscription worker after each apply
92    /// batch. Returns false when the subscription was dropped
93    /// between when the worker fetched the record and when this
94    /// call landed (so the worker can shut down cleanly).
95    pub fn update_last_received_pos(&mut self, name: &str, pos: u64) -> bool {
96        if let Some(s) = self.inner.get_mut(name) {
97            // Monotone: ignore stale updates (a future restart
98            // resuming from a sidecar may send an older pos than
99            // the live worker has already passed).
100            if pos > s.last_received_pos {
101                s.last_received_pos = pos;
102            }
103            true
104        } else {
105            false
106        }
107    }
108
109    // ── serialisation (envelope v4 trailer) ─────────────────────
110
111    /// Format:
112    ///   [u16 num_subscriptions]
113    ///   for each:
114    ///     [u16 name_len][name bytes]
115    ///     [u32 conn_str_len][conn_str bytes]
116    ///     [u16 num_pubs]
117    ///     for each: [u16 p_len][p bytes]
118    ///     [u8 enabled]
119    ///     [u64 last_received_pos]
120    pub fn serialize(&self) -> Vec<u8> {
121        let mut out = Vec::with_capacity(2 + self.inner.len() * 64);
122        let n = u16::try_from(self.inner.len()).expect("≤ 65,535 subscriptions per cluster");
123        out.extend_from_slice(&n.to_le_bytes());
124        for (name, sub) in &self.inner {
125            write_short_str(&mut out, name);
126            write_long_str(&mut out, &sub.conn_str);
127            let np = u16::try_from(sub.publications.len())
128                .expect("≤ 65,535 publications per subscription");
129            out.extend_from_slice(&np.to_le_bytes());
130            for p in &sub.publications {
131                write_short_str(&mut out, p);
132            }
133            out.push(u8::from(sub.enabled));
134            out.extend_from_slice(&sub.last_received_pos.to_le_bytes());
135        }
136        out
137    }
138
139    pub fn deserialize(buf: &[u8]) -> Result<Self, SubscriptionError> {
140        let mut p = 0usize;
141        let n = read_u16(buf, &mut p)? as usize;
142        let mut inner = BTreeMap::new();
143        for _ in 0..n {
144            let name = read_short_str(buf, &mut p)?;
145            let conn_str = read_long_str(buf, &mut p)?;
146            let np = read_u16(buf, &mut p)? as usize;
147            let mut publications = Vec::with_capacity(np);
148            for _ in 0..np {
149                publications.push(read_short_str(buf, &mut p)?);
150            }
151            let enabled_byte = read_u8(buf, &mut p)?;
152            let enabled = match enabled_byte {
153                0 => false,
154                1 => true,
155                other => {
156                    return Err(SubscriptionError::Corrupt(alloc::format!(
157                        "invalid `enabled` byte {other}, expected 0 or 1"
158                    )));
159                }
160            };
161            let last_received_pos = read_u64(buf, &mut p)?;
162            if inner
163                .insert(
164                    name.clone(),
165                    Subscription {
166                        conn_str,
167                        publications,
168                        enabled,
169                        last_received_pos,
170                    },
171                )
172                .is_some()
173            {
174                return Err(SubscriptionError::Corrupt(alloc::format!(
175                    "duplicate subscription name {name:?} in serialised payload"
176                )));
177            }
178        }
179        if p != buf.len() {
180            return Err(SubscriptionError::Corrupt(alloc::format!(
181                "trailing bytes in subscriptions payload: read {p}, len {}",
182                buf.len()
183            )));
184        }
185        Ok(Self { inner })
186    }
187}
188
189fn write_short_str(out: &mut Vec<u8>, s: &str) {
190    let n = u16::try_from(s.len()).expect("subscription / publication name fits in u16");
191    out.extend_from_slice(&n.to_le_bytes());
192    out.extend_from_slice(s.as_bytes());
193}
194
195fn write_long_str(out: &mut Vec<u8>, s: &str) {
196    // conn_str may be up to a few hundred bytes; u32 keeps headroom.
197    let n = u32::try_from(s.len()).expect("conn_str fits in u32");
198    out.extend_from_slice(&n.to_le_bytes());
199    out.extend_from_slice(s.as_bytes());
200}
201
202fn read_u8(buf: &[u8], p: &mut usize) -> Result<u8, SubscriptionError> {
203    let v = buf
204        .get(*p)
205        .copied()
206        .ok_or_else(|| SubscriptionError::Corrupt("short read (u8)".to_string()))?;
207    *p += 1;
208    Ok(v)
209}
210
211fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, SubscriptionError> {
212    let slice = buf
213        .get(*p..*p + 2)
214        .ok_or_else(|| SubscriptionError::Corrupt("short read (u16)".to_string()))?;
215    let arr: [u8; 2] = slice
216        .try_into()
217        .map_err(|_| SubscriptionError::Corrupt("u16 slice".to_string()))?;
218    *p += 2;
219    Ok(u16::from_le_bytes(arr))
220}
221
222fn read_u32_as_usize(buf: &[u8], p: &mut usize) -> Result<usize, SubscriptionError> {
223    let slice = buf
224        .get(*p..*p + 4)
225        .ok_or_else(|| SubscriptionError::Corrupt("short read (u32)".to_string()))?;
226    let arr: [u8; 4] = slice
227        .try_into()
228        .map_err(|_| SubscriptionError::Corrupt("u32 slice".to_string()))?;
229    *p += 4;
230    Ok(u32::from_le_bytes(arr) as usize)
231}
232
233fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, SubscriptionError> {
234    let slice = buf
235        .get(*p..*p + 8)
236        .ok_or_else(|| SubscriptionError::Corrupt("short read (u64)".to_string()))?;
237    let arr: [u8; 8] = slice
238        .try_into()
239        .map_err(|_| SubscriptionError::Corrupt("u64 slice".to_string()))?;
240    *p += 8;
241    Ok(u64::from_le_bytes(arr))
242}
243
244fn read_short_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
245    let n = read_u16(buf, p)? as usize;
246    let slice = buf.get(*p..*p + n).ok_or_else(|| {
247        SubscriptionError::Corrupt(alloc::format!("short read (short str, {n} bytes)"))
248    })?;
249    *p += n;
250    core::str::from_utf8(slice)
251        .map(ToString::to_string)
252        .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
253}
254
255fn read_long_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
256    let n = read_u32_as_usize(buf, p)?;
257    let slice = buf.get(*p..*p + n).ok_or_else(|| {
258        SubscriptionError::Corrupt(alloc::format!("short read (long str, {n} bytes)"))
259    })?;
260    *p += n;
261    core::str::from_utf8(slice)
262        .map(ToString::to_string)
263        .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 conn_str: {e}")))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    fn mk(
271        name: &str,
272        host: &str,
273        pubs: &[&str],
274        enabled: bool,
275        pos: u64,
276    ) -> (String, Subscription) {
277        (
278            name.to_string(),
279            Subscription {
280                conn_str: alloc::format!("host=127.0.0.1 port={host}"),
281                publications: pubs.iter().map(|s| (*s).to_string()).collect(),
282                enabled,
283                last_received_pos: pos,
284            },
285        )
286    }
287
288    #[test]
289    fn empty_roundtrips() {
290        let s = Subscriptions::new();
291        let bytes = s.serialize();
292        assert_eq!(Subscriptions::deserialize(&bytes).unwrap(), s);
293    }
294
295    #[test]
296    fn single_subscription_roundtrips() {
297        let mut s = Subscriptions::new();
298        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
299        s.create(n, sub).unwrap();
300        let bytes = s.serialize();
301        let s2 = Subscriptions::deserialize(&bytes).unwrap();
302        assert_eq!(s2, s);
303        assert!(s2.contains("sub_a"));
304    }
305
306    #[test]
307    fn multi_publication_roundtrips_with_nontrivial_last_pos() {
308        let mut s = Subscriptions::new();
309        let (n, sub) = mk("sub_z", "20002", &["p1", "p2", "p3"], true, 1_234_567_890);
310        s.create(n, sub).unwrap();
311        let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
312        assert_eq!(s2, s);
313        let r = s2.get("sub_z").unwrap();
314        assert_eq!(r.publications, alloc::vec!["p1", "p2", "p3"]);
315        assert_eq!(r.last_received_pos, 1_234_567_890);
316    }
317
318    #[test]
319    fn disabled_roundtrips() {
320        let mut s = Subscriptions::new();
321        let (n, sub) = mk("sub_off", "20002", &["pub_a"], false, 42);
322        s.create(n, sub).unwrap();
323        let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
324        assert!(!s2.get("sub_off").unwrap().enabled);
325    }
326
327    #[test]
328    fn duplicate_name_errors() {
329        let mut s = Subscriptions::new();
330        let (n1, sub1) = mk("sub_a", "20002", &["pub_a"], true, 0);
331        s.create(n1, sub1).unwrap();
332        let (n2, sub2) = mk("sub_a", "20003", &["pub_b"], true, 0);
333        assert_eq!(
334            s.create(n2, sub2).unwrap_err(),
335            SubscriptionError::DuplicateName("sub_a".into())
336        );
337    }
338
339    #[test]
340    fn drop_present_and_absent() {
341        let mut s = Subscriptions::new();
342        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
343        s.create(n, sub).unwrap();
344        assert!(s.drop("sub_a"));
345        assert!(!s.drop("sub_a"));
346        assert!(!s.drop("never"));
347    }
348
349    #[test]
350    fn update_last_pos_monotone_and_absent_returns_false() {
351        let mut s = Subscriptions::new();
352        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 100);
353        s.create(n, sub).unwrap();
354        assert!(s.update_last_received_pos("sub_a", 50)); // ignored (older)
355        assert_eq!(s.get("sub_a").unwrap().last_received_pos, 100);
356        assert!(s.update_last_received_pos("sub_a", 200));
357        assert_eq!(s.get("sub_a").unwrap().last_received_pos, 200);
358        assert!(!s.update_last_received_pos("missing", 1));
359    }
360
361    #[test]
362    fn corrupt_enabled_byte_errors() {
363        // Forge a payload with an invalid enabled byte (2).
364        let mut buf = Vec::new();
365        buf.extend_from_slice(&1u16.to_le_bytes()); // n = 1
366        // name
367        buf.extend_from_slice(&3u16.to_le_bytes());
368        buf.extend_from_slice(b"bad");
369        // conn_str
370        buf.extend_from_slice(&0u32.to_le_bytes()); // empty
371        // pubs (zero)
372        buf.extend_from_slice(&0u16.to_le_bytes());
373        // bogus enabled
374        buf.push(2);
375        // last_received_pos
376        buf.extend_from_slice(&0u64.to_le_bytes());
377        let err = Subscriptions::deserialize(&buf).unwrap_err();
378        assert!(matches!(err, SubscriptionError::Corrupt(_)));
379    }
380
381    #[test]
382    fn deterministic_order_independent_of_insert_sequence() {
383        let mut s1 = Subscriptions::new();
384        let (n, sub) = mk("z", "20002", &["p1"], true, 0);
385        s1.create(n, sub).unwrap();
386        let (n, sub) = mk("a", "20003", &["p2"], true, 0);
387        s1.create(n, sub).unwrap();
388        let mut s2 = Subscriptions::new();
389        let (n, sub) = mk("a", "20003", &["p2"], true, 0);
390        s2.create(n, sub).unwrap();
391        let (n, sub) = mk("z", "20002", &["p1"], true, 0);
392        s2.create(n, sub).unwrap();
393        assert_eq!(s1.serialize(), s2.serialize());
394    }
395}