Skip to main content

spg_engine/
subscriptions.rs

1// pedantic doc_markdown flags the embedded wire-format spec block
2// and a handful of proper nouns; allowing at the module level
3// keeps the spec readable.
4#![allow(clippy::doc_markdown)]
5
6//! v6.1.4 — logical-replication subscription catalog.
7//!
8//! In-memory table of subscriptions, owned by the engine. The
9//! catalog persists across restarts via the snapshot envelope's
10//! v4 trailer block (see `crate::lib::build_envelope`) — same
11//! mechanism v6.1.2 added for publications, just an extra section.
12//!
13//! Subscriptions are the receive side of logical replication. A
14//! `CreateSubscription` row holds:
15//!   - `name`              the local identifier
16//!   - `conn_str`          PG keyword=value string the worker
17//!                         parses for `host=…` and `port=…`
18//!   - `publications`      list of remote publication names
19//!   - `enabled`           v6.1.4 hard-codes to `true`; ALTER
20//!                         SUBSCRIPTION ENABLE / DISABLE lands
21//!                         in a future sub-version
22//!   - `last_received_pos` master-WAL byte offset the worker has
23//!                         applied through (updated live by the
24//!                         worker, persisted at the next snapshot)
25//!
26//! The worker itself lives in `spg-server::replication::
27//! run_subscription_worker` — the engine layer only owns the
28//! catalog state, snapshots, and answers `SHOW SUBSCRIPTIONS`.
29
30use alloc::collections::BTreeMap;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34use spg_sql::ast::CreateSubscriptionStatement;
35use spg_storage::{ColumnSchema, DataType, Row, Value};
36
37use crate::{Engine, EngineError, QueryResult};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct Subscription {
41    pub conn_str: String,
42    pub publications: Vec<String>,
43    pub enabled: bool,
44    pub last_received_pos: u64,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Default)]
48pub struct Subscriptions {
49    inner: BTreeMap<String, Subscription>,
50}
51
52#[derive(Debug, PartialEq, Eq)]
53pub enum SubscriptionError {
54    DuplicateName(String),
55    Corrupt(String),
56}
57
58impl Subscriptions {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    pub fn len(&self) -> usize {
64        self.inner.len()
65    }
66
67    pub fn is_empty(&self) -> bool {
68        self.inner.is_empty()
69    }
70
71    pub fn contains(&self, name: &str) -> bool {
72        self.inner.contains_key(name)
73    }
74
75    pub fn get(&self, name: &str) -> Option<&Subscription> {
76        self.inner.get(name)
77    }
78
79    pub fn iter(&self) -> impl Iterator<Item = (&String, &Subscription)> {
80        self.inner.iter()
81    }
82
83    pub fn create(&mut self, name: String, sub: Subscription) -> Result<(), SubscriptionError> {
84        if self.inner.contains_key(&name) {
85            return Err(SubscriptionError::DuplicateName(name));
86        }
87        self.inner.insert(name, sub);
88        Ok(())
89    }
90
91    pub fn drop(&mut self, name: &str) -> bool {
92        self.inner.remove(name).is_some()
93    }
94
95    /// v6.1.4 — update the worker's last-applied master-WAL
96    /// offset. Called by the subscription worker after each apply
97    /// batch. Returns false when the subscription was dropped
98    /// between when the worker fetched the record and when this
99    /// call landed (so the worker can shut down cleanly).
100    pub fn update_last_received_pos(&mut self, name: &str, pos: u64) -> bool {
101        if let Some(s) = self.inner.get_mut(name) {
102            // Monotone: ignore stale updates (a future restart
103            // resuming from a sidecar may send an older pos than
104            // the live worker has already passed).
105            if pos > s.last_received_pos {
106                s.last_received_pos = pos;
107            }
108            true
109        } else {
110            false
111        }
112    }
113
114    // ── serialisation (envelope v4 trailer) ─────────────────────
115
116    /// Format:
117    ///   [u16 num_subscriptions]
118    ///   for each:
119    ///     [u16 name_len][name bytes]
120    ///     [u32 conn_str_len][conn_str bytes]
121    ///     [u16 num_pubs]
122    ///     for each: [u16 p_len][p bytes]
123    ///     [u8 enabled]
124    ///     [u64 last_received_pos]
125    pub fn serialize(&self) -> Vec<u8> {
126        let mut out = Vec::with_capacity(2 + self.inner.len() * 64);
127        let n = u16::try_from(self.inner.len()).expect("≤ 65,535 subscriptions per cluster");
128        out.extend_from_slice(&n.to_le_bytes());
129        for (name, sub) in &self.inner {
130            write_short_str(&mut out, name);
131            write_long_str(&mut out, &sub.conn_str);
132            let np = u16::try_from(sub.publications.len())
133                .expect("≤ 65,535 publications per subscription");
134            out.extend_from_slice(&np.to_le_bytes());
135            for p in &sub.publications {
136                write_short_str(&mut out, p);
137            }
138            out.push(u8::from(sub.enabled));
139            out.extend_from_slice(&sub.last_received_pos.to_le_bytes());
140        }
141        out
142    }
143
144    pub fn deserialize(buf: &[u8]) -> Result<Self, SubscriptionError> {
145        let mut p = 0usize;
146        let n = read_u16(buf, &mut p)? as usize;
147        let mut inner = BTreeMap::new();
148        for _ in 0..n {
149            let name = read_short_str(buf, &mut p)?;
150            let conn_str = read_long_str(buf, &mut p)?;
151            let np = read_u16(buf, &mut p)? as usize;
152            let mut publications = Vec::with_capacity(np);
153            for _ in 0..np {
154                publications.push(read_short_str(buf, &mut p)?);
155            }
156            let enabled_byte = read_u8(buf, &mut p)?;
157            let enabled = match enabled_byte {
158                0 => false,
159                1 => true,
160                other => {
161                    return Err(SubscriptionError::Corrupt(alloc::format!(
162                        "invalid `enabled` byte {other}, expected 0 or 1"
163                    )));
164                }
165            };
166            let last_received_pos = read_u64(buf, &mut p)?;
167            if inner
168                .insert(
169                    name.clone(),
170                    Subscription {
171                        conn_str,
172                        publications,
173                        enabled,
174                        last_received_pos,
175                    },
176                )
177                .is_some()
178            {
179                return Err(SubscriptionError::Corrupt(alloc::format!(
180                    "duplicate subscription name {name:?} in serialised payload"
181                )));
182            }
183        }
184        if p != buf.len() {
185            return Err(SubscriptionError::Corrupt(alloc::format!(
186                "trailing bytes in subscriptions payload: read {p}, len {}",
187                buf.len()
188            )));
189        }
190        Ok(Self { inner })
191    }
192}
193
194fn write_short_str(out: &mut Vec<u8>, s: &str) {
195    let n = u16::try_from(s.len()).expect("subscription / publication name fits in u16");
196    out.extend_from_slice(&n.to_le_bytes());
197    out.extend_from_slice(s.as_bytes());
198}
199
200fn write_long_str(out: &mut Vec<u8>, s: &str) {
201    // conn_str may be up to a few hundred bytes; u32 keeps headroom.
202    let n = u32::try_from(s.len()).expect("conn_str fits in u32");
203    out.extend_from_slice(&n.to_le_bytes());
204    out.extend_from_slice(s.as_bytes());
205}
206
207fn read_u8(buf: &[u8], p: &mut usize) -> Result<u8, SubscriptionError> {
208    let v = buf
209        .get(*p)
210        .copied()
211        .ok_or_else(|| SubscriptionError::Corrupt("short read (u8)".to_string()))?;
212    *p += 1;
213    Ok(v)
214}
215
216fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, SubscriptionError> {
217    let slice = buf
218        .get(*p..*p + 2)
219        .ok_or_else(|| SubscriptionError::Corrupt("short read (u16)".to_string()))?;
220    let arr: [u8; 2] = slice
221        .try_into()
222        .map_err(|_| SubscriptionError::Corrupt("u16 slice".to_string()))?;
223    *p += 2;
224    Ok(u16::from_le_bytes(arr))
225}
226
227fn read_u32_as_usize(buf: &[u8], p: &mut usize) -> Result<usize, SubscriptionError> {
228    let slice = buf
229        .get(*p..*p + 4)
230        .ok_or_else(|| SubscriptionError::Corrupt("short read (u32)".to_string()))?;
231    let arr: [u8; 4] = slice
232        .try_into()
233        .map_err(|_| SubscriptionError::Corrupt("u32 slice".to_string()))?;
234    *p += 4;
235    Ok(u32::from_le_bytes(arr) as usize)
236}
237
238fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, SubscriptionError> {
239    let slice = buf
240        .get(*p..*p + 8)
241        .ok_or_else(|| SubscriptionError::Corrupt("short read (u64)".to_string()))?;
242    let arr: [u8; 8] = slice
243        .try_into()
244        .map_err(|_| SubscriptionError::Corrupt("u64 slice".to_string()))?;
245    *p += 8;
246    Ok(u64::from_le_bytes(arr))
247}
248
249fn read_short_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
250    let n = read_u16(buf, p)? as usize;
251    let slice = buf.get(*p..*p + n).ok_or_else(|| {
252        SubscriptionError::Corrupt(alloc::format!("short read (short str, {n} bytes)"))
253    })?;
254    *p += n;
255    core::str::from_utf8(slice)
256        .map(ToString::to_string)
257        .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
258}
259
260fn read_long_str(buf: &[u8], p: &mut usize) -> Result<String, SubscriptionError> {
261    let n = read_u32_as_usize(buf, p)?;
262    let slice = buf.get(*p..*p + n).ok_or_else(|| {
263        SubscriptionError::Corrupt(alloc::format!("short read (long str, {n} bytes)"))
264    })?;
265    *p += n;
266    core::str::from_utf8(slice)
267        .map(ToString::to_string)
268        .map_err(|e| SubscriptionError::Corrupt(alloc::format!("non-UTF-8 conn_str: {e}")))
269}
270
271impl Engine {
272    /// v6.1.4 — `SHOW SUBSCRIPTIONS` row materialisation. Returns
273    /// `(name, conn_str, publications, enabled, last_received_pos)`
274    /// ordered by subscription name. The `publications` column is
275    /// the comma-joined list ("p1, p2") for ergonomic SHOW output;
276    /// callers wanting structured access read `Engine::subscriptions`.
277    pub(crate) fn exec_show_subscriptions(&self) -> QueryResult {
278        let columns = alloc::vec![
279            ColumnSchema::new("name", DataType::Text, false),
280            ColumnSchema::new("conn_str", DataType::Text, false),
281            ColumnSchema::new("publications", DataType::Text, false),
282            ColumnSchema::new("enabled", DataType::Bool, false),
283            ColumnSchema::new("last_received_pos", DataType::BigInt, false),
284        ];
285        let rows: Vec<Row> = self
286            .subscriptions
287            .iter()
288            .map(|(name, sub)| {
289                Row::new(alloc::vec![
290                    Value::Text(name.clone()),
291                    Value::Text(sub.conn_str.clone()),
292                    Value::Text(sub.publications.join(", ")),
293                    Value::Bool(sub.enabled),
294                    Value::BigInt(i64::try_from(sub.last_received_pos).unwrap_or(i64::MAX)),
295                ])
296            })
297            .collect();
298        QueryResult::Rows { columns, rows }
299    }
300
301    /// v6.1.4 — `CREATE SUBSCRIPTION` runtime path. Defaults
302    /// `enabled = true` and `last_received_pos = 0` for a freshly-
303    /// created subscription. The actual worker thread is spawned
304    /// by spg-server once the engine returns success.
305    pub(crate) fn exec_create_subscription(
306        &mut self,
307        s: CreateSubscriptionStatement,
308    ) -> Result<QueryResult, EngineError> {
309        // See exec_create_publication — the in_transaction gate
310        // was over-cautious; the auto-commit wrap path holds an
311        // internal TX that this check was incorrectly blocking.
312        let sub = Subscription {
313            conn_str: s.conn_str,
314            publications: s.publications,
315            enabled: true,
316            last_received_pos: 0,
317        };
318        self.subscriptions
319            .create(s.name, sub)
320            .map_err(|e| EngineError::Unsupported(alloc::format!("CREATE SUBSCRIPTION: {e:?}")))?;
321        Ok(QueryResult::CommandOk {
322            affected: 1,
323            modified_catalog: true,
324        })
325    }
326
327    /// v6.1.4 — `DROP SUBSCRIPTION`. Silent no-op when the name
328    /// doesn't exist (PG-compatible). The associated worker is
329    /// torn down by spg-server when it observes the catalog
330    /// change at the next snapshot or via the engine's
331    /// subscriptions accessor (the worker polls the catalog on
332    /// reconnect; v6.1.5's filter-side will tighten this to an
333    /// explicit signal).
334    pub(crate) fn exec_drop_subscription(
335        &mut self,
336        name: &str,
337    ) -> Result<QueryResult, EngineError> {
338        let removed = self.subscriptions.drop(name);
339        Ok(QueryResult::CommandOk {
340            affected: usize::from(removed),
341            modified_catalog: removed,
342        })
343    }
344
345    /// v6.1.4 — read access to the subscription catalog. Used by
346    /// the subscription worker (read its own row to find its
347    /// publications + last applied position), by SHOW SUBSCRIPTIONS,
348    /// and by e2e tests asserting state directly.
349    pub const fn subscriptions(&self) -> &Subscriptions {
350        &self.subscriptions
351    }
352
353    /// v6.1.4 — write access to `last_received_pos`. Worker
354    /// calls this after each apply batch (under the engine's
355    /// write-lock). Returns `false` when the subscription was
356    /// dropped between when the worker received the record and
357    /// when this call landed.
358    pub fn subscription_advance(&mut self, name: &str, pos: u64) -> bool {
359        self.subscriptions.update_last_received_pos(name, pos)
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    fn mk(
368        name: &str,
369        host: &str,
370        pubs: &[&str],
371        enabled: bool,
372        pos: u64,
373    ) -> (String, Subscription) {
374        (
375            name.to_string(),
376            Subscription {
377                conn_str: alloc::format!("host=127.0.0.1 port={host}"),
378                publications: pubs.iter().map(|s| (*s).to_string()).collect(),
379                enabled,
380                last_received_pos: pos,
381            },
382        )
383    }
384
385    #[test]
386    fn empty_roundtrips() {
387        let s = Subscriptions::new();
388        let bytes = s.serialize();
389        assert_eq!(Subscriptions::deserialize(&bytes).unwrap(), s);
390    }
391
392    #[test]
393    fn single_subscription_roundtrips() {
394        let mut s = Subscriptions::new();
395        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
396        s.create(n, sub).unwrap();
397        let bytes = s.serialize();
398        let s2 = Subscriptions::deserialize(&bytes).unwrap();
399        assert_eq!(s2, s);
400        assert!(s2.contains("sub_a"));
401    }
402
403    #[test]
404    fn multi_publication_roundtrips_with_nontrivial_last_pos() {
405        let mut s = Subscriptions::new();
406        let (n, sub) = mk("sub_z", "20002", &["p1", "p2", "p3"], true, 1_234_567_890);
407        s.create(n, sub).unwrap();
408        let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
409        assert_eq!(s2, s);
410        let r = s2.get("sub_z").unwrap();
411        assert_eq!(r.publications, alloc::vec!["p1", "p2", "p3"]);
412        assert_eq!(r.last_received_pos, 1_234_567_890);
413    }
414
415    #[test]
416    fn disabled_roundtrips() {
417        let mut s = Subscriptions::new();
418        let (n, sub) = mk("sub_off", "20002", &["pub_a"], false, 42);
419        s.create(n, sub).unwrap();
420        let s2 = Subscriptions::deserialize(&s.serialize()).unwrap();
421        assert!(!s2.get("sub_off").unwrap().enabled);
422    }
423
424    #[test]
425    fn duplicate_name_errors() {
426        let mut s = Subscriptions::new();
427        let (n1, sub1) = mk("sub_a", "20002", &["pub_a"], true, 0);
428        s.create(n1, sub1).unwrap();
429        let (n2, sub2) = mk("sub_a", "20003", &["pub_b"], true, 0);
430        assert_eq!(
431            s.create(n2, sub2).unwrap_err(),
432            SubscriptionError::DuplicateName("sub_a".into())
433        );
434    }
435
436    #[test]
437    fn drop_present_and_absent() {
438        let mut s = Subscriptions::new();
439        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 0);
440        s.create(n, sub).unwrap();
441        assert!(s.drop("sub_a"));
442        assert!(!s.drop("sub_a"));
443        assert!(!s.drop("never"));
444    }
445
446    #[test]
447    fn update_last_pos_monotone_and_absent_returns_false() {
448        let mut s = Subscriptions::new();
449        let (n, sub) = mk("sub_a", "20002", &["pub_a"], true, 100);
450        s.create(n, sub).unwrap();
451        assert!(s.update_last_received_pos("sub_a", 50)); // ignored (older)
452        assert_eq!(s.get("sub_a").unwrap().last_received_pos, 100);
453        assert!(s.update_last_received_pos("sub_a", 200));
454        assert_eq!(s.get("sub_a").unwrap().last_received_pos, 200);
455        assert!(!s.update_last_received_pos("missing", 1));
456    }
457
458    #[test]
459    fn corrupt_enabled_byte_errors() {
460        // Forge a payload with an invalid enabled byte (2).
461        let mut buf = Vec::new();
462        buf.extend_from_slice(&1u16.to_le_bytes()); // n = 1
463        // name
464        buf.extend_from_slice(&3u16.to_le_bytes());
465        buf.extend_from_slice(b"bad");
466        // conn_str
467        buf.extend_from_slice(&0u32.to_le_bytes()); // empty
468        // pubs (zero)
469        buf.extend_from_slice(&0u16.to_le_bytes());
470        // bogus enabled
471        buf.push(2);
472        // last_received_pos
473        buf.extend_from_slice(&0u64.to_le_bytes());
474        let err = Subscriptions::deserialize(&buf).unwrap_err();
475        assert!(matches!(err, SubscriptionError::Corrupt(_)));
476    }
477
478    #[test]
479    fn deterministic_order_independent_of_insert_sequence() {
480        let mut s1 = Subscriptions::new();
481        let (n, sub) = mk("z", "20002", &["p1"], true, 0);
482        s1.create(n, sub).unwrap();
483        let (n, sub) = mk("a", "20003", &["p2"], true, 0);
484        s1.create(n, sub).unwrap();
485        let mut s2 = Subscriptions::new();
486        let (n, sub) = mk("a", "20003", &["p2"], true, 0);
487        s2.create(n, sub).unwrap();
488        let (n, sub) = mk("z", "20002", &["p1"], true, 0);
489        s2.create(n, sub).unwrap();
490        assert_eq!(s1.serialize(), s2.serialize());
491    }
492}