1use crate::error::{KernelError, KernelResult};
58use parking_lot::RwLock;
59use std::collections::HashMap;
60use std::sync::Arc;
61use std::sync::atomic::{AtomicU64, Ordering};
62use std::time::{Duration, Instant};
63
64#[derive(Debug, Clone)]
70pub struct RuntimeConfig {
71 pub memory_limit_bytes: u64,
73 pub timeout_ms: u64,
75 pub packages: Vec<String>,
77 pub debug: bool,
79 pub allow_network: bool,
81 pub wheel_urls: Vec<String>,
83}
84
85impl Default for RuntimeConfig {
86 fn default() -> Self {
87 Self {
88 memory_limit_bytes: 64 * 1024 * 1024, timeout_ms: 5000, packages: vec![],
91 debug: false,
92 allow_network: false,
93 wheel_urls: vec![],
94 }
95 }
96}
97
98impl RuntimeConfig {
99 pub fn with_ml_packages() -> Self {
101 Self {
102 packages: vec!["numpy".into(), "pandas".into(), "scikit-learn".into()],
103 memory_limit_bytes: 256 * 1024 * 1024, timeout_ms: 30000, ..Default::default()
106 }
107 }
108
109 pub fn lightweight() -> Self {
111 Self {
112 memory_limit_bytes: 16 * 1024 * 1024, timeout_ms: 100, packages: vec![],
115 ..Default::default()
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum TriggerEvent {
127 BeforeInsert,
128 AfterInsert,
129 BeforeUpdate,
130 AfterUpdate,
131 BeforeDelete,
132 AfterDelete,
133 OnBatch,
135}
136
137impl TriggerEvent {
138 pub fn from_str(s: &str) -> Option<Self> {
139 match s.to_uppercase().replace(' ', "_").as_str() {
140 "BEFORE_INSERT" => Some(Self::BeforeInsert),
141 "AFTER_INSERT" => Some(Self::AfterInsert),
142 "BEFORE_UPDATE" => Some(Self::BeforeUpdate),
143 "AFTER_UPDATE" => Some(Self::AfterUpdate),
144 "BEFORE_DELETE" => Some(Self::BeforeDelete),
145 "AFTER_DELETE" => Some(Self::AfterDelete),
146 "ON_BATCH" => Some(Self::OnBatch),
147 _ => None,
148 }
149 }
150
151 pub fn handler_name(&self) -> &'static str {
152 match self {
153 Self::BeforeInsert => "on_before_insert",
154 Self::AfterInsert => "on_after_insert",
155 Self::BeforeUpdate => "on_before_update",
156 Self::AfterUpdate => "on_after_update",
157 Self::BeforeDelete => "on_before_delete",
158 Self::AfterDelete => "on_after_delete",
159 Self::OnBatch => "on_batch",
160 }
161 }
162
163 pub fn is_before(&self) -> bool {
164 matches!(
165 self,
166 Self::BeforeInsert | Self::BeforeUpdate | Self::BeforeDelete
167 )
168 }
169}
170
171#[derive(Debug, Clone)]
177pub struct PythonPlugin {
178 pub name: String,
180 pub version: String,
182 pub code: String,
184 pub packages: Vec<String>,
186 pub wheels: Vec<String>,
188 pub triggers: HashMap<String, Vec<TriggerEvent>>,
190 pub config: Option<RuntimeConfig>,
192}
193
194impl PythonPlugin {
195 pub fn new(name: &str) -> Self {
196 Self {
197 name: name.to_string(),
198 version: "1.0.0".to_string(),
199 code: String::new(),
200 packages: vec![],
201 wheels: vec![],
202 triggers: HashMap::new(),
203 config: None,
204 }
205 }
206
207 pub fn with_version(mut self, version: &str) -> Self {
208 self.version = version.to_string();
209 self
210 }
211
212 pub fn with_code(mut self, code: &str) -> Self {
213 self.code = code.to_string();
214 self
215 }
216
217 pub fn with_packages(mut self, packages: Vec<&str>) -> Self {
218 self.packages = packages.into_iter().map(String::from).collect();
219 self
220 }
221
222 pub fn with_trigger(mut self, table: &str, event: TriggerEvent) -> Self {
223 self.triggers
224 .entry(table.to_string())
225 .or_default()
226 .push(event);
227 self
228 }
229
230 pub fn with_config(mut self, config: RuntimeConfig) -> Self {
231 self.config = Some(config);
232 self
233 }
234}
235
236#[derive(Debug, Clone)]
242pub struct TriggerContext {
243 pub table: String,
245 pub event: TriggerEvent,
247 pub row_json: String,
249 pub old_row_json: Option<String>,
251 pub txn_id: u64,
253 pub batch_json: Option<String>,
255}
256
257#[derive(Debug, Clone)]
259pub enum TriggerResult {
260 Continue(Option<String>),
262 Abort { message: String, code: String },
264 Skip,
266 Batch(String),
268}
269
270#[derive(Debug, Default)]
276pub struct RuntimeStats {
277 pub total_executions: AtomicU64,
278 pub total_time_us: AtomicU64,
279 pub errors: AtomicU64,
280 pub aborts: AtomicU64,
281 pub packages_installed: AtomicU64,
282}
283
284pub struct PyodideRuntime {
289 config: RuntimeConfig,
290 plugins: RwLock<HashMap<String, PythonPlugin>>,
292 trigger_map: RwLock<HashMap<(String, TriggerEvent), Vec<String>>>,
294 installed_packages: RwLock<Vec<String>>,
296 stats: Arc<RuntimeStats>,
298 #[allow(dead_code)]
300 instances: RwLock<HashMap<String, PluginInstance>>,
301}
302
303#[allow(dead_code)]
305struct PluginInstance {
306 plugin_name: String,
307 loaded_at: u64,
308 memory_used: u64,
309 call_count: u64,
310}
311
312impl PyodideRuntime {
313 pub fn new(config: RuntimeConfig) -> Self {
315 Self {
316 config,
317 plugins: RwLock::new(HashMap::new()),
318 trigger_map: RwLock::new(HashMap::new()),
319 installed_packages: RwLock::new(vec![]),
320 stats: Arc::new(RuntimeStats::default()),
321 instances: RwLock::new(HashMap::new()),
322 }
323 }
324
325 pub async fn install_packages(&self, packages: &[&str]) -> KernelResult<()> {
329 let mut installed = self.installed_packages.write();
330 for pkg in packages {
331 if !installed.contains(&pkg.to_string()) {
332 if self.config.debug {
334 eprintln!("[Pyodide] Installing package: {}", pkg);
335 }
336 installed.push(pkg.to_string());
337 self.stats
338 .packages_installed
339 .fetch_add(1, Ordering::Relaxed);
340 }
341 }
342 Ok(())
343 }
344
345 pub fn register(&self, plugin: PythonPlugin) -> KernelResult<()> {
347 self.validate_code(&plugin.code)?;
349
350 let name = plugin.name.clone();
352 {
353 let mut plugins = self.plugins.write();
354 plugins.insert(name.clone(), plugin.clone());
355 }
356
357 {
359 let mut trigger_map = self.trigger_map.write();
360 for (table, events) in &plugin.triggers {
361 for event in events {
362 trigger_map
363 .entry((table.clone(), *event))
364 .or_default()
365 .push(name.clone());
366 }
367 }
368 }
369
370 if self.config.debug {
371 eprintln!("[Pyodide] Registered plugin: {}", name);
372 }
373
374 Ok(())
375 }
376
377 pub fn unregister(&self, name: &str) -> KernelResult<()> {
379 let mut plugins = self.plugins.write();
380 if let Some(plugin) = plugins.remove(name) {
381 let mut trigger_map = self.trigger_map.write();
383 for (table, events) in &plugin.triggers {
384 for event in events {
385 if let Some(names) = trigger_map.get_mut(&(table.clone(), *event)) {
386 names.retain(|n| n != name);
387 }
388 }
389 }
390 Ok(())
391 } else {
392 Err(KernelError::Plugin {
393 message: format!("Plugin not found: {}", name),
394 })
395 }
396 }
397
398 pub async fn fire(
400 &self,
401 table: &str,
402 event: TriggerEvent,
403 context: &TriggerContext,
404 ) -> KernelResult<TriggerResult> {
405 let start = Instant::now();
406 self.stats.total_executions.fetch_add(1, Ordering::Relaxed);
407
408 let plugin_names = {
410 let trigger_map = self.trigger_map.read();
411 trigger_map
412 .get(&(table.to_string(), event))
413 .cloned()
414 .unwrap_or_default()
415 };
416
417 if plugin_names.is_empty() {
418 return Ok(TriggerResult::Continue(None));
419 }
420
421 let mut current_row = context.row_json.clone();
423
424 for name in plugin_names {
425 let plugins = self.plugins.read();
426 if let Some(plugin) = plugins.get(&name) {
427 let result = self.execute_plugin(plugin, event, ¤t_row).await?;
428
429 match result {
430 TriggerResult::Continue(Some(modified)) => {
431 current_row = modified;
432 }
433 TriggerResult::Abort { message, code } => {
434 self.stats.aborts.fetch_add(1, Ordering::Relaxed);
435 return Ok(TriggerResult::Abort { message, code });
436 }
437 TriggerResult::Skip => {
438 return Ok(TriggerResult::Skip);
439 }
440 _ => {}
441 }
442 }
443 }
444
445 let elapsed = start.elapsed().as_micros() as u64;
446 self.stats
447 .total_time_us
448 .fetch_add(elapsed, Ordering::Relaxed);
449
450 Ok(TriggerResult::Continue(Some(current_row)))
451 }
452
453 async fn execute_plugin(
455 &self,
456 plugin: &PythonPlugin,
457 event: TriggerEvent,
458 row_json: &str,
459 ) -> KernelResult<TriggerResult> {
460 let timeout = Duration::from_millis(self.config.timeout_ms);
461 let start = Instant::now();
462
463 let result = self.simulate_execution(plugin, event, row_json, timeout)?;
470
471 if self.config.debug {
472 eprintln!(
473 "[Pyodide] {} executed in {:?}",
474 plugin.name,
475 start.elapsed()
476 );
477 }
478
479 Ok(result)
480 }
481
482 fn simulate_execution(
484 &self,
485 plugin: &PythonPlugin,
486 event: TriggerEvent,
487 row_json: &str,
488 timeout: Duration,
489 ) -> KernelResult<TriggerResult> {
490 let start = Instant::now();
491
492 if start.elapsed() > timeout {
494 return Err(KernelError::Plugin {
495 message: "Execution timed out".to_string(),
496 });
497 }
498
499 let code = &plugin.code;
501
502 if code.contains("TriggerAbort") || code.contains("raise") {
504 if code.contains("amount") && code.contains("> 10000") {
506 if row_json.contains("\"amount\":") {
508 if let Some(amount) = self.extract_amount(row_json) {
509 if amount > 10000.0 {
510 return Ok(TriggerResult::Abort {
511 message: "Amount too high".to_string(),
512 code: "LIMIT_EXCEEDED".to_string(),
513 });
514 }
515 }
516 }
517 }
518 }
519
520 if code.contains(".lower()") {
522 let modified = row_json.to_lowercase();
524 return Ok(TriggerResult::Continue(Some(modified)));
525 }
526
527 if event.is_before() {
529 Ok(TriggerResult::Continue(Some(row_json.to_string())))
530 } else {
531 Ok(TriggerResult::Continue(None))
532 }
533 }
534
535 fn extract_amount(&self, json: &str) -> Option<f64> {
536 if let Some(start) = json.find("\"amount\":") {
538 let rest = &json[start + 9..].trim_start();
539 let end = rest.find(|c: char| !c.is_numeric() && c != '.' && c != '-');
540 let num_str = match end {
541 Some(e) => &rest[..e],
542 None => rest,
543 };
544 num_str.trim().parse().ok()
545 } else {
546 None
547 }
548 }
549
550 fn validate_code(&self, code: &str) -> KernelResult<()> {
552 let forbidden = [
554 "__import__('os')",
555 "subprocess",
556 "eval(",
557 "exec(",
558 "compile(",
559 "open(",
560 "__builtins__",
561 ];
562
563 for pattern in forbidden {
564 if code.contains(pattern) {
565 return Err(KernelError::Plugin {
566 message: format!("Forbidden pattern in code: {}", pattern),
567 });
568 }
569 }
570
571 let handlers = [
573 "on_insert",
574 "on_before_insert",
575 "on_after_insert",
576 "on_update",
577 "on_delete",
578 "on_batch",
579 "handler",
580 ];
581 if !handlers
582 .iter()
583 .any(|h| code.contains(&format!("def {}(", h)))
584 {
585 return Err(KernelError::Plugin {
586 message: "Code must define a handler function".to_string(),
587 });
588 }
589
590 Ok(())
591 }
592
593 pub fn stats(&self) -> &RuntimeStats {
595 &self.stats
596 }
597
598 pub fn list_plugins(&self) -> Vec<String> {
600 self.plugins.read().keys().cloned().collect()
601 }
602}
603
604#[allow(dead_code)]
610pub struct AiTriggerGenerator {
611 model: String,
613 endpoint: Option<String>,
615}
616
617#[allow(dead_code)]
618impl AiTriggerGenerator {
619 pub fn new(model: &str) -> Self {
620 Self {
621 model: model.to_string(),
622 endpoint: None,
623 }
624 }
625
626 pub async fn generate(&self, instruction: &str, table_schema: &str) -> KernelResult<String> {
628 let code = format!(
631 r#"
632# Generated from: {}
633# Table schema: {}
634
635def on_before_insert(row: dict) -> dict:
636 # TODO: Implement validation logic
637 return row
638"#,
639 instruction, table_schema
640 );
641 Ok(code)
642 }
643}
644
645#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_plugin_builder() {
655 let plugin = PythonPlugin::new("test")
656 .with_version("2.0.0")
657 .with_code("def on_insert(row): return row")
658 .with_packages(vec!["numpy", "pandas"])
659 .with_trigger("users", TriggerEvent::BeforeInsert);
660
661 assert_eq!(plugin.name, "test");
662 assert_eq!(plugin.version, "2.0.0");
663 assert!(plugin.packages.contains(&"numpy".to_string()));
664 assert!(plugin.triggers.contains_key("users"));
665 }
666
667 #[test]
668 fn test_runtime_config() {
669 let ml_config = RuntimeConfig::with_ml_packages();
670 assert!(ml_config.packages.contains(&"numpy".to_string()));
671 assert_eq!(ml_config.memory_limit_bytes, 256 * 1024 * 1024);
672
673 let light_config = RuntimeConfig::lightweight();
674 assert_eq!(light_config.timeout_ms, 100);
675 }
676
677 #[tokio::test]
678 async fn test_runtime_register() {
679 let runtime = PyodideRuntime::new(RuntimeConfig::default());
680
681 let plugin = PythonPlugin::new("validator")
682 .with_code("def on_insert(row): return row")
683 .with_trigger("users", TriggerEvent::BeforeInsert);
684
685 runtime.register(plugin).unwrap();
686 assert!(runtime.list_plugins().contains(&"validator".to_string()));
687 }
688
689 #[tokio::test]
690 async fn test_runtime_fire_trigger() {
691 let runtime = PyodideRuntime::new(RuntimeConfig::default());
692
693 let plugin = PythonPlugin::new("amount_check")
694 .with_code(
695 r#"
696def on_insert(row):
697 if row["amount"] > 10000:
698 raise TriggerAbort("Amount too high")
699 return row
700"#,
701 )
702 .with_trigger("orders", TriggerEvent::BeforeInsert);
703
704 runtime.register(plugin).unwrap();
705
706 let context = TriggerContext {
708 table: "orders".to_string(),
709 event: TriggerEvent::BeforeInsert,
710 row_json: r#"{"amount": 500}"#.to_string(),
711 old_row_json: None,
712 txn_id: 1,
713 batch_json: None,
714 };
715
716 let result = runtime
717 .fire("orders", TriggerEvent::BeforeInsert, &context)
718 .await;
719 assert!(matches!(result, Ok(TriggerResult::Continue(_))));
720
721 let context2 = TriggerContext {
723 table: "orders".to_string(),
724 event: TriggerEvent::BeforeInsert,
725 row_json: r#"{"amount": 50000}"#.to_string(),
726 old_row_json: None,
727 txn_id: 2,
728 batch_json: None,
729 };
730
731 let result2 = runtime
732 .fire("orders", TriggerEvent::BeforeInsert, &context2)
733 .await;
734 assert!(matches!(result2, Ok(TriggerResult::Abort { .. })));
735 }
736
737 #[test]
738 fn test_code_validation() {
739 let runtime = PyodideRuntime::new(RuntimeConfig::default());
740
741 assert!(
743 runtime
744 .validate_code("def on_insert(row): return row")
745 .is_ok()
746 );
747
748 assert!(runtime.validate_code("import subprocess").is_err());
750
751 assert!(runtime.validate_code("x = 42").is_err());
753 }
754}