1use crate::error::{KernelError, KernelResult};
58use parking_lot::RwLock;
59use std::collections::HashMap;
60use std::sync::atomic::{AtomicU64, Ordering};
61use std::sync::Arc;
62use std::time::{Duration, Instant};
63
64#[derive(Debug, Clone)]
70pub struct RuntimeConfig {
71 pub memory_limit_bytes: u64,
73 pub timeout_ms: u64,
75 pub packages: Vec<String>,
77 pub debug: bool,
79 pub allow_network: bool,
81 pub wheel_urls: Vec<String>,
83}
84
85impl Default for RuntimeConfig {
86 fn default() -> Self {
87 Self {
88 memory_limit_bytes: 64 * 1024 * 1024, timeout_ms: 5000, packages: vec![],
91 debug: false,
92 allow_network: false,
93 wheel_urls: vec![],
94 }
95 }
96}
97
98impl RuntimeConfig {
99 pub fn with_ml_packages() -> Self {
101 Self {
102 packages: vec![
103 "numpy".into(),
104 "pandas".into(),
105 "scikit-learn".into(),
106 ],
107 memory_limit_bytes: 256 * 1024 * 1024, timeout_ms: 30000, ..Default::default()
110 }
111 }
112
113 pub fn lightweight() -> Self {
115 Self {
116 memory_limit_bytes: 16 * 1024 * 1024, timeout_ms: 100, packages: vec![],
119 ..Default::default()
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
130pub enum TriggerEvent {
131 BeforeInsert,
132 AfterInsert,
133 BeforeUpdate,
134 AfterUpdate,
135 BeforeDelete,
136 AfterDelete,
137 OnBatch,
139}
140
141impl TriggerEvent {
142 pub fn from_str(s: &str) -> Option<Self> {
143 match s.to_uppercase().replace(' ', "_").as_str() {
144 "BEFORE_INSERT" => Some(Self::BeforeInsert),
145 "AFTER_INSERT" => Some(Self::AfterInsert),
146 "BEFORE_UPDATE" => Some(Self::BeforeUpdate),
147 "AFTER_UPDATE" => Some(Self::AfterUpdate),
148 "BEFORE_DELETE" => Some(Self::BeforeDelete),
149 "AFTER_DELETE" => Some(Self::AfterDelete),
150 "ON_BATCH" => Some(Self::OnBatch),
151 _ => None,
152 }
153 }
154
155 pub fn handler_name(&self) -> &'static str {
156 match self {
157 Self::BeforeInsert => "on_before_insert",
158 Self::AfterInsert => "on_after_insert",
159 Self::BeforeUpdate => "on_before_update",
160 Self::AfterUpdate => "on_after_update",
161 Self::BeforeDelete => "on_before_delete",
162 Self::AfterDelete => "on_after_delete",
163 Self::OnBatch => "on_batch",
164 }
165 }
166
167 pub fn is_before(&self) -> bool {
168 matches!(self, Self::BeforeInsert | Self::BeforeUpdate | Self::BeforeDelete)
169 }
170}
171
172#[derive(Debug, Clone)]
178pub struct PythonPlugin {
179 pub name: String,
181 pub version: String,
183 pub code: String,
185 pub packages: Vec<String>,
187 pub wheels: Vec<String>,
189 pub triggers: HashMap<String, Vec<TriggerEvent>>,
191 pub config: Option<RuntimeConfig>,
193}
194
195impl PythonPlugin {
196 pub fn new(name: &str) -> Self {
197 Self {
198 name: name.to_string(),
199 version: "1.0.0".to_string(),
200 code: String::new(),
201 packages: vec![],
202 wheels: vec![],
203 triggers: HashMap::new(),
204 config: None,
205 }
206 }
207
208 pub fn with_version(mut self, version: &str) -> Self {
209 self.version = version.to_string();
210 self
211 }
212
213 pub fn with_code(mut self, code: &str) -> Self {
214 self.code = code.to_string();
215 self
216 }
217
218 pub fn with_packages(mut self, packages: Vec<&str>) -> Self {
219 self.packages = packages.into_iter().map(String::from).collect();
220 self
221 }
222
223 pub fn with_trigger(mut self, table: &str, event: TriggerEvent) -> Self {
224 self.triggers
225 .entry(table.to_string())
226 .or_default()
227 .push(event);
228 self
229 }
230
231 pub fn with_config(mut self, config: RuntimeConfig) -> Self {
232 self.config = Some(config);
233 self
234 }
235}
236
237#[derive(Debug, Clone)]
243pub struct TriggerContext {
244 pub table: String,
246 pub event: TriggerEvent,
248 pub row_json: String,
250 pub old_row_json: Option<String>,
252 pub txn_id: u64,
254 pub batch_json: Option<String>,
256}
257
258#[derive(Debug, Clone)]
260pub enum TriggerResult {
261 Continue(Option<String>),
263 Abort { message: String, code: String },
265 Skip,
267 Batch(String),
269}
270
271#[derive(Debug, Default)]
277pub struct RuntimeStats {
278 pub total_executions: AtomicU64,
279 pub total_time_us: AtomicU64,
280 pub errors: AtomicU64,
281 pub aborts: AtomicU64,
282 pub packages_installed: AtomicU64,
283}
284
285pub struct PyodideRuntime {
290 config: RuntimeConfig,
291 plugins: RwLock<HashMap<String, PythonPlugin>>,
293 trigger_map: RwLock<HashMap<(String, TriggerEvent), Vec<String>>>,
295 installed_packages: RwLock<Vec<String>>,
297 stats: Arc<RuntimeStats>,
299 #[allow(dead_code)]
301 instances: RwLock<HashMap<String, PluginInstance>>,
302}
303
304#[allow(dead_code)]
306struct PluginInstance {
307 plugin_name: String,
308 loaded_at: u64,
309 memory_used: u64,
310 call_count: u64,
311}
312
313impl PyodideRuntime {
314 pub fn new(config: RuntimeConfig) -> Self {
316 Self {
317 config,
318 plugins: RwLock::new(HashMap::new()),
319 trigger_map: RwLock::new(HashMap::new()),
320 installed_packages: RwLock::new(vec![]),
321 stats: Arc::new(RuntimeStats::default()),
322 instances: RwLock::new(HashMap::new()),
323 }
324 }
325
326 pub async fn install_packages(&self, packages: &[&str]) -> KernelResult<()> {
330 let mut installed = self.installed_packages.write();
331 for pkg in packages {
332 if !installed.contains(&pkg.to_string()) {
333 if self.config.debug {
335 eprintln!("[Pyodide] Installing package: {}", pkg);
336 }
337 installed.push(pkg.to_string());
338 self.stats.packages_installed.fetch_add(1, Ordering::Relaxed);
339 }
340 }
341 Ok(())
342 }
343
344 pub fn register(&self, plugin: PythonPlugin) -> KernelResult<()> {
346 self.validate_code(&plugin.code)?;
348
349 let name = plugin.name.clone();
351 {
352 let mut plugins = self.plugins.write();
353 plugins.insert(name.clone(), plugin.clone());
354 }
355
356 {
358 let mut trigger_map = self.trigger_map.write();
359 for (table, events) in &plugin.triggers {
360 for event in events {
361 trigger_map
362 .entry((table.clone(), *event))
363 .or_default()
364 .push(name.clone());
365 }
366 }
367 }
368
369 if self.config.debug {
370 eprintln!("[Pyodide] Registered plugin: {}", name);
371 }
372
373 Ok(())
374 }
375
376 pub fn unregister(&self, name: &str) -> KernelResult<()> {
378 let mut plugins = self.plugins.write();
379 if let Some(plugin) = plugins.remove(name) {
380 let mut trigger_map = self.trigger_map.write();
382 for (table, events) in &plugin.triggers {
383 for event in events {
384 if let Some(names) = trigger_map.get_mut(&(table.clone(), *event)) {
385 names.retain(|n| n != name);
386 }
387 }
388 }
389 Ok(())
390 } else {
391 Err(KernelError::Plugin {
392 message: format!("Plugin not found: {}", name),
393 })
394 }
395 }
396
397 pub async fn fire(
399 &self,
400 table: &str,
401 event: TriggerEvent,
402 context: &TriggerContext,
403 ) -> KernelResult<TriggerResult> {
404 let start = Instant::now();
405 self.stats.total_executions.fetch_add(1, Ordering::Relaxed);
406
407 let plugin_names = {
409 let trigger_map = self.trigger_map.read();
410 trigger_map
411 .get(&(table.to_string(), event))
412 .cloned()
413 .unwrap_or_default()
414 };
415
416 if plugin_names.is_empty() {
417 return Ok(TriggerResult::Continue(None));
418 }
419
420 let mut current_row = context.row_json.clone();
422
423 for name in plugin_names {
424 let plugins = self.plugins.read();
425 if let Some(plugin) = plugins.get(&name) {
426 let result = self.execute_plugin(plugin, event, ¤t_row).await?;
427
428 match result {
429 TriggerResult::Continue(Some(modified)) => {
430 current_row = modified;
431 }
432 TriggerResult::Abort { message, code } => {
433 self.stats.aborts.fetch_add(1, Ordering::Relaxed);
434 return Ok(TriggerResult::Abort { message, code });
435 }
436 TriggerResult::Skip => {
437 return Ok(TriggerResult::Skip);
438 }
439 _ => {}
440 }
441 }
442 }
443
444 let elapsed = start.elapsed().as_micros() as u64;
445 self.stats.total_time_us.fetch_add(elapsed, Ordering::Relaxed);
446
447 Ok(TriggerResult::Continue(Some(current_row)))
448 }
449
450 async fn execute_plugin(
452 &self,
453 plugin: &PythonPlugin,
454 event: TriggerEvent,
455 row_json: &str,
456 ) -> KernelResult<TriggerResult> {
457 let timeout = Duration::from_millis(self.config.timeout_ms);
458 let start = Instant::now();
459
460 let result = self.simulate_execution(plugin, event, row_json, timeout)?;
467
468 if self.config.debug {
469 eprintln!(
470 "[Pyodide] {} executed in {:?}",
471 plugin.name,
472 start.elapsed()
473 );
474 }
475
476 Ok(result)
477 }
478
479 fn simulate_execution(
481 &self,
482 plugin: &PythonPlugin,
483 event: TriggerEvent,
484 row_json: &str,
485 timeout: Duration,
486 ) -> KernelResult<TriggerResult> {
487 let start = Instant::now();
488
489 if start.elapsed() > timeout {
491 return Err(KernelError::Plugin {
492 message: "Execution timed out".to_string(),
493 });
494 }
495
496 let code = &plugin.code;
498
499 if code.contains("TriggerAbort") || code.contains("raise") {
501 if code.contains("amount") && code.contains("> 10000") {
503 if row_json.contains("\"amount\":") {
505 if let Some(amount) = self.extract_amount(row_json) {
506 if amount > 10000.0 {
507 return Ok(TriggerResult::Abort {
508 message: "Amount too high".to_string(),
509 code: "LIMIT_EXCEEDED".to_string(),
510 });
511 }
512 }
513 }
514 }
515 }
516
517 if code.contains(".lower()") {
519 let modified = row_json.to_lowercase();
521 return Ok(TriggerResult::Continue(Some(modified)));
522 }
523
524 if event.is_before() {
526 Ok(TriggerResult::Continue(Some(row_json.to_string())))
527 } else {
528 Ok(TriggerResult::Continue(None))
529 }
530 }
531
532 fn extract_amount(&self, json: &str) -> Option<f64> {
533 if let Some(start) = json.find("\"amount\":") {
535 let rest = &json[start + 9..].trim_start();
536 let end = rest.find(|c: char| !c.is_numeric() && c != '.' && c != '-');
537 let num_str = match end {
538 Some(e) => &rest[..e],
539 None => rest,
540 };
541 num_str.trim().parse().ok()
542 } else {
543 None
544 }
545 }
546
547 fn validate_code(&self, code: &str) -> KernelResult<()> {
549 let forbidden = [
551 "__import__('os')",
552 "subprocess",
553 "eval(",
554 "exec(",
555 "compile(",
556 "open(",
557 "__builtins__",
558 ];
559
560 for pattern in forbidden {
561 if code.contains(pattern) {
562 return Err(KernelError::Plugin {
563 message: format!("Forbidden pattern in code: {}", pattern),
564 });
565 }
566 }
567
568 let handlers = ["on_insert", "on_before_insert", "on_after_insert",
570 "on_update", "on_delete", "on_batch", "handler"];
571 if !handlers.iter().any(|h| code.contains(&format!("def {}(", h))) {
572 return Err(KernelError::Plugin {
573 message: "Code must define a handler function".to_string(),
574 });
575 }
576
577 Ok(())
578 }
579
580 pub fn stats(&self) -> &RuntimeStats {
582 &self.stats
583 }
584
585 pub fn list_plugins(&self) -> Vec<String> {
587 self.plugins.read().keys().cloned().collect()
588 }
589}
590
591#[allow(dead_code)]
597pub struct AiTriggerGenerator {
598 model: String,
600 endpoint: Option<String>,
602}
603
604#[allow(dead_code)]
605impl AiTriggerGenerator {
606 pub fn new(model: &str) -> Self {
607 Self {
608 model: model.to_string(),
609 endpoint: None,
610 }
611 }
612
613 pub async fn generate(&self, instruction: &str, table_schema: &str) -> KernelResult<String> {
615 let code = format!(
618 r#"
619# Generated from: {}
620# Table schema: {}
621
622def on_before_insert(row: dict) -> dict:
623 # TODO: Implement validation logic
624 return row
625"#,
626 instruction, table_schema
627 );
628 Ok(code)
629 }
630}
631
632#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_plugin_builder() {
642 let plugin = PythonPlugin::new("test")
643 .with_version("2.0.0")
644 .with_code("def on_insert(row): return row")
645 .with_packages(vec!["numpy", "pandas"])
646 .with_trigger("users", TriggerEvent::BeforeInsert);
647
648 assert_eq!(plugin.name, "test");
649 assert_eq!(plugin.version, "2.0.0");
650 assert!(plugin.packages.contains(&"numpy".to_string()));
651 assert!(plugin.triggers.contains_key("users"));
652 }
653
654 #[test]
655 fn test_runtime_config() {
656 let ml_config = RuntimeConfig::with_ml_packages();
657 assert!(ml_config.packages.contains(&"numpy".to_string()));
658 assert_eq!(ml_config.memory_limit_bytes, 256 * 1024 * 1024);
659
660 let light_config = RuntimeConfig::lightweight();
661 assert_eq!(light_config.timeout_ms, 100);
662 }
663
664 #[tokio::test]
665 async fn test_runtime_register() {
666 let runtime = PyodideRuntime::new(RuntimeConfig::default());
667
668 let plugin = PythonPlugin::new("validator")
669 .with_code("def on_insert(row): return row")
670 .with_trigger("users", TriggerEvent::BeforeInsert);
671
672 runtime.register(plugin).unwrap();
673 assert!(runtime.list_plugins().contains(&"validator".to_string()));
674 }
675
676 #[tokio::test]
677 async fn test_runtime_fire_trigger() {
678 let runtime = PyodideRuntime::new(RuntimeConfig::default());
679
680 let plugin = PythonPlugin::new("amount_check")
681 .with_code(r#"
682def on_insert(row):
683 if row["amount"] > 10000:
684 raise TriggerAbort("Amount too high")
685 return row
686"#)
687 .with_trigger("orders", TriggerEvent::BeforeInsert);
688
689 runtime.register(plugin).unwrap();
690
691 let context = TriggerContext {
693 table: "orders".to_string(),
694 event: TriggerEvent::BeforeInsert,
695 row_json: r#"{"amount": 500}"#.to_string(),
696 old_row_json: None,
697 txn_id: 1,
698 batch_json: None,
699 };
700
701 let result = runtime.fire("orders", TriggerEvent::BeforeInsert, &context).await;
702 assert!(matches!(result, Ok(TriggerResult::Continue(_))));
703
704 let context2 = TriggerContext {
706 table: "orders".to_string(),
707 event: TriggerEvent::BeforeInsert,
708 row_json: r#"{"amount": 50000}"#.to_string(),
709 old_row_json: None,
710 txn_id: 2,
711 batch_json: None,
712 };
713
714 let result2 = runtime.fire("orders", TriggerEvent::BeforeInsert, &context2).await;
715 assert!(matches!(result2, Ok(TriggerResult::Abort { .. })));
716 }
717
718 #[test]
719 fn test_code_validation() {
720 let runtime = PyodideRuntime::new(RuntimeConfig::default());
721
722 assert!(runtime.validate_code("def on_insert(row): return row").is_ok());
724
725 assert!(runtime.validate_code("import subprocess").is_err());
727
728 assert!(runtime.validate_code("x = 42").is_err());
730 }
731}