1use crate::error::{KernelError, KernelResult};
59use crate::plugin_manifest::PluginManifest;
60use crate::wasm_runtime::{WasmInstanceConfig, WasmPluginInstance, WasmValue};
61use parking_lot::{Mutex, RwLock};
62use std::sync::Arc;
63use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
64use std::time::{Duration, Instant};
65
66pub struct EpochTracker {
72 epoch: AtomicU64,
74 epoch_refs: [AtomicUsize; 8],
76}
77
78impl Default for EpochTracker {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl EpochTracker {
85 pub fn new() -> Self {
87 Self {
88 epoch: AtomicU64::new(0),
89 epoch_refs: Default::default(),
90 }
91 }
92
93 pub fn current(&self) -> u64 {
95 self.epoch.load(Ordering::Acquire)
96 }
97
98 pub fn enter(&self) -> EpochGuard<'_> {
100 let epoch = self.current();
101 let idx = (epoch % 8) as usize;
102 self.epoch_refs[idx].fetch_add(1, Ordering::AcqRel);
103 EpochGuard {
104 tracker: self,
105 epoch,
106 }
107 }
108
109 pub fn advance(&self) -> u64 {
111 self.epoch.fetch_add(1, Ordering::AcqRel) + 1
112 }
113
114 pub fn wait_drain(&self, target_epoch: u64, timeout: Duration) -> bool {
116 let idx = (target_epoch % 8) as usize;
117 let start = Instant::now();
118
119 while start.elapsed() < timeout {
120 if self.epoch_refs[idx].load(Ordering::Acquire) == 0 {
121 return true;
122 }
123 std::thread::sleep(Duration::from_micros(100));
124 }
125
126 false
127 }
128
129 pub fn refs_for_epoch(&self, epoch: u64) -> usize {
131 let idx = (epoch % 8) as usize;
132 self.epoch_refs[idx].load(Ordering::Acquire)
133 }
134}
135
136pub struct EpochGuard<'a> {
138 tracker: &'a EpochTracker,
139 epoch: u64,
140}
141
142impl Drop for EpochGuard<'_> {
143 fn drop(&mut self) {
144 let idx = (self.epoch % 8) as usize;
145 self.tracker.epoch_refs[idx].fetch_sub(1, Ordering::AcqRel);
146 }
147}
148
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
155pub enum HotReloadState {
156 Active,
158 PreparingUpgrade,
160 Draining,
162 Swapping,
164 UpgradeComplete,
166 RolledBack,
168}
169
170#[derive(Debug, Clone, Default)]
172pub struct HotReloadStats {
173 pub successful_upgrades: u64,
175 pub failed_upgrades: u64,
177 pub total_drain_time_us: u64,
179 pub max_drain_time_us: u64,
181 pub version: u64,
183}
184
185pub struct HotReloadablePlugin {
187 name: String,
189 current: RwLock<Arc<WasmPluginInstance>>,
191 pending: Mutex<Option<Arc<WasmPluginInstance>>>,
193 epochs: EpochTracker,
195 state: RwLock<HotReloadState>,
197 stats: RwLock<HotReloadStats>,
199 manifest: RwLock<PluginManifest>,
201 drain_timeout: Duration,
203}
204
205impl HotReloadablePlugin {
206 pub fn new(name: &str, instance: Arc<WasmPluginInstance>, manifest: PluginManifest) -> Self {
208 Self {
209 name: name.to_string(),
210 current: RwLock::new(instance),
211 pending: Mutex::new(None),
212 epochs: EpochTracker::new(),
213 state: RwLock::new(HotReloadState::Active),
214 stats: RwLock::new(HotReloadStats::default()),
215 manifest: RwLock::new(manifest),
216 drain_timeout: Duration::from_secs(5),
217 }
218 }
219
220 pub fn current(&self) -> Arc<WasmPluginInstance> {
222 self.current.read().clone()
223 }
224
225 pub fn call(&self, func_name: &str, args: &[WasmValue]) -> KernelResult<Vec<WasmValue>> {
227 let _guard = self.epochs.enter();
229
230 let instance = self.current();
232
233 instance.call(func_name, args)
235 }
236
237 pub fn prepare_upgrade(
239 &self,
240 wasm_bytes: &[u8],
241 new_manifest: PluginManifest,
242 ) -> KernelResult<()> {
243 {
245 let state = self.state.read();
246 if *state != HotReloadState::Active {
247 return Err(KernelError::Plugin {
248 message: format!("cannot upgrade in state {:?}, must be Active", *state),
249 });
250 }
251 }
252
253 *self.state.write() = HotReloadState::PreparingUpgrade;
255
256 self.validate_upgrade(&new_manifest)?;
258
259 let config = WasmInstanceConfig {
261 capabilities: new_manifest.to_capabilities(),
262 ..Default::default()
263 };
264
265 let new_instance = WasmPluginInstance::new(&self.name, wasm_bytes, config)?;
266 new_instance.init()?;
267
268 *self.pending.lock() = Some(Arc::new(new_instance));
270 *self.manifest.write() = new_manifest;
271
272 Ok(())
273 }
274
275 pub fn execute_upgrade(&self) -> KernelResult<()> {
277 {
279 let state = self.state.read();
280 if *state != HotReloadState::PreparingUpgrade {
281 return Err(KernelError::Plugin {
282 message: "must call prepare_upgrade first".to_string(),
283 });
284 }
285 }
286
287 let drain_start = Instant::now();
288
289 *self.state.write() = HotReloadState::Draining;
291
292 let old_epoch = self.epochs.current();
294 self.epochs.advance();
295
296 if !self.epochs.wait_drain(old_epoch, self.drain_timeout) {
298 *self.state.write() = HotReloadState::RolledBack;
300 *self.pending.lock() = None;
301
302 let mut stats = self.stats.write();
303 stats.failed_upgrades += 1;
304
305 return Err(KernelError::Plugin {
306 message: format!(
307 "drain timeout: {} refs still held after {:?}",
308 self.epochs.refs_for_epoch(old_epoch),
309 self.drain_timeout
310 ),
311 });
312 }
313
314 let drain_time = drain_start.elapsed();
315
316 *self.state.write() = HotReloadState::Swapping;
318
319 let new_instance = self
321 .pending
322 .lock()
323 .take()
324 .ok_or_else(|| KernelError::Plugin {
325 message: "pending instance missing during upgrade".to_string(),
326 })?;
327
328 *self.current.write() = new_instance;
330
331 {
333 let mut stats = self.stats.write();
334 stats.successful_upgrades += 1;
335 stats.version += 1;
336 let drain_us = drain_time.as_micros() as u64;
337 stats.total_drain_time_us += drain_us;
338 stats.max_drain_time_us = stats.max_drain_time_us.max(drain_us);
339 }
340
341 *self.state.write() = HotReloadState::UpgradeComplete;
343
344 *self.state.write() = HotReloadState::Active;
346
347 Ok(())
348 }
349
350 pub fn upgrade(&self, wasm_bytes: &[u8], new_manifest: PluginManifest) -> KernelResult<()> {
352 self.prepare_upgrade(wasm_bytes, new_manifest)?;
353 self.execute_upgrade()
354 }
355
356 pub fn cancel_upgrade(&self) -> KernelResult<()> {
358 let state = *self.state.read();
359
360 match state {
361 HotReloadState::PreparingUpgrade => {
362 *self.pending.lock() = None;
363 *self.state.write() = HotReloadState::Active;
364 Ok(())
365 }
366 HotReloadState::Active => {
367 Ok(())
369 }
370 _ => Err(KernelError::Plugin {
371 message: format!("cannot cancel in state {:?}", state),
372 }),
373 }
374 }
375
376 pub fn state(&self) -> HotReloadState {
378 *self.state.read()
379 }
380
381 pub fn stats(&self) -> HotReloadStats {
383 self.stats.read().clone()
384 }
385
386 pub fn name(&self) -> &str {
388 &self.name
389 }
390
391 pub fn manifest(&self) -> PluginManifest {
393 self.manifest.read().clone()
394 }
395
396 pub fn set_drain_timeout(&mut self, timeout: Duration) {
398 self.drain_timeout = timeout;
399 }
400
401 fn validate_upgrade(&self, new_manifest: &PluginManifest) -> KernelResult<()> {
403 let current = self.manifest.read();
404
405 if current.plugin.name != new_manifest.plugin.name {
407 return Err(KernelError::Plugin {
408 message: format!(
409 "plugin name mismatch: {} vs {}",
410 current.plugin.name, new_manifest.plugin.name
411 ),
412 });
413 }
414
415 if current.plugin.version == new_manifest.plugin.version {
417 }
419
420 for hook in ¤t.hooks.before_insert {
422 if !new_manifest.exports.functions.contains(hook) {
423 return Err(KernelError::Plugin {
424 message: format!("new version missing hook function: {}", hook),
425 });
426 }
427 }
428
429 Ok(())
430 }
431}
432
433pub struct HotReloadManager {
439 plugins: RwLock<std::collections::HashMap<String, Arc<HotReloadablePlugin>>>,
441}
442
443impl Default for HotReloadManager {
444 fn default() -> Self {
445 Self::new()
446 }
447}
448
449impl HotReloadManager {
450 pub fn new() -> Self {
452 Self {
453 plugins: RwLock::new(std::collections::HashMap::new()),
454 }
455 }
456
457 pub fn register(
459 &self,
460 name: &str,
461 instance: Arc<WasmPluginInstance>,
462 manifest: PluginManifest,
463 ) -> KernelResult<()> {
464 let mut plugins = self.plugins.write();
465
466 if plugins.contains_key(name) {
467 return Err(KernelError::Plugin {
468 message: format!("plugin '{}' already registered", name),
469 });
470 }
471
472 let plugin = Arc::new(HotReloadablePlugin::new(name, instance, manifest));
473 plugins.insert(name.to_string(), plugin);
474
475 Ok(())
476 }
477
478 pub fn get(&self, name: &str) -> Option<Arc<HotReloadablePlugin>> {
480 self.plugins.read().get(name).cloned()
481 }
482
483 pub fn upgrade(
485 &self,
486 name: &str,
487 wasm_bytes: &[u8],
488 new_manifest: PluginManifest,
489 ) -> KernelResult<()> {
490 let plugin = self.get(name).ok_or_else(|| KernelError::Plugin {
491 message: format!("plugin '{}' not found", name),
492 })?;
493
494 plugin.upgrade(wasm_bytes, new_manifest)
495 }
496
497 pub fn unregister(&self, name: &str) -> KernelResult<()> {
499 let mut plugins = self.plugins.write();
500
501 if plugins.remove(name).is_none() {
502 return Err(KernelError::Plugin {
503 message: format!("plugin '{}' not found", name),
504 });
505 }
506
507 Ok(())
508 }
509
510 pub fn list(&self) -> Vec<String> {
512 self.plugins.read().keys().cloned().collect()
513 }
514
515 pub fn all_stats(&self) -> Vec<(String, HotReloadStats)> {
517 self.plugins
518 .read()
519 .iter()
520 .map(|(name, plugin)| (name.clone(), plugin.stats()))
521 .collect()
522 }
523}
524
525#[cfg(test)]
530mod tests {
531 use super::*;
532 use crate::plugin_manifest::ManifestBuilder;
533
534 fn create_test_instance(name: &str) -> Arc<WasmPluginInstance> {
535 let config = WasmInstanceConfig::default();
536 let instance = WasmPluginInstance::new(name, b"test wasm", config).unwrap();
537 instance.init().unwrap();
538 Arc::new(instance)
539 }
540
541 fn create_test_manifest(name: &str, version: &str) -> PluginManifest {
542 ManifestBuilder::new(name, version)
543 .export("on_insert")
544 .build()
545 .unwrap()
546 }
547
548 #[test]
549 fn test_epoch_tracker() {
550 let tracker = EpochTracker::new();
551
552 assert_eq!(tracker.current(), 0);
553
554 let guard = tracker.enter();
556 assert_eq!(tracker.refs_for_epoch(0), 1);
557
558 tracker.advance();
560 assert_eq!(tracker.current(), 1);
561
562 assert_eq!(tracker.refs_for_epoch(0), 1);
564
565 drop(guard);
567 assert_eq!(tracker.refs_for_epoch(0), 0);
568 }
569
570 #[test]
571 fn test_epoch_drain() {
572 let tracker = EpochTracker::new();
573
574 assert!(tracker.wait_drain(0, Duration::from_millis(10)));
576
577 let _guard = tracker.enter();
579 assert!(!tracker.wait_drain(0, Duration::from_millis(10)));
580 }
581
582 #[test]
583 fn test_hot_reload_creation() {
584 let instance = create_test_instance("test");
585 let manifest = create_test_manifest("test", "1.0.0");
586
587 let plugin = HotReloadablePlugin::new("test", instance, manifest);
588
589 assert_eq!(plugin.name(), "test");
590 assert_eq!(plugin.state(), HotReloadState::Active);
591 }
592
593 #[test]
594 fn test_hot_reload_call() {
595 let instance = create_test_instance("test");
596 let manifest = create_test_manifest("test", "1.0.0");
597
598 let plugin = HotReloadablePlugin::new("test", instance, manifest);
599
600 let result = plugin.call("on_insert", &[]);
602 assert!(result.is_ok());
603 }
604
605 #[test]
606 fn test_hot_reload_prepare() {
607 let instance = create_test_instance("test");
608 let manifest = create_test_manifest("test", "1.0.0");
609
610 let plugin = HotReloadablePlugin::new("test", instance, manifest);
611
612 let new_manifest = create_test_manifest("test", "2.0.0");
613 plugin.prepare_upgrade(b"new wasm", new_manifest).unwrap();
614
615 assert_eq!(plugin.state(), HotReloadState::PreparingUpgrade);
616 }
617
618 #[test]
619 fn test_hot_reload_full_upgrade() {
620 let instance = create_test_instance("test");
621 let manifest = create_test_manifest("test", "1.0.0");
622
623 let plugin = HotReloadablePlugin::new("test", instance, manifest);
624
625 let new_manifest = create_test_manifest("test", "2.0.0");
626 plugin.upgrade(b"new wasm", new_manifest).unwrap();
627
628 assert_eq!(plugin.state(), HotReloadState::Active);
629
630 let stats = plugin.stats();
631 assert_eq!(stats.successful_upgrades, 1);
632 assert_eq!(stats.version, 1);
633 }
634
635 #[test]
636 fn test_hot_reload_cancel() {
637 let instance = create_test_instance("test");
638 let manifest = create_test_manifest("test", "1.0.0");
639
640 let plugin = HotReloadablePlugin::new("test", instance, manifest);
641
642 let new_manifest = create_test_manifest("test", "2.0.0");
643 plugin.prepare_upgrade(b"new wasm", new_manifest).unwrap();
644
645 plugin.cancel_upgrade().unwrap();
646 assert_eq!(plugin.state(), HotReloadState::Active);
647 }
648
649 #[test]
650 fn test_hot_reload_name_mismatch() {
651 let instance = create_test_instance("test");
652 let manifest = create_test_manifest("test", "1.0.0");
653
654 let plugin = HotReloadablePlugin::new("test", instance, manifest);
655
656 let new_manifest = create_test_manifest("different", "2.0.0");
658 let result = plugin.prepare_upgrade(b"new wasm", new_manifest);
659 assert!(result.is_err());
660 }
661
662 #[test]
663 fn test_manager_operations() {
664 let manager = HotReloadManager::new();
665
666 let instance = create_test_instance("plugin1");
667 let manifest = create_test_manifest("plugin1", "1.0.0");
668
669 manager.register("plugin1", instance, manifest).unwrap();
671 assert_eq!(manager.list().len(), 1);
672
673 let plugin = manager.get("plugin1").unwrap();
675 assert_eq!(plugin.name(), "plugin1");
676
677 let new_manifest = create_test_manifest("plugin1", "2.0.0");
679 manager
680 .upgrade("plugin1", b"new wasm", new_manifest)
681 .unwrap();
682
683 let stats = manager.all_stats();
685 assert_eq!(stats.len(), 1);
686 assert_eq!(stats[0].1.successful_upgrades, 1);
687
688 manager.unregister("plugin1").unwrap();
690 assert!(manager.list().is_empty());
691 }
692
693 #[test]
694 fn test_manager_duplicate() {
695 let manager = HotReloadManager::new();
696
697 let instance1 = create_test_instance("dup");
698 let manifest1 = create_test_manifest("dup", "1.0.0");
699 manager.register("dup", instance1, manifest1).unwrap();
700
701 let instance2 = create_test_instance("dup");
702 let manifest2 = create_test_manifest("dup", "1.0.0");
703 let result = manager.register("dup", instance2, manifest2);
704 assert!(result.is_err());
705 }
706
707 #[test]
708 fn test_concurrent_calls_during_upgrade() {
709 use std::sync::atomic::AtomicBool;
710 use std::thread;
711
712 let instance = create_test_instance("concurrent");
713 let manifest = create_test_manifest("concurrent", "1.0.0");
714
715 let plugin = Arc::new(HotReloadablePlugin::new("concurrent", instance, manifest));
716
717 let stop = Arc::new(AtomicBool::new(false));
719
720 let mut handles = vec![];
722 for _ in 0..4 {
723 let p = plugin.clone();
724 let s = stop.clone();
725 handles.push(thread::spawn(move || {
726 let mut calls = 0;
727 while !s.load(Ordering::Relaxed) {
728 let _ = p.call("on_insert", &[]);
729 calls += 1;
730 if calls > 100 {
731 break;
732 }
733 }
734 }));
735 }
736
737 thread::sleep(Duration::from_millis(5));
739 let new_manifest = create_test_manifest("concurrent", "2.0.0");
740 let result = plugin.upgrade(b"new wasm", new_manifest);
741
742 stop.store(true, Ordering::Relaxed);
744 for h in handles {
745 h.join().unwrap();
746 }
747
748 assert!(result.is_ok());
750 assert_eq!(plugin.stats().successful_upgrades, 1);
751 }
752}