Skip to main content

sochdb_kernel/
plugin_hot_reload.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Hot-Reload Without Restart
19//!
20//! This module implements zero-downtime plugin upgrades using atomic
21//! swapping and epoch-based draining.
22//!
23//! ## Design
24//!
25//! ```text
26//!                      ┌───────────────────────────┐
27//!                      │    HotReloadablePlugin    │
28//!                      │                           │
29//!                      │  ┌─────────────────────┐  │
30//!                      │  │  Arc<Current Plugin> │  │
31//!                      │  └──────────┬──────────┘  │
32//!                      │             │             │
33//!   New Version ──────►│  ┌──────────▼──────────┐  │
34//!                      │  │  prepare_upgrade()   │  │
35//!                      │  └──────────┬──────────┘  │
36//!                      │             │             │
37//!                      │  ┌──────────▼──────────┐  │
38//!                      │  │  drain_in_flight()   │  │
39//!                      │  └──────────┬──────────┘  │
40//!                      │             │             │
41//!                      │  ┌──────────▼──────────┐  │
42//!                      │  │  atomic_swap()       │  │
43//!                      │  └──────────┬──────────┘  │
44//!                      │             │             │
45//!                      │  ┌──────────▼──────────┐  │
46//!                      │  │  cleanup_old()       │  │
47//!                      │  └─────────────────────┘  │
48//!                      └───────────────────────────┘
49//! ```
50//!
51//! ## Safety Properties
52//!
53//! 1. **No Request Drops**: In-flight calls complete on old version
54//! 2. **Atomic Transition**: New calls immediately use new version
55//! 3. **Memory Safety**: Old version freed only when refs drop to zero
56//! 4. **Rollback**: If new version fails, old version remains active
57
58use crate::error::{KernelError, KernelResult};
59use crate::plugin_manifest::PluginManifest;
60use crate::wasm_runtime::{WasmInstanceConfig, WasmPluginInstance, WasmValue};
61use parking_lot::{Mutex, RwLock};
62use std::sync::Arc;
63use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
64use std::time::{Duration, Instant};
65
66// ============================================================================
67// Epoch-Based Draining
68// ============================================================================
69
70/// Epoch counter for tracking in-flight operations
71pub struct EpochTracker {
72    /// Current epoch number
73    epoch: AtomicU64,
74    /// Reference counts per epoch (circular buffer of 8 epochs)
75    epoch_refs: [AtomicUsize; 8],
76}
77
78impl Default for EpochTracker {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl EpochTracker {
85    /// Create a new epoch tracker
86    pub fn new() -> Self {
87        Self {
88            epoch: AtomicU64::new(0),
89            epoch_refs: Default::default(),
90        }
91    }
92
93    /// Get current epoch
94    pub fn current(&self) -> u64 {
95        self.epoch.load(Ordering::Acquire)
96    }
97
98    /// Enter an epoch (increment ref count)
99    pub fn enter(&self) -> EpochGuard<'_> {
100        let epoch = self.current();
101        let idx = (epoch % 8) as usize;
102        self.epoch_refs[idx].fetch_add(1, Ordering::AcqRel);
103        EpochGuard {
104            tracker: self,
105            epoch,
106        }
107    }
108
109    /// Advance to next epoch
110    pub fn advance(&self) -> u64 {
111        self.epoch.fetch_add(1, Ordering::AcqRel) + 1
112    }
113
114    /// Wait for an epoch to drain (all refs released)
115    pub fn wait_drain(&self, target_epoch: u64, timeout: Duration) -> bool {
116        let idx = (target_epoch % 8) as usize;
117        let start = Instant::now();
118
119        while start.elapsed() < timeout {
120            if self.epoch_refs[idx].load(Ordering::Acquire) == 0 {
121                return true;
122            }
123            std::thread::sleep(Duration::from_micros(100));
124        }
125
126        false
127    }
128
129    /// Get reference count for an epoch
130    pub fn refs_for_epoch(&self, epoch: u64) -> usize {
131        let idx = (epoch % 8) as usize;
132        self.epoch_refs[idx].load(Ordering::Acquire)
133    }
134}
135
136/// Guard that releases epoch reference on drop
137pub struct EpochGuard<'a> {
138    tracker: &'a EpochTracker,
139    epoch: u64,
140}
141
142impl Drop for EpochGuard<'_> {
143    fn drop(&mut self) {
144        let idx = (self.epoch % 8) as usize;
145        self.tracker.epoch_refs[idx].fetch_sub(1, Ordering::AcqRel);
146    }
147}
148
149// ============================================================================
150// Hot-Reloadable Plugin
151// ============================================================================
152
153/// State of a hot-reloadable plugin
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
155pub enum HotReloadState {
156    /// Normal operation
157    Active,
158    /// Preparing for upgrade
159    PreparingUpgrade,
160    /// Draining in-flight calls
161    Draining,
162    /// Performing atomic swap
163    Swapping,
164    /// Upgrade complete
165    UpgradeComplete,
166    /// Upgrade failed, rolled back
167    RolledBack,
168}
169
170/// Statistics for hot-reload operations
171#[derive(Debug, Clone, Default)]
172pub struct HotReloadStats {
173    /// Number of successful upgrades
174    pub successful_upgrades: u64,
175    /// Number of failed upgrades
176    pub failed_upgrades: u64,
177    /// Total drain time in microseconds
178    pub total_drain_time_us: u64,
179    /// Longest drain time
180    pub max_drain_time_us: u64,
181    /// Current version number
182    pub version: u64,
183}
184
185/// A hot-reloadable plugin wrapper
186pub struct HotReloadablePlugin {
187    /// Plugin name
188    name: String,
189    /// Current active instance (wrapped in Arc for atomic swap)
190    current: RwLock<Arc<WasmPluginInstance>>,
191    /// Pending new instance (during upgrade)
192    pending: Mutex<Option<Arc<WasmPluginInstance>>>,
193    /// Epoch tracker for draining
194    epochs: EpochTracker,
195    /// Current state
196    state: RwLock<HotReloadState>,
197    /// Statistics
198    stats: RwLock<HotReloadStats>,
199    /// Current manifest
200    manifest: RwLock<PluginManifest>,
201    /// Upgrade timeout
202    drain_timeout: Duration,
203}
204
205impl HotReloadablePlugin {
206    /// Create a new hot-reloadable plugin
207    pub fn new(name: &str, instance: Arc<WasmPluginInstance>, manifest: PluginManifest) -> Self {
208        Self {
209            name: name.to_string(),
210            current: RwLock::new(instance),
211            pending: Mutex::new(None),
212            epochs: EpochTracker::new(),
213            state: RwLock::new(HotReloadState::Active),
214            stats: RwLock::new(HotReloadStats::default()),
215            manifest: RwLock::new(manifest),
216            drain_timeout: Duration::from_secs(5),
217        }
218    }
219
220    /// Get the current active instance
221    pub fn current(&self) -> Arc<WasmPluginInstance> {
222        self.current.read().clone()
223    }
224
225    /// Call a function, tracking the epoch
226    pub fn call(&self, func_name: &str, args: &[WasmValue]) -> KernelResult<Vec<WasmValue>> {
227        // Enter epoch
228        let _guard = self.epochs.enter();
229
230        // Get current instance
231        let instance = self.current();
232
233        // Execute call
234        instance.call(func_name, args)
235    }
236
237    /// Prepare an upgrade with new WASM bytes
238    pub fn prepare_upgrade(
239        &self,
240        wasm_bytes: &[u8],
241        new_manifest: PluginManifest,
242    ) -> KernelResult<()> {
243        // Check current state
244        {
245            let state = self.state.read();
246            if *state != HotReloadState::Active {
247                return Err(KernelError::Plugin {
248                    message: format!("cannot upgrade in state {:?}, must be Active", *state),
249                });
250            }
251        }
252
253        // Set preparing state
254        *self.state.write() = HotReloadState::PreparingUpgrade;
255
256        // Validate new manifest against current
257        self.validate_upgrade(&new_manifest)?;
258
259        // Create new instance
260        let config = WasmInstanceConfig {
261            capabilities: new_manifest.to_capabilities(),
262            ..Default::default()
263        };
264
265        let new_instance = WasmPluginInstance::new(&self.name, wasm_bytes, config)?;
266        new_instance.init()?;
267
268        // Store pending
269        *self.pending.lock() = Some(Arc::new(new_instance));
270        *self.manifest.write() = new_manifest;
271
272        Ok(())
273    }
274
275    /// Execute the upgrade
276    pub fn execute_upgrade(&self) -> KernelResult<()> {
277        // Check state
278        {
279            let state = self.state.read();
280            if *state != HotReloadState::PreparingUpgrade {
281                return Err(KernelError::Plugin {
282                    message: "must call prepare_upgrade first".to_string(),
283                });
284            }
285        }
286
287        let drain_start = Instant::now();
288
289        // Enter draining state
290        *self.state.write() = HotReloadState::Draining;
291
292        // Advance epoch
293        let old_epoch = self.epochs.current();
294        self.epochs.advance();
295
296        // Wait for old epoch to drain
297        if !self.epochs.wait_drain(old_epoch, self.drain_timeout) {
298            // Rollback
299            *self.state.write() = HotReloadState::RolledBack;
300            *self.pending.lock() = None;
301
302            let mut stats = self.stats.write();
303            stats.failed_upgrades += 1;
304
305            return Err(KernelError::Plugin {
306                message: format!(
307                    "drain timeout: {} refs still held after {:?}",
308                    self.epochs.refs_for_epoch(old_epoch),
309                    self.drain_timeout
310                ),
311            });
312        }
313
314        let drain_time = drain_start.elapsed();
315
316        // Enter swapping state
317        *self.state.write() = HotReloadState::Swapping;
318
319        // Get pending instance
320        let new_instance = self
321            .pending
322            .lock()
323            .take()
324            .ok_or_else(|| KernelError::Plugin {
325                message: "pending instance missing during upgrade".to_string(),
326            })?;
327
328        // Atomic swap
329        *self.current.write() = new_instance;
330
331        // Update stats
332        {
333            let mut stats = self.stats.write();
334            stats.successful_upgrades += 1;
335            stats.version += 1;
336            let drain_us = drain_time.as_micros() as u64;
337            stats.total_drain_time_us += drain_us;
338            stats.max_drain_time_us = stats.max_drain_time_us.max(drain_us);
339        }
340
341        // Mark complete
342        *self.state.write() = HotReloadState::UpgradeComplete;
343
344        // Reset to active
345        *self.state.write() = HotReloadState::Active;
346
347        Ok(())
348    }
349
350    /// Perform a full upgrade (prepare + execute)
351    pub fn upgrade(&self, wasm_bytes: &[u8], new_manifest: PluginManifest) -> KernelResult<()> {
352        self.prepare_upgrade(wasm_bytes, new_manifest)?;
353        self.execute_upgrade()
354    }
355
356    /// Cancel a pending upgrade
357    pub fn cancel_upgrade(&self) -> KernelResult<()> {
358        let state = *self.state.read();
359
360        match state {
361            HotReloadState::PreparingUpgrade => {
362                *self.pending.lock() = None;
363                *self.state.write() = HotReloadState::Active;
364                Ok(())
365            }
366            HotReloadState::Active => {
367                // Nothing to cancel
368                Ok(())
369            }
370            _ => Err(KernelError::Plugin {
371                message: format!("cannot cancel in state {:?}", state),
372            }),
373        }
374    }
375
376    /// Get current state
377    pub fn state(&self) -> HotReloadState {
378        *self.state.read()
379    }
380
381    /// Get statistics
382    pub fn stats(&self) -> HotReloadStats {
383        self.stats.read().clone()
384    }
385
386    /// Get plugin name
387    pub fn name(&self) -> &str {
388        &self.name
389    }
390
391    /// Get current manifest
392    pub fn manifest(&self) -> PluginManifest {
393        self.manifest.read().clone()
394    }
395
396    /// Set drain timeout
397    pub fn set_drain_timeout(&mut self, timeout: Duration) {
398        self.drain_timeout = timeout;
399    }
400
401    /// Validate that an upgrade is compatible
402    fn validate_upgrade(&self, new_manifest: &PluginManifest) -> KernelResult<()> {
403        let current = self.manifest.read();
404
405        // Name must match
406        if current.plugin.name != new_manifest.plugin.name {
407            return Err(KernelError::Plugin {
408                message: format!(
409                    "plugin name mismatch: {} vs {}",
410                    current.plugin.name, new_manifest.plugin.name
411                ),
412            });
413        }
414
415        // Version should be different (warning, not error)
416        if current.plugin.version == new_manifest.plugin.version {
417            // Just a warning in production
418        }
419
420        // All existing hooks must still be present
421        for hook in &current.hooks.before_insert {
422            if !new_manifest.exports.functions.contains(hook) {
423                return Err(KernelError::Plugin {
424                    message: format!("new version missing hook function: {}", hook),
425                });
426            }
427        }
428
429        Ok(())
430    }
431}
432
433// ============================================================================
434// Hot-Reload Manager
435// ============================================================================
436
437/// Manager for all hot-reloadable plugins
438pub struct HotReloadManager {
439    /// Plugins by name
440    plugins: RwLock<std::collections::HashMap<String, Arc<HotReloadablePlugin>>>,
441}
442
443impl Default for HotReloadManager {
444    fn default() -> Self {
445        Self::new()
446    }
447}
448
449impl HotReloadManager {
450    /// Create a new manager
451    pub fn new() -> Self {
452        Self {
453            plugins: RwLock::new(std::collections::HashMap::new()),
454        }
455    }
456
457    /// Register a new hot-reloadable plugin
458    pub fn register(
459        &self,
460        name: &str,
461        instance: Arc<WasmPluginInstance>,
462        manifest: PluginManifest,
463    ) -> KernelResult<()> {
464        let mut plugins = self.plugins.write();
465
466        if plugins.contains_key(name) {
467            return Err(KernelError::Plugin {
468                message: format!("plugin '{}' already registered", name),
469            });
470        }
471
472        let plugin = Arc::new(HotReloadablePlugin::new(name, instance, manifest));
473        plugins.insert(name.to_string(), plugin);
474
475        Ok(())
476    }
477
478    /// Get a plugin by name
479    pub fn get(&self, name: &str) -> Option<Arc<HotReloadablePlugin>> {
480        self.plugins.read().get(name).cloned()
481    }
482
483    /// Upgrade a plugin
484    pub fn upgrade(
485        &self,
486        name: &str,
487        wasm_bytes: &[u8],
488        new_manifest: PluginManifest,
489    ) -> KernelResult<()> {
490        let plugin = self.get(name).ok_or_else(|| KernelError::Plugin {
491            message: format!("plugin '{}' not found", name),
492        })?;
493
494        plugin.upgrade(wasm_bytes, new_manifest)
495    }
496
497    /// Unregister a plugin
498    pub fn unregister(&self, name: &str) -> KernelResult<()> {
499        let mut plugins = self.plugins.write();
500
501        if plugins.remove(name).is_none() {
502            return Err(KernelError::Plugin {
503                message: format!("plugin '{}' not found", name),
504            });
505        }
506
507        Ok(())
508    }
509
510    /// List all plugins
511    pub fn list(&self) -> Vec<String> {
512        self.plugins.read().keys().cloned().collect()
513    }
514
515    /// Get stats for all plugins
516    pub fn all_stats(&self) -> Vec<(String, HotReloadStats)> {
517        self.plugins
518            .read()
519            .iter()
520            .map(|(name, plugin)| (name.clone(), plugin.stats()))
521            .collect()
522    }
523}
524
525// ============================================================================
526// Tests
527// ============================================================================
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use crate::plugin_manifest::ManifestBuilder;
533
534    fn create_test_instance(name: &str) -> Arc<WasmPluginInstance> {
535        let config = WasmInstanceConfig::default();
536        let instance = WasmPluginInstance::new(name, b"test wasm", config).unwrap();
537        instance.init().unwrap();
538        Arc::new(instance)
539    }
540
541    fn create_test_manifest(name: &str, version: &str) -> PluginManifest {
542        ManifestBuilder::new(name, version)
543            .export("on_insert")
544            .build()
545            .unwrap()
546    }
547
548    #[test]
549    fn test_epoch_tracker() {
550        let tracker = EpochTracker::new();
551
552        assert_eq!(tracker.current(), 0);
553
554        // Enter epoch 0
555        let guard = tracker.enter();
556        assert_eq!(tracker.refs_for_epoch(0), 1);
557
558        // Advance to epoch 1
559        tracker.advance();
560        assert_eq!(tracker.current(), 1);
561
562        // Old epoch still has ref
563        assert_eq!(tracker.refs_for_epoch(0), 1);
564
565        // Drop guard
566        drop(guard);
567        assert_eq!(tracker.refs_for_epoch(0), 0);
568    }
569
570    #[test]
571    fn test_epoch_drain() {
572        let tracker = EpochTracker::new();
573
574        // No refs, drain should succeed immediately
575        assert!(tracker.wait_drain(0, Duration::from_millis(10)));
576
577        // With a ref, drain should timeout
578        let _guard = tracker.enter();
579        assert!(!tracker.wait_drain(0, Duration::from_millis(10)));
580    }
581
582    #[test]
583    fn test_hot_reload_creation() {
584        let instance = create_test_instance("test");
585        let manifest = create_test_manifest("test", "1.0.0");
586
587        let plugin = HotReloadablePlugin::new("test", instance, manifest);
588
589        assert_eq!(plugin.name(), "test");
590        assert_eq!(plugin.state(), HotReloadState::Active);
591    }
592
593    #[test]
594    fn test_hot_reload_call() {
595        let instance = create_test_instance("test");
596        let manifest = create_test_manifest("test", "1.0.0");
597
598        let plugin = HotReloadablePlugin::new("test", instance, manifest);
599
600        // Call should work
601        let result = plugin.call("on_insert", &[]);
602        assert!(result.is_ok());
603    }
604
605    #[test]
606    fn test_hot_reload_prepare() {
607        let instance = create_test_instance("test");
608        let manifest = create_test_manifest("test", "1.0.0");
609
610        let plugin = HotReloadablePlugin::new("test", instance, manifest);
611
612        let new_manifest = create_test_manifest("test", "2.0.0");
613        plugin.prepare_upgrade(b"new wasm", new_manifest).unwrap();
614
615        assert_eq!(plugin.state(), HotReloadState::PreparingUpgrade);
616    }
617
618    #[test]
619    fn test_hot_reload_full_upgrade() {
620        let instance = create_test_instance("test");
621        let manifest = create_test_manifest("test", "1.0.0");
622
623        let plugin = HotReloadablePlugin::new("test", instance, manifest);
624
625        let new_manifest = create_test_manifest("test", "2.0.0");
626        plugin.upgrade(b"new wasm", new_manifest).unwrap();
627
628        assert_eq!(plugin.state(), HotReloadState::Active);
629
630        let stats = plugin.stats();
631        assert_eq!(stats.successful_upgrades, 1);
632        assert_eq!(stats.version, 1);
633    }
634
635    #[test]
636    fn test_hot_reload_cancel() {
637        let instance = create_test_instance("test");
638        let manifest = create_test_manifest("test", "1.0.0");
639
640        let plugin = HotReloadablePlugin::new("test", instance, manifest);
641
642        let new_manifest = create_test_manifest("test", "2.0.0");
643        plugin.prepare_upgrade(b"new wasm", new_manifest).unwrap();
644
645        plugin.cancel_upgrade().unwrap();
646        assert_eq!(plugin.state(), HotReloadState::Active);
647    }
648
649    #[test]
650    fn test_hot_reload_name_mismatch() {
651        let instance = create_test_instance("test");
652        let manifest = create_test_manifest("test", "1.0.0");
653
654        let plugin = HotReloadablePlugin::new("test", instance, manifest);
655
656        // Different name should fail
657        let new_manifest = create_test_manifest("different", "2.0.0");
658        let result = plugin.prepare_upgrade(b"new wasm", new_manifest);
659        assert!(result.is_err());
660    }
661
662    #[test]
663    fn test_manager_operations() {
664        let manager = HotReloadManager::new();
665
666        let instance = create_test_instance("plugin1");
667        let manifest = create_test_manifest("plugin1", "1.0.0");
668
669        // Register
670        manager.register("plugin1", instance, manifest).unwrap();
671        assert_eq!(manager.list().len(), 1);
672
673        // Get
674        let plugin = manager.get("plugin1").unwrap();
675        assert_eq!(plugin.name(), "plugin1");
676
677        // Upgrade
678        let new_manifest = create_test_manifest("plugin1", "2.0.0");
679        manager
680            .upgrade("plugin1", b"new wasm", new_manifest)
681            .unwrap();
682
683        // Stats
684        let stats = manager.all_stats();
685        assert_eq!(stats.len(), 1);
686        assert_eq!(stats[0].1.successful_upgrades, 1);
687
688        // Unregister
689        manager.unregister("plugin1").unwrap();
690        assert!(manager.list().is_empty());
691    }
692
693    #[test]
694    fn test_manager_duplicate() {
695        let manager = HotReloadManager::new();
696
697        let instance1 = create_test_instance("dup");
698        let manifest1 = create_test_manifest("dup", "1.0.0");
699        manager.register("dup", instance1, manifest1).unwrap();
700
701        let instance2 = create_test_instance("dup");
702        let manifest2 = create_test_manifest("dup", "1.0.0");
703        let result = manager.register("dup", instance2, manifest2);
704        assert!(result.is_err());
705    }
706
707    #[test]
708    fn test_concurrent_calls_during_upgrade() {
709        use std::sync::atomic::AtomicBool;
710        use std::thread;
711
712        let instance = create_test_instance("concurrent");
713        let manifest = create_test_manifest("concurrent", "1.0.0");
714
715        let plugin = Arc::new(HotReloadablePlugin::new("concurrent", instance, manifest));
716
717        // Flag to stop workers
718        let stop = Arc::new(AtomicBool::new(false));
719
720        // Spawn worker threads making calls
721        let mut handles = vec![];
722        for _ in 0..4 {
723            let p = plugin.clone();
724            let s = stop.clone();
725            handles.push(thread::spawn(move || {
726                let mut calls = 0;
727                while !s.load(Ordering::Relaxed) {
728                    let _ = p.call("on_insert", &[]);
729                    calls += 1;
730                    if calls > 100 {
731                        break;
732                    }
733                }
734            }));
735        }
736
737        // Perform upgrade while calls are happening
738        thread::sleep(Duration::from_millis(5));
739        let new_manifest = create_test_manifest("concurrent", "2.0.0");
740        let result = plugin.upgrade(b"new wasm", new_manifest);
741
742        // Stop workers
743        stop.store(true, Ordering::Relaxed);
744        for h in handles {
745            h.join().unwrap();
746        }
747
748        // Upgrade should succeed
749        assert!(result.is_ok());
750        assert_eq!(plugin.stats().successful_upgrades, 1);
751    }
752}