1use std::sync::Arc;
37
38use crate::errors::{FnError, ReloadError};
39use crate::registry::{PluginRecordSnapshot, PluginRegistry};
40use crate::traits::cdc::{CdcLsn, CdcOutputProvider, CdcStartContext, CdcStream};
41use crate::traits::crdt::CrdtKindProvider;
42use crate::traits::index::{IndexHandle, IndexKindProvider};
43use crate::traits::types::LogicalTypeProvider;
44
45#[derive(Default)]
53pub struct ReloadKindHandlers {
54 pub index_handles: Vec<IndexHandoff>,
56 pub cdc_streams: Vec<CdcHandoff>,
58}
59
60impl std::fmt::Debug for ReloadKindHandlers {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("ReloadKindHandlers")
63 .field("index_handles", &self.index_handles.len())
64 .field("cdc_streams", &self.cdc_streams.len())
65 .finish()
66 }
67}
68
69pub struct IndexHandoff {
71 pub name: String,
73 pub old: Box<dyn IndexHandle>,
75 pub new: Arc<dyn IndexKindProvider>,
77}
78
79impl std::fmt::Debug for IndexHandoff {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("IndexHandoff")
82 .field("name", &self.name)
83 .finish_non_exhaustive()
84 }
85}
86
87pub struct CdcHandoff {
89 pub name: String,
91 pub old: Box<dyn CdcStream>,
93 pub new: Arc<dyn CdcOutputProvider>,
96}
97
98impl std::fmt::Debug for CdcHandoff {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("CdcHandoff")
101 .field("name", &self.name)
102 .finish_non_exhaustive()
103 }
104}
105
106#[derive(Default)]
110pub struct ReloadOutcome {
111 pub index_handles: Vec<(String, Box<dyn IndexHandle>)>,
113 pub cdc_streams: Vec<(String, Box<dyn CdcStream>)>,
115}
116
117impl std::fmt::Debug for ReloadOutcome {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("ReloadOutcome")
120 .field("index_handles", &self.index_handles.len())
121 .field("cdc_streams", &self.cdc_streams.len())
122 .finish()
123 }
124}
125
126#[derive(Debug)]
133pub struct ReloadDispatcher<'a> {
134 old: &'a PluginRecordSnapshot,
136 new_registry: &'a PluginRegistry,
139 handlers: ReloadKindHandlers,
141}
142
143impl<'a> ReloadDispatcher<'a> {
144 #[must_use]
147 pub fn new(old: &'a PluginRecordSnapshot, new_registry: &'a PluginRegistry) -> Self {
148 Self {
149 old,
150 new_registry,
151 handlers: ReloadKindHandlers::default(),
152 }
153 }
154
155 #[must_use]
157 pub fn with_handlers(mut self, handlers: ReloadKindHandlers) -> Self {
158 self.handlers = handlers;
159 self
160 }
161
162 pub fn check_compat(&self, old_providers: &OldProviders) -> Result<(), ReloadError> {
175 for kind in &self.old.crdt_kinds {
176 let Some(old) = old_providers.crdt_kinds.get(kind) else {
177 continue;
178 };
179 let Some(new) = self.new_registry.crdt_kind(kind) else {
180 continue;
183 };
184 new.schema_compat_check(old.as_ref())
185 .map_err(|e: FnError| {
186 ReloadError::schema_incompat(format!("crdt:{}", kind.0), e.message)
187 })?;
188 }
189 for name in &old_providers.logical_type_names {
190 let Some(old) = old_providers.logical_types.get(name) else {
191 continue;
192 };
193 let Some(new) = self.new_registry.logical_type(name) else {
194 continue;
195 };
196 new.compat_check(old.as_ref()).map_err(|e: FnError| {
197 ReloadError::schema_incompat(format!("logical-type:{name}"), e.message)
198 })?;
199 }
200 Ok(())
201 }
202
203 pub fn dispatch(mut self) -> Result<ReloadOutcome, ReloadError> {
213 let mut outcome = ReloadOutcome::default();
214 for handoff in self.handlers.index_handles.drain(..) {
215 let bytes = handoff.old.persist().map_err(ReloadError::Persist)?;
216 drop(handoff.old);
220 let reopened = handoff.new.open(&bytes).map_err(ReloadError::Persist)?;
221 outcome.index_handles.push((handoff.name, reopened));
222 }
223 for mut handoff in self.handlers.cdc_streams.drain(..) {
224 let lsn: CdcLsn = handoff.old.checkpoint().map_err(ReloadError::Persist)?;
225 handoff.old.shutdown().map_err(ReloadError::Persist)?;
228 drop(handoff.old);
229 let resumed = handoff
230 .new
231 .start(CdcStartContext::new(Some(lsn)))
232 .map_err(ReloadError::Persist)?;
233 outcome.cdc_streams.push((handoff.name, resumed));
234 }
235 Ok(outcome)
236 }
237}
238
239#[derive(Default)]
247pub struct OldProviders {
248 pub crdt_kinds:
250 std::collections::HashMap<crate::traits::crdt::CrdtKind, Arc<dyn CrdtKindProvider>>,
251 pub logical_type_names: Vec<smol_str::SmolStr>,
253 pub logical_types: std::collections::HashMap<smol_str::SmolStr, Arc<dyn LogicalTypeProvider>>,
255}
256
257impl std::fmt::Debug for OldProviders {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("OldProviders")
260 .field("crdt_kinds", &self.crdt_kinds.len())
261 .field("logical_types", &self.logical_types.len())
262 .finish()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::traits::crdt::{CrdtKind, CrdtOp, CrdtState};
270 use datafusion::scalar::ScalarValue;
271
272 #[derive(Default)]
275 struct CountState {
276 v: i64,
277 }
278
279 impl CrdtState for CountState {
280 fn as_any(&self) -> &dyn std::any::Any {
281 self
282 }
283 fn apply(&mut self, op: &CrdtOp) -> Result<(), FnError> {
284 self.v += op.bytes.len() as i64;
285 Ok(())
286 }
287 fn merge(&mut self, other: &dyn CrdtState) -> Result<(), FnError> {
288 let other = other
289 .as_any()
290 .downcast_ref::<CountState>()
291 .ok_or_else(|| FnError::new(0x100, "merge: wrong state type"))?;
292 if other.v > self.v {
293 self.v = other.v;
294 }
295 Ok(())
296 }
297 fn value(&self) -> Result<ScalarValue, FnError> {
298 Ok(ScalarValue::Int64(Some(self.v)))
299 }
300 fn persist(&self) -> Result<Vec<u8>, FnError> {
301 Ok(self.v.to_le_bytes().to_vec())
302 }
303 }
304
305 struct CountProvider {
306 kind_str: &'static str,
307 }
308
309 impl CrdtKindProvider for CountProvider {
310 fn kind(&self) -> CrdtKind {
311 CrdtKind::new(self.kind_str)
312 }
313 fn empty(&self) -> Box<dyn CrdtState> {
314 Box::new(CountState::default())
315 }
316 fn from_persisted(&self, bytes: &[u8]) -> Result<Box<dyn CrdtState>, FnError> {
317 if bytes.len() != 8 {
318 return Err(FnError::new(
319 0x101,
320 format!("expected 8 bytes, got {}", bytes.len()),
321 ));
322 }
323 let mut arr = [0u8; 8];
324 arr.copy_from_slice(bytes);
325 Ok(Box::new(CountState {
326 v: i64::from_le_bytes(arr),
327 }))
328 }
329 }
330
331 struct RejectingProvider;
332
333 impl CrdtKindProvider for RejectingProvider {
334 fn kind(&self) -> CrdtKind {
335 CrdtKind::new("count")
336 }
337 fn empty(&self) -> Box<dyn CrdtState> {
338 Box::new(CountState::default())
339 }
340 fn from_persisted(&self, _bytes: &[u8]) -> Result<Box<dyn CrdtState>, FnError> {
341 Err(FnError::new(0x102, "rejecting all persisted bytes"))
342 }
343 }
344
345 #[test]
348 fn schema_compat_accepts_round_trip() {
349 let old = CountProvider { kind_str: "count" };
350 let new = CountProvider { kind_str: "count" };
351 new.schema_compat_check(&old).expect("compatible");
352 }
353
354 #[test]
355 fn schema_compat_rejects_incompatible_round_trip() {
356 let old = CountProvider { kind_str: "count" };
357 let new = RejectingProvider;
358 let err = new.schema_compat_check(&old).unwrap_err();
359 assert!(err.message.contains("rejecting"));
360 }
361
362 #[test]
363 fn dispatcher_check_compat_passes_when_all_round_trip() {
364 let registry = PluginRegistry::new();
365 let snap = PluginRecordSnapshot {
368 crdt_kinds: vec![CrdtKind::new("count")],
369 ..Default::default()
370 };
371 let mut olds = OldProviders::default();
376 olds.crdt_kinds.insert(
377 CrdtKind::new("count"),
378 Arc::new(CountProvider { kind_str: "count" }),
379 );
380 let d = ReloadDispatcher::new(&snap, ®istry);
383 d.check_compat(&olds).expect("absence is OK");
384 }
385
386 #[test]
387 fn dispatcher_dispatch_handles_index_handoff() {
388 struct DummyHandle {
389 bytes: Vec<u8>,
390 }
391 impl IndexHandle for DummyHandle {
392 fn probe(
393 &self,
394 _query: &datafusion::arrow::record_batch::RecordBatch,
395 _k: usize,
396 ) -> Result<datafusion::arrow::record_batch::RecordBatch, FnError> {
397 Err(FnError::new(0, "unused"))
398 }
399 fn persist(&self) -> Result<Vec<u8>, FnError> {
400 Ok(self.bytes.clone())
401 }
402 fn schema(&self) -> arrow_schema::SchemaRef {
403 std::sync::Arc::new(arrow_schema::Schema::empty())
404 }
405 }
406 struct DummyProvider;
407 impl IndexKindProvider for DummyProvider {
408 fn kind(&self) -> crate::traits::index::IndexKind {
409 crate::traits::index::IndexKind::new("dummy")
410 }
411 fn build(
412 &self,
413 _source: &datafusion::arrow::record_batch::RecordBatch,
414 _options: &str,
415 ) -> Result<Box<dyn crate::traits::index::IndexBuild>, FnError> {
416 Err(FnError::new(0, "unused"))
417 }
418 fn open(&self, persisted: &[u8]) -> Result<Box<dyn IndexHandle>, FnError> {
419 Ok(Box::new(DummyHandle {
420 bytes: persisted.to_vec(),
421 }))
422 }
423 }
424 let snap = PluginRecordSnapshot::default();
425 let registry = PluginRegistry::new();
426 let mut handlers = ReloadKindHandlers::default();
427 handlers.index_handles.push(IndexHandoff {
428 name: "i1".to_owned(),
429 old: Box::new(DummyHandle {
430 bytes: vec![1, 2, 3, 4],
431 }),
432 new: Arc::new(DummyProvider),
433 });
434 let outcome = ReloadDispatcher::new(&snap, ®istry)
435 .with_handlers(handlers)
436 .dispatch()
437 .expect("handoff");
438 assert_eq!(outcome.index_handles.len(), 1);
439 assert_eq!(outcome.index_handles[0].0, "i1");
440 assert_eq!(
441 outcome.index_handles[0].1.persist().unwrap(),
442 vec![1, 2, 3, 4]
443 );
444 }
445}