1use crate::{ZoeyError, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet};
18use std::path::{Path, PathBuf};
19use std::process::{Command, Stdio};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::RwLock;
23use tokio::time::timeout;
24use tracing::{debug, error, info, instrument, warn};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28pub enum MLFramework {
29 PyTorch,
31 TensorFlow,
33 Custom,
35}
36
37impl MLFramework {
38 pub fn package_name(&self) -> &str {
40 match self {
41 MLFramework::PyTorch => "torch",
42 MLFramework::TensorFlow => "tensorflow",
43 MLFramework::Custom => "custom",
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SecurityConfig {
51 pub allowed_script_dirs: Vec<PathBuf>,
53
54 pub max_timeout_secs: u64,
56
57 pub max_code_length: usize,
59
60 pub allow_direct_code: bool,
62
63 pub blocked_modules: HashSet<String>,
65
66 pub blocked_operations: HashSet<String>,
68}
69
70impl Default for SecurityConfig {
71 fn default() -> Self {
72 let mut blocked_modules = HashSet::new();
73 blocked_modules.insert("os".to_string());
74 blocked_modules.insert("sys".to_string());
75 blocked_modules.insert("subprocess".to_string());
76 blocked_modules.insert("shutil".to_string());
77 blocked_modules.insert("socket".to_string());
78 blocked_modules.insert("urllib".to_string());
79 blocked_modules.insert("requests".to_string());
80 blocked_modules.insert("http".to_string());
81 blocked_modules.insert("ftplib".to_string());
82 blocked_modules.insert("smtplib".to_string());
83 blocked_modules.insert("pickle".to_string());
84 blocked_modules.insert("marshal".to_string());
85 blocked_modules.insert("ctypes".to_string());
86 blocked_modules.insert("__builtin__".to_string());
87 blocked_modules.insert("builtins".to_string());
88
89 let mut blocked_operations = HashSet::new();
90 blocked_operations.insert("eval".to_string());
91 blocked_operations.insert("exec".to_string());
92 blocked_operations.insert("compile".to_string());
93 blocked_operations.insert("open".to_string());
94 blocked_operations.insert("file".to_string());
95 blocked_operations.insert("input".to_string());
96 blocked_operations.insert("raw_input".to_string());
97 blocked_operations.insert("__import__".to_string());
98 blocked_operations.insert("reload".to_string());
99 blocked_operations.insert("execfile".to_string());
100
101 Self {
102 allowed_script_dirs: vec![],
103 max_timeout_secs: 300, max_code_length: 1_000_000, allow_direct_code: false, blocked_modules,
107 blocked_operations,
108 }
109 }
110}
111
112impl SecurityConfig {
113 pub fn strict() -> Self {
115 Self {
116 max_timeout_secs: 60, max_code_length: 100_000, allow_direct_code: false,
119 ..Default::default()
120 }
121 }
122
123 pub fn permissive() -> Self {
125 Self {
126 allowed_script_dirs: vec![],
127 max_timeout_secs: 3600, max_code_length: 10_000_000, allow_direct_code: true,
130 blocked_modules: HashSet::new(),
131 blocked_operations: HashSet::new(),
132 }
133 }
134
135 pub fn with_allowed_dir(mut self, dir: PathBuf) -> Self {
137 self.allowed_script_dirs.push(dir);
138 self
139 }
140
141 pub fn is_path_allowed(&self, path: &Path) -> bool {
143 if self.allowed_script_dirs.is_empty() {
144 return true;
146 }
147
148 for allowed_dir in &self.allowed_script_dirs {
150 if path.starts_with(allowed_dir) {
151 return true;
152 }
153 }
154
155 false
156 }
157
158 pub fn validate_code(&self, code: &str) -> Result<()> {
160 if code.len() > self.max_code_length {
162 return Err(ZoeyError::Runtime(format!(
163 "Code too long: {} bytes (max: {})",
164 code.len(),
165 self.max_code_length
166 )));
167 }
168
169 for blocked_op in &self.blocked_operations {
171 if code.contains(blocked_op) {
172 return Err(ZoeyError::Runtime(format!(
173 "Blocked operation detected: {}",
174 blocked_op
175 )));
176 }
177 }
178
179 for blocked_mod in &self.blocked_modules {
181 let import_patterns = vec![
182 format!("import {}", blocked_mod),
183 format!("from {} import", blocked_mod),
184 format!(
185 "import {}",
186 blocked_mod.replace("__builtin__", "__builtins__")
187 ),
188 ];
189
190 for pattern in import_patterns {
191 if code.contains(&pattern) {
192 return Err(ZoeyError::Runtime(format!(
193 "Blocked module import detected: {}",
194 blocked_mod
195 )));
196 }
197 }
198 }
199
200 let dangerous_patterns = vec![
202 "__import__",
203 "eval(",
204 "exec(",
205 "compile(",
206 "open(",
207 "file(",
208 "input(",
209 "raw_input(",
210 "reload(",
211 "execfile(",
212 "subprocess",
213 "os.system",
214 "os.popen",
215 "os.exec",
216 "shutil.",
217 "socket.",
218 "urllib.",
219 "requests.",
220 "pickle.load",
221 "marshal.load",
222 "ctypes.",
223 ];
224
225 for pattern in dangerous_patterns {
226 if code.contains(pattern) {
227 return Err(ZoeyError::Runtime(format!(
228 "Dangerous pattern detected: {}",
229 pattern
230 )));
231 }
232 }
233
234 Ok(())
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct PythonEnvironment {
241 pub python_path: String,
243
244 pub venv_path: Option<PathBuf>,
246
247 pub env_vars: HashMap<String, String>,
249
250 pub working_dir: Option<PathBuf>,
252
253 pub security: SecurityConfig,
255}
256
257impl Default for PythonEnvironment {
258 fn default() -> Self {
259 Self {
260 python_path: "python3".to_string(),
261 venv_path: None,
262 env_vars: HashMap::new(),
263 working_dir: None,
264 security: SecurityConfig::default(),
265 }
266 }
267}
268
269impl PythonEnvironment {
270 pub fn new(python_path: String) -> Self {
272 Self {
273 python_path,
274 ..Default::default()
275 }
276 }
277
278 pub fn with_venv(mut self, venv_path: PathBuf) -> Self {
280 self.venv_path = Some(venv_path);
281 self
282 }
283
284 pub fn with_env_var(mut self, key: String, value: String) -> Self {
286 self.env_vars.insert(key, value);
287 self
288 }
289
290 pub fn with_working_dir(mut self, dir: PathBuf) -> Self {
292 self.working_dir = Some(dir);
293 self
294 }
295
296 pub fn with_security(mut self, security: SecurityConfig) -> Self {
298 self.security = security;
299 self
300 }
301
302 pub async fn check_availability(&self) -> Result<bool> {
304 let output = Command::new(&self.python_path).arg("--version").output();
305
306 match output {
307 Ok(output) => {
308 if output.status.success() {
309 let version = String::from_utf8_lossy(&output.stdout);
310 debug!("Python available: {}", version.trim());
311 Ok(true)
312 } else {
313 warn!("Python command failed: {:?}", output.stderr);
314 Ok(false)
315 }
316 }
317 Err(e) => {
318 warn!("Failed to execute Python: {}", e);
319 Ok(false)
320 }
321 }
322 }
323
324 #[instrument(skip(self, script_path), level = "debug")]
326 pub async fn run_script(&self, script_path: &Path, args: &[&str]) -> Result<String> {
327 let script_path = if script_path.exists() {
329 script_path.canonicalize().map_err(|e| {
330 ZoeyError::Runtime(format!("Failed to canonicalize script path: {}", e))
331 })?
332 } else {
333 if let Some(parent) = script_path.parent() {
335 if let Ok(canonical_parent) = parent.canonicalize() {
336 if !self.security.is_path_allowed(&canonical_parent) {
337 return Err(ZoeyError::Runtime(format!(
338 "Script path not allowed: {:?}. Allowed directories: {:?}",
339 script_path, self.security.allowed_script_dirs
340 )));
341 }
342 }
343 }
344 script_path.to_path_buf()
345 };
346
347 if !self.security.allowed_script_dirs.is_empty()
348 && !self.security.is_path_allowed(&script_path)
349 {
350 return Err(ZoeyError::Runtime(format!(
351 "Script path not allowed: {:?}. Allowed directories: {:?}",
352 script_path, self.security.allowed_script_dirs
353 )));
354 }
355
356 if let Ok(contents) = std::fs::read_to_string(&script_path) {
358 self.security.validate_code(&contents)?;
359 }
360
361 for arg in args {
363 if arg.contains(";") || arg.contains("&") || arg.contains("|") || arg.contains("`") {
364 return Err(ZoeyError::Runtime(format!(
365 "Dangerous argument detected: {}",
366 arg
367 )));
368 }
369 }
370
371 let mut cmd = Command::new(&self.python_path);
372 cmd.arg(&script_path);
373
374 for arg in args {
375 cmd.arg(arg);
376 }
377
378 for (key, value) in &self.env_vars {
380 if key == "PATH" || key == "LD_LIBRARY_PATH" || key == "PYTHONPATH" {
382 warn!("Blocked dangerous environment variable: {}", key);
383 continue;
384 }
385 cmd.env(key, value);
386 }
387
388 cmd.env("PYTHONUNBUFFERED", "1");
390 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
391
392 if let Some(ref dir) = self.working_dir {
394 cmd.current_dir(dir);
395 }
396
397 if let Some(ref venv) = self.venv_path {
399 let venv_python = venv.join("bin").join("python");
400 if venv_python.exists() {
401 cmd = Command::new(&venv_python);
402 cmd.arg(&script_path);
403 for arg in args {
404 cmd.arg(arg);
405 }
406 }
407 }
408
409 debug!("Executing Python script: {:?}", script_path);
410
411 let timeout_duration = Duration::from_secs(self.security.max_timeout_secs);
412 let python_path = self.python_path.clone();
413 let script_path_clone = script_path.clone();
414 let args_vec: Vec<String> = args.iter().map(|s| s.to_string()).collect();
415 let env_vars_clone: HashMap<String, String> = self
416 .env_vars
417 .iter()
418 .filter(|(k, _)| {
419 k.as_str() != "PATH"
420 && k.as_str() != "LD_LIBRARY_PATH"
421 && k.as_str() != "PYTHONPATH"
422 })
423 .map(|(k, v)| (k.clone(), v.clone()))
424 .collect();
425 let venv_path_clone = self.venv_path.clone();
426 let working_dir_clone = self.working_dir.clone();
427
428 let output_future = tokio::task::spawn_blocking(move || {
429 let mut cmd = Command::new(&python_path);
430 cmd.arg(&script_path_clone);
431 for arg in &args_vec {
432 cmd.arg(arg);
433 }
434 for (key, value) in &env_vars_clone {
435 cmd.env(key, value);
436 }
437 cmd.env("PYTHONUNBUFFERED", "1");
438 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
439 if let Some(ref dir) = working_dir_clone {
440 cmd.current_dir(dir);
441 }
442 if let Some(ref venv) = venv_path_clone {
443 let venv_python = venv.join("bin").join("python");
444 if venv_python.exists() {
445 cmd = Command::new(&venv_python);
446 cmd.arg(&script_path_clone);
447 for arg in &args_vec {
448 cmd.arg(arg);
449 }
450 for (key, value) in &env_vars_clone {
451 cmd.env(key, value);
452 }
453 cmd.env("PYTHONUNBUFFERED", "1");
454 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
455 }
456 }
457 cmd.stdout(Stdio::piped()).stderr(Stdio::piped()).output()
458 });
459
460 let output = timeout(timeout_duration, output_future)
461 .await
462 .map_err(|_| {
463 ZoeyError::Runtime(format!(
464 "Python script execution timed out after {} seconds",
465 self.security.max_timeout_secs
466 ))
467 })?
468 .map_err(|e| ZoeyError::Runtime(format!("Failed to spawn Python process: {}", e)))?
469 .map_err(|e| ZoeyError::Runtime(format!("Failed to execute Python script: {}", e)))?;
470
471 if output.status.success() {
472 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
473 debug!("Python script output: {} bytes", stdout.len());
474 Ok(stdout)
475 } else {
476 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
477 error!("Python script failed: {}", stderr);
478 Err(ZoeyError::Runtime(format!(
479 "Python script failed: {}",
480 stderr
481 )))
482 }
483 }
484
485 #[instrument(skip(self, code), level = "debug")]
491 pub async fn run_code(&self, code: &str) -> Result<String> {
492 if !self.security.allow_direct_code {
494 return Err(ZoeyError::Runtime(
495 "Direct code execution is disabled for security. Use run_script() with a whitelisted script instead.".to_string()
496 ));
497 }
498
499 self.security.validate_code(code)?;
501
502 let mut cmd = Command::new(&self.python_path);
503 cmd.arg("-c").arg(code);
504
505 for (key, value) in &self.env_vars {
507 if key == "PATH" || key == "LD_LIBRARY_PATH" || key == "PYTHONPATH" {
509 warn!("Blocked dangerous environment variable: {}", key);
510 continue;
511 }
512 cmd.env(key, value);
513 }
514
515 cmd.env("PYTHONUNBUFFERED", "1");
517 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
518
519 if let Some(ref venv) = self.venv_path {
520 let venv_python = venv.join("bin").join("python");
521 if venv_python.exists() {
522 cmd = Command::new(&venv_python);
523 cmd.arg("-c").arg(code);
524 }
525 }
526
527 debug!("Executing Python code: {} bytes", code.len());
528
529 let timeout_duration = Duration::from_secs(self.security.max_timeout_secs);
530 let python_path = self.python_path.clone();
531 let code_clone = code.to_string();
532 let env_vars_clone: HashMap<String, String> = self
533 .env_vars
534 .iter()
535 .filter(|(k, _)| {
536 k.as_str() != "PATH"
537 && k.as_str() != "LD_LIBRARY_PATH"
538 && k.as_str() != "PYTHONPATH"
539 })
540 .map(|(k, v)| (k.clone(), v.clone()))
541 .collect();
542 let venv_path_clone = self.venv_path.clone();
543
544 let output_future = tokio::task::spawn_blocking(move || {
545 let mut cmd = Command::new(&python_path);
546 cmd.arg("-c").arg(&code_clone);
547 for (key, value) in &env_vars_clone {
548 cmd.env(key, value);
549 }
550 cmd.env("PYTHONUNBUFFERED", "1");
551 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
552 if let Some(ref venv) = venv_path_clone {
553 let venv_python = venv.join("bin").join("python");
554 if venv_python.exists() {
555 cmd = Command::new(&venv_python);
556 cmd.arg("-c").arg(&code_clone);
557 for (key, value) in &env_vars_clone {
558 cmd.env(key, value);
559 }
560 cmd.env("PYTHONUNBUFFERED", "1");
561 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
562 }
563 }
564 cmd.stdout(Stdio::piped()).stderr(Stdio::piped()).output()
565 });
566
567 let output = timeout(timeout_duration, output_future)
568 .await
569 .map_err(|_| {
570 ZoeyError::Runtime(format!(
571 "Python code execution timed out after {} seconds",
572 self.security.max_timeout_secs
573 ))
574 })?
575 .map_err(|e| ZoeyError::Runtime(format!("Failed to spawn Python process: {}", e)))?
576 .map_err(|e| ZoeyError::Runtime(format!("Failed to execute Python code: {}", e)))?;
577
578 if output.status.success() {
579 Ok(String::from_utf8_lossy(&output.stdout).to_string())
580 } else {
581 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
582 Err(ZoeyError::Runtime(format!(
583 "Python code failed: {}",
584 stderr
585 )))
586 }
587 }
588
589 pub async fn check_package(&self, package: &str) -> Result<bool> {
591 if package.contains(";")
593 || package.contains("&")
594 || package.contains("|")
595 || package.contains("`")
596 {
597 return Err(ZoeyError::Runtime(format!(
598 "Invalid package name: {}",
599 package
600 )));
601 }
602
603 if self.security.blocked_modules.contains(package) {
605 return Err(ZoeyError::Runtime(format!(
606 "Package is blocked: {}",
607 package
608 )));
609 }
610
611 let mut cmd = Command::new(&self.python_path);
613 cmd.arg("-c");
614 cmd.arg(format!("import {}; print('installed')", package));
615
616 cmd.env("PYTHONUNBUFFERED", "1");
618 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
619
620 if let Some(ref venv) = self.venv_path {
621 let venv_python = venv.join("bin").join("python");
622 if venv_python.exists() {
623 cmd = Command::new(&venv_python);
624 cmd.arg("-c");
625 cmd.arg(format!("import {}; print('installed')", package));
626 }
627 }
628
629 let timeout_duration = Duration::from_secs(10); let python_path = self.python_path.clone();
631 let package_clone = package.to_string();
632 let venv_path_clone = self.venv_path.clone();
633
634 let output_future = tokio::task::spawn_blocking(move || {
635 let mut cmd = Command::new(&python_path);
636 cmd.arg("-c");
637 cmd.arg(format!("import {}; print('installed')", package_clone));
638 cmd.env("PYTHONUNBUFFERED", "1");
639 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
640 if let Some(ref venv) = venv_path_clone {
641 let venv_python = venv.join("bin").join("python");
642 if venv_python.exists() {
643 cmd = Command::new(&venv_python);
644 cmd.arg("-c");
645 cmd.arg(format!("import {}; print('installed')", package_clone));
646 cmd.env("PYTHONUNBUFFERED", "1");
647 cmd.env("PYTHONDONTWRITEBYTECODE", "1");
648 }
649 }
650 cmd.stdout(Stdio::piped()).stderr(Stdio::piped()).output()
651 });
652
653 match timeout(timeout_duration, output_future).await {
654 Ok(join_result) => match join_result {
655 Ok(output_result) => match output_result {
656 Ok(output) => {
657 if output.status.success() {
658 let stdout = String::from_utf8_lossy(&output.stdout);
659 Ok(stdout.contains("installed"))
660 } else {
661 Ok(false)
662 }
663 }
664 Err(_) => Ok(false),
665 },
666 Err(_) => Ok(false),
667 },
668 Err(_) => Ok(false),
669 }
670 }
671}
672
673#[async_trait::async_trait]
675pub trait ModelInterface: Send + Sync {
676 fn name(&self) -> &str;
678
679 fn framework(&self) -> MLFramework;
681
682 async fn load(&mut self, path: &Path) -> Result<()>;
684
685 async fn save(&self, path: &Path) -> Result<()>;
687
688 async fn predict(&self, input: &[f32]) -> Result<Vec<f32>>;
690
691 fn metadata(&self) -> HashMap<String, String>;
693}
694
695#[derive(Debug, Clone)]
697pub struct TrainedModel {
698 pub name: String,
700
701 pub framework: MLFramework,
703
704 pub path: PathBuf,
706
707 pub metadata: HashMap<String, String>,
709}
710
711impl TrainedModel {
712 pub fn new(name: String, framework: MLFramework, path: PathBuf) -> Self {
714 Self {
715 name,
716 framework,
717 path,
718 metadata: HashMap::new(),
719 }
720 }
721
722 pub fn with_metadata(mut self, key: String, value: String) -> Self {
724 self.metadata.insert(key, value);
725 self
726 }
727}
728
729pub struct MLBridge {
731 python_env: PythonEnvironment,
733
734 models: Arc<RwLock<HashMap<String, TrainedModel>>>,
736
737 frameworks: Arc<RwLock<HashMap<MLFramework, bool>>>,
739}
740
741impl MLBridge {
742 pub fn new(python_env: PythonEnvironment) -> Self {
744 Self {
745 python_env,
746 models: Arc::new(RwLock::new(HashMap::new())),
747 frameworks: Arc::new(RwLock::new(HashMap::new())),
748 }
749 }
750
751 pub fn python_env(&self) -> &PythonEnvironment {
753 &self.python_env
754 }
755
756 #[instrument(skip(self), level = "info")]
758 pub async fn check_framework(&self, framework: MLFramework) -> Result<bool> {
759 {
761 let cache = self.frameworks.read().await;
762 if let Some(&available) = cache.get(&framework) {
763 return Ok(available);
764 }
765 }
766
767 let available = self
769 .python_env
770 .check_package(framework.package_name())
771 .await?;
772
773 {
775 let mut cache = self.frameworks.write().await;
776 cache.insert(framework, available);
777 }
778
779 if available {
780 info!("✓ {} is available", framework.package_name());
781 } else {
782 warn!("✗ {} is not installed", framework.package_name());
783 }
784
785 Ok(available)
786 }
787
788 pub async fn register_model(&self, model: TrainedModel) -> Result<()> {
790 let name = model.name.clone();
791 let mut models = self.models.write().await;
792 models.insert(name.clone(), model);
793 info!("Registered model: {}", name);
794 Ok(())
795 }
796
797 pub async fn get_model(&self, name: &str) -> Option<TrainedModel> {
799 let models = self.models.read().await;
800 models.get(name).cloned()
801 }
802
803 pub async fn list_models(&self) -> Vec<String> {
805 let models = self.models.read().await;
806 models.keys().cloned().collect()
807 }
808
809 pub async fn unregister_model(&self, name: &str) -> Result<()> {
811 let mut models = self.models.write().await;
812 models.remove(name);
813 info!("Unregistered model: {}", name);
814 Ok(())
815 }
816
817 pub async fn execute_script(&self, script_path: &Path, args: &[&str]) -> Result<String> {
819 self.python_env.run_script(script_path, args).await
820 }
821
822 pub async fn execute_code(&self, code: &str) -> Result<String> {
828 self.python_env.run_code(code).await
829 }
830
831 pub fn security_config(&self) -> &SecurityConfig {
833 &self.python_env.security
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840 use std::fs;
841 use tempfile::TempDir;
842
843 #[test]
844 fn test_python_environment_creation() {
845 let env = PythonEnvironment::default();
846 assert_eq!(env.python_path, "python3");
847 assert!(env.venv_path.is_none());
848 assert!(!env.security.allow_direct_code);
849 }
850
851 #[test]
852 fn test_ml_framework_package_names() {
853 assert_eq!(MLFramework::PyTorch.package_name(), "torch");
854 assert_eq!(MLFramework::TensorFlow.package_name(), "tensorflow");
855 }
856
857 #[test]
858 fn test_trained_model_creation() {
859 let model = TrainedModel::new(
860 "test_model".to_string(),
861 MLFramework::PyTorch,
862 PathBuf::from("/tmp/model.pt"),
863 );
864 assert_eq!(model.name, "test_model");
865 assert_eq!(model.framework, MLFramework::PyTorch);
866 }
867
868 #[tokio::test]
869 async fn test_ml_bridge_creation() {
870 let env = PythonEnvironment::default();
871 let bridge = MLBridge::new(env);
872 assert!(bridge.list_models().await.is_empty());
873 }
874
875 #[test]
876 fn test_security_config_default() {
877 let config = SecurityConfig::default();
878 assert!(!config.allow_direct_code);
879 assert_eq!(config.max_timeout_secs, 300);
880 assert!(!config.blocked_modules.is_empty());
881 assert!(!config.blocked_operations.is_empty());
882 }
883
884 #[test]
885 fn test_security_config_strict() {
886 let config = SecurityConfig::strict();
887 assert!(!config.allow_direct_code);
888 assert_eq!(config.max_timeout_secs, 60);
889 assert_eq!(config.max_code_length, 100_000);
890 }
891
892 #[test]
893 fn test_security_path_validation() {
894 let temp_dir = TempDir::new().unwrap();
895 let allowed_dir = temp_dir.path().join("allowed");
896 fs::create_dir_all(&allowed_dir).unwrap();
897
898 let mut config = SecurityConfig::default();
899 config.allowed_script_dirs.push(allowed_dir.clone());
900
901 let allowed_script = allowed_dir.join("script.py");
902 assert!(config.is_path_allowed(&allowed_script));
903
904 let disallowed_script = PathBuf::from("/tmp/evil.py");
905 assert!(!config.is_path_allowed(&disallowed_script));
906 }
907
908 #[test]
909 fn test_security_code_validation() {
910 let config = SecurityConfig::default();
911
912 assert!(config.validate_code("print('hello')").is_ok());
914
915 assert!(config
917 .validate_code("import os; os.system('rm -rf /')")
918 .is_err());
919 assert!(config.validate_code("eval('malicious code')").is_err());
920 assert!(config.validate_code("exec('dangerous')").is_err());
921 assert!(config.validate_code("import subprocess").is_err());
922 assert!(config.validate_code("import socket").is_err());
923 assert!(config.validate_code("import pickle").is_err());
924 }
925
926 #[test]
927 fn test_security_code_length_limit() {
928 let mut config = SecurityConfig::default();
929 config.max_code_length = 100;
930
931 let short_code = "print('hello')";
932 assert!(config.validate_code(short_code).is_ok());
933
934 let long_code = "x".repeat(200);
935 assert!(config.validate_code(&long_code).is_err());
936 }
937
938 #[test]
939 fn test_security_direct_code_blocked() {
940 let env = PythonEnvironment::default();
941 assert!(!env.security.allow_direct_code);
943 }
944
945 #[tokio::test]
946 async fn test_security_direct_code_execution_blocked() {
947 let env = PythonEnvironment::default();
948 let result = env.run_code("print('test')").await;
949 assert!(result.is_err());
950 assert!(result.unwrap_err().to_string().contains("disabled"));
951 }
952
953 #[tokio::test]
954 async fn test_security_allowed_direct_code() {
955 let mut config = SecurityConfig::default();
956 config.allow_direct_code = true;
957 let env = PythonEnvironment::default().with_security(config);
958
959 let result = env.run_code("import os").await;
961 assert!(result.is_err());
962 }
963
964 #[test]
965 fn test_security_package_validation() {
966 let config = SecurityConfig::default();
967 assert!(config.blocked_modules.contains("os"));
968 assert!(config.blocked_modules.contains("subprocess"));
969 }
970
971 #[tokio::test]
972 async fn test_ml_bridge_model_registration() {
973 let env = PythonEnvironment::default();
974 let bridge = MLBridge::new(env);
975
976 let model = TrainedModel::new(
977 "test_model".to_string(),
978 MLFramework::PyTorch,
979 PathBuf::from("/tmp/model.pt"),
980 );
981
982 bridge.register_model(model).await.unwrap();
983 assert_eq!(bridge.list_models().await.len(), 1);
984 assert!(bridge.get_model("test_model").await.is_some());
985
986 bridge.unregister_model("test_model").await.unwrap();
987 assert!(bridge.list_models().await.is_empty());
988 }
989
990 #[test]
991 fn test_security_config_with_allowed_dir() {
992 let temp_dir = TempDir::new().unwrap();
993 let allowed_dir = temp_dir.path().to_path_buf();
994
995 let config = SecurityConfig::default().with_allowed_dir(allowed_dir.clone());
996 assert!(config.allowed_script_dirs.contains(&allowed_dir));
997 }
998}