Skip to main content

uni_plugin/
reload.rs

1//! Per-kind hot-reload discipline orchestration.
2//!
3//! [`ReloadDispatcher`] is invoked between the *drain* and *commit*
4//! phases of `Uni::reload`. By the time it runs, the old plugin has
5//! already been removed from the registry (so new captures cannot see
6//! its surfaces), but in-flight queries that captured `Arc<dyn Foo>`
7//! before the swap still operate against the old instances.
8//!
9//! The dispatcher's job is to run the per-kind handoff each surface
10//! needs **before** committing the new plugin's registrations. The
11//! handoffs are spelled out below:
12//!
13//! | Surface                 | Discipline                                         |
14//! |-------------------------|----------------------------------------------------|
15//! | Scalar / aggregate / …  | Clean — no protocol step needed.                   |
16//! | `StorageBackend`        | Clean — old `Storage` continues until drained.     |
17//! | `IndexKindProvider`     | `persist()` on old handle → `open()` on new.       |
18//! | `BackgroundJobProvider` | Clean — next tick picks up the new provider.       |
19//! | `CdcOutputProvider`     | `checkpoint()` on old → `start(lsn)` on new.       |
20//! | `CrdtKindProvider`      | Schema-compat round-trip — hard error on mismatch. |
21//! | `LogicalTypeProvider`   | Arrow extension contract unchanged — hard error.   |
22//!
23//! Per-kind handoffs that the trait surface already exposes (e.g.,
24//! `CdcStream::checkpoint`) are invoked directly. Surfaces that need
25//! a richer contract (CRDT / logical-type schema-compat) get a
26//! default-method on the trait (`schema_compat_check`, `compat_check`)
27//! that providers can override.
28//!
29//! Stateful surfaces with **in-flight private resources** (live
30//! `IndexHandle`s, open `CdcStream`s) are reload-managed by the host
31//! that owns those resources — the dispatcher receives them through
32//! the [`ReloadKindHandlers`] builder rather than by registry walk,
33//! because the registry only tracks *providers*, not the per-instance
34//! resources those providers spawn.
35
36use std::sync::Arc;
37
38use crate::errors::{FnError, ReloadError};
39use crate::registry::{PluginRecordSnapshot, PluginRegistry};
40use crate::traits::cdc::{CdcLsn, CdcOutputProvider, CdcStartContext, CdcStream};
41use crate::traits::crdt::CrdtKindProvider;
42use crate::traits::index::{IndexHandle, IndexKindProvider};
43use crate::traits::types::LogicalTypeProvider;
44
45/// Host-supplied handlers wiring per-kind in-flight resources into the
46/// reload pipeline.
47///
48/// The registry tracks providers, not the per-instance resources those
49/// providers spawn (open index handles, live CDC streams). The host
50/// owns those resources and supplies them through this builder so the
51/// dispatcher can persist / checkpoint them at the right moment.
52#[derive(Default)]
53pub struct ReloadKindHandlers {
54    /// Live index handles to persist and reopen against the new provider.
55    pub index_handles: Vec<IndexHandoff>,
56    /// Live CDC streams to checkpoint and restart against the new provider.
57    pub cdc_streams: Vec<CdcHandoff>,
58}
59
60impl std::fmt::Debug for ReloadKindHandlers {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("ReloadKindHandlers")
63            .field("index_handles", &self.index_handles.len())
64            .field("cdc_streams", &self.cdc_streams.len())
65            .finish()
66    }
67}
68
69/// One live index handle and the new provider that will reopen it.
70pub struct IndexHandoff {
71    /// Diagnostic name for the index (typically the registry key).
72    pub name: String,
73    /// The live, in-flight index handle owned by the old plugin.
74    pub old: Box<dyn IndexHandle>,
75    /// The new plugin's provider that will reopen the persisted bytes.
76    pub new: Arc<dyn IndexKindProvider>,
77}
78
79impl std::fmt::Debug for IndexHandoff {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("IndexHandoff")
82            .field("name", &self.name)
83            .finish_non_exhaustive()
84    }
85}
86
87/// One live CDC stream and the new provider that will resume it.
88pub struct CdcHandoff {
89    /// Diagnostic name for the stream (typically the registry key).
90    pub name: String,
91    /// The live CDC stream owned by the old plugin.
92    pub old: Box<dyn CdcStream>,
93    /// The new plugin's provider that will start a fresh stream at the
94    /// checkpointed LSN.
95    pub new: Arc<dyn CdcOutputProvider>,
96}
97
98impl std::fmt::Debug for CdcHandoff {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("CdcHandoff")
101            .field("name", &self.name)
102            .finish_non_exhaustive()
103    }
104}
105
106/// The outcome of a successful reload — opaque container for any
107/// new in-flight resources the dispatcher constructed (reopened index
108/// handles, restarted CDC streams) so the host can re-attach them.
109#[derive(Default)]
110pub struct ReloadOutcome {
111    /// Reopened index handles, paired with their registry name.
112    pub index_handles: Vec<(String, Box<dyn IndexHandle>)>,
113    /// Restarted CDC streams, paired with their registry name.
114    pub cdc_streams: Vec<(String, Box<dyn CdcStream>)>,
115}
116
117impl std::fmt::Debug for ReloadOutcome {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("ReloadOutcome")
120            .field("index_handles", &self.index_handles.len())
121            .field("cdc_streams", &self.cdc_streams.len())
122            .finish()
123    }
124}
125
126/// Orchestrates per-kind reload discipline between drain and commit.
127///
128/// Construct with [`ReloadDispatcher::new`], optionally populate live
129/// resources via [`ReloadKindHandlers`], then call [`Self::dispatch`].
130/// Failures abort the reload before the new plugin's registrations
131/// commit.
132#[derive(Debug)]
133pub struct ReloadDispatcher<'a> {
134    /// Snapshot of the old plugin's registry footprint.
135    old: &'a PluginRecordSnapshot,
136    /// The *new* registry view — already updated to point at the new
137    /// plugin's providers for any surfaces both registered.
138    new_registry: &'a PluginRegistry,
139    /// Optional live-resource handoffs.
140    handlers: ReloadKindHandlers,
141}
142
143impl<'a> ReloadDispatcher<'a> {
144    /// Construct a dispatcher over the old plugin's snapshot and the
145    /// new plugin's already-committed surface registry.
146    #[must_use]
147    pub fn new(old: &'a PluginRecordSnapshot, new_registry: &'a PluginRegistry) -> Self {
148        Self {
149            old,
150            new_registry,
151            handlers: ReloadKindHandlers::default(),
152        }
153    }
154
155    /// Attach per-kind live-resource handoffs.
156    #[must_use]
157    pub fn with_handlers(mut self, handlers: ReloadKindHandlers) -> Self {
158        self.handlers = handlers;
159        self
160    }
161
162    /// Pre-flight check: run schema-compat checks for CRDT kinds and
163    /// logical types that both old and new plugin register.
164    ///
165    /// `old_providers` supplies the *pre-swap* views of the providers
166    /// the old plugin owned. The dispatcher cannot recover those from
167    /// the registry once the swap has happened, so the host snapshots
168    /// them immediately before evicting the old plugin.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`ReloadError::SchemaIncompat`] when any pair fails its
173    /// compat check.
174    pub fn check_compat(&self, old_providers: &OldProviders) -> Result<(), ReloadError> {
175        for kind in &self.old.crdt_kinds {
176            let Some(old) = old_providers.crdt_kinds.get(kind) else {
177                continue;
178            };
179            let Some(new) = self.new_registry.crdt_kind(kind) else {
180                // New plugin did not re-register this CRDT kind — that
181                // is a plain removal, not an incompat reload.
182                continue;
183            };
184            new.schema_compat_check(old.as_ref())
185                .map_err(|e: FnError| {
186                    ReloadError::schema_incompat(format!("crdt:{}", kind.0), e.message)
187                })?;
188        }
189        for name in &old_providers.logical_type_names {
190            let Some(old) = old_providers.logical_types.get(name) else {
191                continue;
192            };
193            let Some(new) = self.new_registry.logical_type(name) else {
194                continue;
195            };
196            new.compat_check(old.as_ref()).map_err(|e: FnError| {
197                ReloadError::schema_incompat(format!("logical-type:{name}"), e.message)
198            })?;
199        }
200        Ok(())
201    }
202
203    /// Drive the per-instance handoffs: persist & reopen index handles;
204    /// checkpoint & restart CDC streams.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`ReloadError::Persist`] if a handoff fails. The old
209    /// resources are dropped on failure — the host must be prepared to
210    /// surface the failure and continue serving against the new
211    /// providers' freshly-initialized resources.
212    pub fn dispatch(mut self) -> Result<ReloadOutcome, ReloadError> {
213        let mut outcome = ReloadOutcome::default();
214        for handoff in self.handlers.index_handles.drain(..) {
215            let bytes = handoff.old.persist().map_err(ReloadError::Persist)?;
216            // Drop the old handle now — the new one is about to take its
217            // place. RAII closes any underlying resources (mmaps, file
218            // handles) the old handle owned.
219            drop(handoff.old);
220            let reopened = handoff.new.open(&bytes).map_err(ReloadError::Persist)?;
221            outcome.index_handles.push((handoff.name, reopened));
222        }
223        for mut handoff in self.handlers.cdc_streams.drain(..) {
224            let lsn: CdcLsn = handoff.old.checkpoint().map_err(ReloadError::Persist)?;
225            // Best-effort shutdown of the old stream; surface failure as
226            // Persist (the spec treats shutdown failure as fatal).
227            handoff.old.shutdown().map_err(ReloadError::Persist)?;
228            drop(handoff.old);
229            let resumed = handoff
230                .new
231                .start(CdcStartContext::new(Some(lsn)))
232                .map_err(ReloadError::Persist)?;
233            outcome.cdc_streams.push((handoff.name, resumed));
234        }
235        Ok(outcome)
236    }
237}
238
239/// Pre-swap snapshot of the old plugin's stateful providers.
240///
241/// The host populates this immediately before evicting the old plugin
242/// from the registry so the dispatcher's schema-compat checks have
243/// the old providers to compare against. The vectors are keyed by
244/// the same names the registry uses (`CrdtKind` for CRDTs, extension
245/// `name()` for logical types).
246#[derive(Default)]
247pub struct OldProviders {
248    /// CRDT kind providers the old plugin owned, keyed by kind.
249    pub crdt_kinds:
250        std::collections::HashMap<crate::traits::crdt::CrdtKind, Arc<dyn CrdtKindProvider>>,
251    /// Names of logical types the old plugin owned (preserves order).
252    pub logical_type_names: Vec<smol_str::SmolStr>,
253    /// Logical type providers keyed by extension name.
254    pub logical_types: std::collections::HashMap<smol_str::SmolStr, Arc<dyn LogicalTypeProvider>>,
255}
256
257impl std::fmt::Debug for OldProviders {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("OldProviders")
260            .field("crdt_kinds", &self.crdt_kinds.len())
261            .field("logical_types", &self.logical_types.len())
262            .finish()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::traits::crdt::{CrdtKind, CrdtOp, CrdtState};
270    use datafusion::scalar::ScalarValue;
271
272    // ── Test fixtures ───────────────────────────────────────────────
273
274    #[derive(Default)]
275    struct CountState {
276        v: i64,
277    }
278
279    impl CrdtState for CountState {
280        fn as_any(&self) -> &dyn std::any::Any {
281            self
282        }
283        fn apply(&mut self, op: &CrdtOp) -> Result<(), FnError> {
284            self.v += op.bytes.len() as i64;
285            Ok(())
286        }
287        fn merge(&mut self, other: &dyn CrdtState) -> Result<(), FnError> {
288            let other = other
289                .as_any()
290                .downcast_ref::<CountState>()
291                .ok_or_else(|| FnError::new(0x100, "merge: wrong state type"))?;
292            if other.v > self.v {
293                self.v = other.v;
294            }
295            Ok(())
296        }
297        fn value(&self) -> Result<ScalarValue, FnError> {
298            Ok(ScalarValue::Int64(Some(self.v)))
299        }
300        fn persist(&self) -> Result<Vec<u8>, FnError> {
301            Ok(self.v.to_le_bytes().to_vec())
302        }
303    }
304
305    struct CountProvider {
306        kind_str: &'static str,
307    }
308
309    impl CrdtKindProvider for CountProvider {
310        fn kind(&self) -> CrdtKind {
311            CrdtKind::new(self.kind_str)
312        }
313        fn empty(&self) -> Box<dyn CrdtState> {
314            Box::new(CountState::default())
315        }
316        fn from_persisted(&self, bytes: &[u8]) -> Result<Box<dyn CrdtState>, FnError> {
317            if bytes.len() != 8 {
318                return Err(FnError::new(
319                    0x101,
320                    format!("expected 8 bytes, got {}", bytes.len()),
321                ));
322            }
323            let mut arr = [0u8; 8];
324            arr.copy_from_slice(bytes);
325            Ok(Box::new(CountState {
326                v: i64::from_le_bytes(arr),
327            }))
328        }
329    }
330
331    struct RejectingProvider;
332
333    impl CrdtKindProvider for RejectingProvider {
334        fn kind(&self) -> CrdtKind {
335            CrdtKind::new("count")
336        }
337        fn empty(&self) -> Box<dyn CrdtState> {
338            Box::new(CountState::default())
339        }
340        fn from_persisted(&self, _bytes: &[u8]) -> Result<Box<dyn CrdtState>, FnError> {
341            Err(FnError::new(0x102, "rejecting all persisted bytes"))
342        }
343    }
344
345    // ── Tests ───────────────────────────────────────────────────────
346
347    #[test]
348    fn schema_compat_accepts_round_trip() {
349        let old = CountProvider { kind_str: "count" };
350        let new = CountProvider { kind_str: "count" };
351        new.schema_compat_check(&old).expect("compatible");
352    }
353
354    #[test]
355    fn schema_compat_rejects_incompatible_round_trip() {
356        let old = CountProvider { kind_str: "count" };
357        let new = RejectingProvider;
358        let err = new.schema_compat_check(&old).unwrap_err();
359        assert!(err.message.contains("rejecting"));
360    }
361
362    #[test]
363    fn dispatcher_check_compat_passes_when_all_round_trip() {
364        let registry = PluginRegistry::new();
365        // Manually drop a provider into the *new* registry's crdt_kinds.
366        // We use a Helper to bypass the registrar; this is test-only.
367        let snap = PluginRecordSnapshot {
368            crdt_kinds: vec![CrdtKind::new("count")],
369            ..Default::default()
370        };
371        // Insert the new provider into the new registry directly.
372        // Since DashMap is private, we use a tiny test-helper plugin
373        // registered via the public API in the integration test layer.
374        // Here we just check the dispatcher logic in isolation:
375        let mut olds = OldProviders::default();
376        olds.crdt_kinds.insert(
377            CrdtKind::new("count"),
378            Arc::new(CountProvider { kind_str: "count" }),
379        );
380        // With no provider in `new_registry`, the dispatcher should treat
381        // the absence as a clean removal — `Ok(())`.
382        let d = ReloadDispatcher::new(&snap, &registry);
383        d.check_compat(&olds).expect("absence is OK");
384    }
385
386    #[test]
387    fn dispatcher_dispatch_handles_index_handoff() {
388        struct DummyHandle {
389            bytes: Vec<u8>,
390        }
391        impl IndexHandle for DummyHandle {
392            fn probe(
393                &self,
394                _query: &datafusion::arrow::record_batch::RecordBatch,
395                _k: usize,
396            ) -> Result<datafusion::arrow::record_batch::RecordBatch, FnError> {
397                Err(FnError::new(0, "unused"))
398            }
399            fn persist(&self) -> Result<Vec<u8>, FnError> {
400                Ok(self.bytes.clone())
401            }
402            fn schema(&self) -> arrow_schema::SchemaRef {
403                std::sync::Arc::new(arrow_schema::Schema::empty())
404            }
405        }
406        struct DummyProvider;
407        impl IndexKindProvider for DummyProvider {
408            fn kind(&self) -> crate::traits::index::IndexKind {
409                crate::traits::index::IndexKind::new("dummy")
410            }
411            fn build(
412                &self,
413                _source: &datafusion::arrow::record_batch::RecordBatch,
414                _options: &str,
415            ) -> Result<Box<dyn crate::traits::index::IndexBuild>, FnError> {
416                Err(FnError::new(0, "unused"))
417            }
418            fn open(&self, persisted: &[u8]) -> Result<Box<dyn IndexHandle>, FnError> {
419                Ok(Box::new(DummyHandle {
420                    bytes: persisted.to_vec(),
421                }))
422            }
423        }
424        let snap = PluginRecordSnapshot::default();
425        let registry = PluginRegistry::new();
426        let mut handlers = ReloadKindHandlers::default();
427        handlers.index_handles.push(IndexHandoff {
428            name: "i1".to_owned(),
429            old: Box::new(DummyHandle {
430                bytes: vec![1, 2, 3, 4],
431            }),
432            new: Arc::new(DummyProvider),
433        });
434        let outcome = ReloadDispatcher::new(&snap, &registry)
435            .with_handlers(handlers)
436            .dispatch()
437            .expect("handoff");
438        assert_eq!(outcome.index_handles.len(), 1);
439        assert_eq!(outcome.index_handles[0].0, "i1");
440        assert_eq!(
441            outcome.index_handles[0].1.persist().unwrap(),
442            vec![1, 2, 3, 4]
443        );
444    }
445}