zoey_core/
ml_bridge.rs

1//! ML Bridge - Python/Rust Interop for Machine Learning
2//!
3//! This module provides seamless integration between Rust and Python ML frameworks (PyTorch, TensorFlow).
4//! It enables running Python ML code from Rust, managing Python environments, and bridging ML models.
5//!
6//! # Security
7//!
8//! This module implements multiple security measures to prevent Python code from exploiting the system:
9//! - Path validation: Only scripts in allowed directories can be executed
10//! - Timeout limits: All Python executions have a maximum timeout
11//! - Code sanitization: Dangerous Python operations are blocked
12//! - Resource limits: Memory and CPU usage can be restricted
13//! - Whitelist validation: Only approved operations are allowed
14
15use crate::{ZoeyError, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet};
18use std::path::{Path, PathBuf};
19use std::process::{Command, Stdio};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::RwLock;
23use tokio::time::timeout;
24use tracing::{debug, error, info, instrument, warn};
25
26/// ML Framework types supported
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28pub enum MLFramework {
29    /// PyTorch framework
30    PyTorch,
31    /// TensorFlow framework
32    TensorFlow,
33    /// Generic/Custom framework
34    Custom,
35}
36
37impl MLFramework {
38    /// Get the Python package name for this framework
39    pub fn package_name(&self) -> &str {
40        match self {
41            MLFramework::PyTorch => "torch",
42            MLFramework::TensorFlow => "tensorflow",
43            MLFramework::Custom => "custom",
44        }
45    }
46}
47
48/// Security configuration for Python execution
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SecurityConfig {
51    /// Allowed script directories (whitelist)
52    pub allowed_script_dirs: Vec<PathBuf>,
53
54    /// Maximum execution timeout in seconds
55    pub max_timeout_secs: u64,
56
57    /// Maximum code length (bytes)
58    pub max_code_length: usize,
59
60    /// Allow direct code execution (dangerous, should be false in production)
61    pub allow_direct_code: bool,
62
63    /// Blocked Python modules (blacklist)
64    pub blocked_modules: HashSet<String>,
65
66    /// Blocked Python operations (blacklist)
67    pub blocked_operations: HashSet<String>,
68}
69
70impl Default for SecurityConfig {
71    fn default() -> Self {
72        let mut blocked_modules = HashSet::new();
73        blocked_modules.insert("os".to_string());
74        blocked_modules.insert("sys".to_string());
75        blocked_modules.insert("subprocess".to_string());
76        blocked_modules.insert("shutil".to_string());
77        blocked_modules.insert("socket".to_string());
78        blocked_modules.insert("urllib".to_string());
79        blocked_modules.insert("requests".to_string());
80        blocked_modules.insert("http".to_string());
81        blocked_modules.insert("ftplib".to_string());
82        blocked_modules.insert("smtplib".to_string());
83        blocked_modules.insert("pickle".to_string());
84        blocked_modules.insert("marshal".to_string());
85        blocked_modules.insert("ctypes".to_string());
86        blocked_modules.insert("__builtin__".to_string());
87        blocked_modules.insert("builtins".to_string());
88
89        let mut blocked_operations = HashSet::new();
90        blocked_operations.insert("eval".to_string());
91        blocked_operations.insert("exec".to_string());
92        blocked_operations.insert("compile".to_string());
93        blocked_operations.insert("open".to_string());
94        blocked_operations.insert("file".to_string());
95        blocked_operations.insert("input".to_string());
96        blocked_operations.insert("raw_input".to_string());
97        blocked_operations.insert("__import__".to_string());
98        blocked_operations.insert("reload".to_string());
99        blocked_operations.insert("execfile".to_string());
100
101        Self {
102            allowed_script_dirs: vec![],
103            max_timeout_secs: 300,      // 5 minutes default
104            max_code_length: 1_000_000, // 1MB default
105            allow_direct_code: false,   // Disabled by default for security
106            blocked_modules,
107            blocked_operations,
108        }
109    }
110}
111
112impl SecurityConfig {
113    /// Create a strict security configuration
114    pub fn strict() -> Self {
115        Self {
116            max_timeout_secs: 60,     // 1 minute
117            max_code_length: 100_000, // 100KB
118            allow_direct_code: false,
119            ..Default::default()
120        }
121    }
122
123    /// Create a permissive configuration (use with caution)
124    pub fn permissive() -> Self {
125        Self {
126            allowed_script_dirs: vec![],
127            max_timeout_secs: 3600,      // 1 hour
128            max_code_length: 10_000_000, // 10MB
129            allow_direct_code: true,
130            blocked_modules: HashSet::new(),
131            blocked_operations: HashSet::new(),
132        }
133    }
134
135    /// Add an allowed script directory
136    pub fn with_allowed_dir(mut self, dir: PathBuf) -> Self {
137        self.allowed_script_dirs.push(dir);
138        self
139    }
140
141    /// Check if a path is allowed
142    pub fn is_path_allowed(&self, path: &Path) -> bool {
143        if self.allowed_script_dirs.is_empty() {
144            // If no restrictions, allow all (for backward compatibility)
145            return true;
146        }
147
148        // Check if path is within any allowed directory
149        for allowed_dir in &self.allowed_script_dirs {
150            if path.starts_with(allowed_dir) {
151                return true;
152            }
153        }
154
155        false
156    }
157
158    /// Validate Python code for dangerous operations
159    pub fn validate_code(&self, code: &str) -> Result<()> {
160        // Check code length
161        if code.len() > self.max_code_length {
162            return Err(ZoeyError::Runtime(format!(
163                "Code too long: {} bytes (max: {})",
164                code.len(),
165                self.max_code_length
166            )));
167        }
168
169        // Check for blocked operations
170        for blocked_op in &self.blocked_operations {
171            if code.contains(blocked_op) {
172                return Err(ZoeyError::Runtime(format!(
173                    "Blocked operation detected: {}",
174                    blocked_op
175                )));
176            }
177        }
178
179        // Check for blocked module imports
180        for blocked_mod in &self.blocked_modules {
181            let import_patterns = vec![
182                format!("import {}", blocked_mod),
183                format!("from {} import", blocked_mod),
184                format!(
185                    "import {}",
186                    blocked_mod.replace("__builtin__", "__builtins__")
187                ),
188            ];
189
190            for pattern in import_patterns {
191                if code.contains(&pattern) {
192                    return Err(ZoeyError::Runtime(format!(
193                        "Blocked module import detected: {}",
194                        blocked_mod
195                    )));
196                }
197            }
198        }
199
200        // Check for dangerous patterns
201        let dangerous_patterns = vec![
202            "__import__",
203            "eval(",
204            "exec(",
205            "compile(",
206            "open(",
207            "file(",
208            "input(",
209            "raw_input(",
210            "reload(",
211            "execfile(",
212            "subprocess",
213            "os.system",
214            "os.popen",
215            "os.exec",
216            "shutil.",
217            "socket.",
218            "urllib.",
219            "requests.",
220            "pickle.load",
221            "marshal.load",
222            "ctypes.",
223        ];
224
225        for pattern in dangerous_patterns {
226            if code.contains(pattern) {
227                return Err(ZoeyError::Runtime(format!(
228                    "Dangerous pattern detected: {}",
229                    pattern
230                )));
231            }
232        }
233
234        Ok(())
235    }
236}
237
238/// Python environment configuration
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct PythonEnvironment {
241    /// Python interpreter path (default: "python3")
242    pub python_path: String,
243
244    /// Virtual environment path (optional)
245    pub venv_path: Option<PathBuf>,
246
247    /// Additional environment variables
248    pub env_vars: HashMap<String, String>,
249
250    /// Working directory for Python execution
251    pub working_dir: Option<PathBuf>,
252
253    /// Security configuration
254    pub security: SecurityConfig,
255}
256
257impl Default for PythonEnvironment {
258    fn default() -> Self {
259        Self {
260            python_path: "python3".to_string(),
261            venv_path: None,
262            env_vars: HashMap::new(),
263            working_dir: None,
264            security: SecurityConfig::default(),
265        }
266    }
267}
268
269impl PythonEnvironment {
270    /// Create a new Python environment with custom settings
271    pub fn new(python_path: String) -> Self {
272        Self {
273            python_path,
274            ..Default::default()
275        }
276    }
277
278    /// Set virtual environment path
279    pub fn with_venv(mut self, venv_path: PathBuf) -> Self {
280        self.venv_path = Some(venv_path);
281        self
282    }
283
284    /// Add environment variable
285    pub fn with_env_var(mut self, key: String, value: String) -> Self {
286        self.env_vars.insert(key, value);
287        self
288    }
289
290    /// Set working directory
291    pub fn with_working_dir(mut self, dir: PathBuf) -> Self {
292        self.working_dir = Some(dir);
293        self
294    }
295
296    /// Set security configuration
297    pub fn with_security(mut self, security: SecurityConfig) -> Self {
298        self.security = security;
299        self
300    }
301
302    /// Check if Python is available
303    pub async fn check_availability(&self) -> Result<bool> {
304        let output = Command::new(&self.python_path).arg("--version").output();
305
306        match output {
307            Ok(output) => {
308                if output.status.success() {
309                    let version = String::from_utf8_lossy(&output.stdout);
310                    debug!("Python available: {}", version.trim());
311                    Ok(true)
312                } else {
313                    warn!("Python command failed: {:?}", output.stderr);
314                    Ok(false)
315                }
316            }
317            Err(e) => {
318                warn!("Failed to execute Python: {}", e);
319                Ok(false)
320            }
321        }
322    }
323
324    /// Run a Python script and return output
325    #[instrument(skip(self, script_path), level = "debug")]
326    pub async fn run_script(&self, script_path: &Path, args: &[&str]) -> Result<String> {
327        // Security: Validate path
328        let script_path = if script_path.exists() {
329            script_path.canonicalize().map_err(|e| {
330                ZoeyError::Runtime(format!("Failed to canonicalize script path: {}", e))
331            })?
332        } else {
333            // If path doesn't exist yet, check if parent directory is allowed
334            if let Some(parent) = script_path.parent() {
335                if let Ok(canonical_parent) = parent.canonicalize() {
336                    if !self.security.is_path_allowed(&canonical_parent) {
337                        return Err(ZoeyError::Runtime(format!(
338                            "Script path not allowed: {:?}. Allowed directories: {:?}",
339                            script_path, self.security.allowed_script_dirs
340                        )));
341                    }
342                }
343            }
344            script_path.to_path_buf()
345        };
346
347        if !self.security.allowed_script_dirs.is_empty()
348            && !self.security.is_path_allowed(&script_path)
349        {
350            return Err(ZoeyError::Runtime(format!(
351                "Script path not allowed: {:?}. Allowed directories: {:?}",
352                script_path, self.security.allowed_script_dirs
353            )));
354        }
355
356        // Security: Validate script content if possible
357        if let Ok(contents) = std::fs::read_to_string(&script_path) {
358            self.security.validate_code(&contents)?;
359        }
360
361        // Security: Validate arguments
362        for arg in args {
363            if arg.contains(";") || arg.contains("&") || arg.contains("|") || arg.contains("`") {
364                return Err(ZoeyError::Runtime(format!(
365                    "Dangerous argument detected: {}",
366                    arg
367                )));
368            }
369        }
370
371        let mut cmd = Command::new(&self.python_path);
372        cmd.arg(&script_path);
373
374        for arg in args {
375            cmd.arg(arg);
376        }
377
378        // Security: Limit environment variables
379        for (key, value) in &self.env_vars {
380            // Block dangerous environment variables
381            if key == "PATH" || key == "LD_LIBRARY_PATH" || key == "PYTHONPATH" {
382                warn!("Blocked dangerous environment variable: {}", key);
383                continue;
384            }
385            cmd.env(key, value);
386        }
387
388        // Security: Set safe defaults
389        cmd.env("PYTHONUNBUFFERED", "1");
390        cmd.env("PYTHONDONTWRITEBYTECODE", "1");
391
392        // Set working directory
393        if let Some(ref dir) = self.working_dir {
394            cmd.current_dir(dir);
395        }
396
397        // Use venv if configured
398        if let Some(ref venv) = self.venv_path {
399            let venv_python = venv.join("bin").join("python");
400            if venv_python.exists() {
401                cmd = Command::new(&venv_python);
402                cmd.arg(&script_path);
403                for arg in args {
404                    cmd.arg(arg);
405                }
406            }
407        }
408
409        debug!("Executing Python script: {:?}", script_path);
410
411        let timeout_duration = Duration::from_secs(self.security.max_timeout_secs);
412        let python_path = self.python_path.clone();
413        let script_path_clone = script_path.clone();
414        let args_vec: Vec<String> = args.iter().map(|s| s.to_string()).collect();
415        let env_vars_clone: HashMap<String, String> = self
416            .env_vars
417            .iter()
418            .filter(|(k, _)| {
419                k.as_str() != "PATH"
420                    && k.as_str() != "LD_LIBRARY_PATH"
421                    && k.as_str() != "PYTHONPATH"
422            })
423            .map(|(k, v)| (k.clone(), v.clone()))
424            .collect();
425        let venv_path_clone = self.venv_path.clone();
426        let working_dir_clone = self.working_dir.clone();
427
428        let output_future = tokio::task::spawn_blocking(move || {
429            let mut cmd = Command::new(&python_path);
430            cmd.arg(&script_path_clone);
431            for arg in &args_vec {
432                cmd.arg(arg);
433            }
434            for (key, value) in &env_vars_clone {
435                cmd.env(key, value);
436            }
437            cmd.env("PYTHONUNBUFFERED", "1");
438            cmd.env("PYTHONDONTWRITEBYTECODE", "1");
439            if let Some(ref dir) = working_dir_clone {
440                cmd.current_dir(dir);
441            }
442            if let Some(ref venv) = venv_path_clone {
443                let venv_python = venv.join("bin").join("python");
444                if venv_python.exists() {
445                    cmd = Command::new(&venv_python);
446                    cmd.arg(&script_path_clone);
447                    for arg in &args_vec {
448                        cmd.arg(arg);
449                    }
450                    for (key, value) in &env_vars_clone {
451                        cmd.env(key, value);
452                    }
453                    cmd.env("PYTHONUNBUFFERED", "1");
454                    cmd.env("PYTHONDONTWRITEBYTECODE", "1");
455                }
456            }
457            cmd.stdout(Stdio::piped()).stderr(Stdio::piped()).output()
458        });
459
460        let output = timeout(timeout_duration, output_future)
461            .await
462            .map_err(|_| {
463                ZoeyError::Runtime(format!(
464                    "Python script execution timed out after {} seconds",
465                    self.security.max_timeout_secs
466                ))
467            })?
468            .map_err(|e| ZoeyError::Runtime(format!("Failed to spawn Python process: {}", e)))?
469            .map_err(|e| ZoeyError::Runtime(format!("Failed to execute Python script: {}", e)))?;
470
471        if output.status.success() {
472            let stdout = String::from_utf8_lossy(&output.stdout).to_string();
473            debug!("Python script output: {} bytes", stdout.len());
474            Ok(stdout)
475        } else {
476            let stderr = String::from_utf8_lossy(&output.stderr).to_string();
477            error!("Python script failed: {}", stderr);
478            Err(ZoeyError::Runtime(format!(
479                "Python script failed: {}",
480                stderr
481            )))
482        }
483    }
484
485    /// Run Python code directly
486    ///
487    /// # Security Warning
488    /// This method is dangerous and should only be used with trusted code.
489    /// By default, `allow_direct_code` is false in SecurityConfig.
490    #[instrument(skip(self, code), level = "debug")]
491    pub async fn run_code(&self, code: &str) -> Result<String> {
492        // Security: Check if direct code execution is allowed
493        if !self.security.allow_direct_code {
494            return Err(ZoeyError::Runtime(
495                "Direct code execution is disabled for security. Use run_script() with a whitelisted script instead.".to_string()
496            ));
497        }
498
499        // Security: Validate code
500        self.security.validate_code(code)?;
501
502        let mut cmd = Command::new(&self.python_path);
503        cmd.arg("-c").arg(code);
504
505        // Security: Limit environment variables
506        for (key, value) in &self.env_vars {
507            // Block dangerous environment variables
508            if key == "PATH" || key == "LD_LIBRARY_PATH" || key == "PYTHONPATH" {
509                warn!("Blocked dangerous environment variable: {}", key);
510                continue;
511            }
512            cmd.env(key, value);
513        }
514
515        // Security: Set safe defaults
516        cmd.env("PYTHONUNBUFFERED", "1");
517        cmd.env("PYTHONDONTWRITEBYTECODE", "1");
518
519        if let Some(ref venv) = self.venv_path {
520            let venv_python = venv.join("bin").join("python");
521            if venv_python.exists() {
522                cmd = Command::new(&venv_python);
523                cmd.arg("-c").arg(code);
524            }
525        }
526
527        debug!("Executing Python code: {} bytes", code.len());
528
529        let timeout_duration = Duration::from_secs(self.security.max_timeout_secs);
530        let python_path = self.python_path.clone();
531        let code_clone = code.to_string();
532        let env_vars_clone: HashMap<String, String> = self
533            .env_vars
534            .iter()
535            .filter(|(k, _)| {
536                k.as_str() != "PATH"
537                    && k.as_str() != "LD_LIBRARY_PATH"
538                    && k.as_str() != "PYTHONPATH"
539            })
540            .map(|(k, v)| (k.clone(), v.clone()))
541            .collect();
542        let venv_path_clone = self.venv_path.clone();
543
544        let output_future = tokio::task::spawn_blocking(move || {
545            let mut cmd = Command::new(&python_path);
546            cmd.arg("-c").arg(&code_clone);
547            for (key, value) in &env_vars_clone {
548                cmd.env(key, value);
549            }
550            cmd.env("PYTHONUNBUFFERED", "1");
551            cmd.env("PYTHONDONTWRITEBYTECODE", "1");
552            if let Some(ref venv) = venv_path_clone {
553                let venv_python = venv.join("bin").join("python");
554                if venv_python.exists() {
555                    cmd = Command::new(&venv_python);
556                    cmd.arg("-c").arg(&code_clone);
557                    for (key, value) in &env_vars_clone {
558                        cmd.env(key, value);
559                    }
560                    cmd.env("PYTHONUNBUFFERED", "1");
561                    cmd.env("PYTHONDONTWRITEBYTECODE", "1");
562                }
563            }
564            cmd.stdout(Stdio::piped()).stderr(Stdio::piped()).output()
565        });
566
567        let output = timeout(timeout_duration, output_future)
568            .await
569            .map_err(|_| {
570                ZoeyError::Runtime(format!(
571                    "Python code execution timed out after {} seconds",
572                    self.security.max_timeout_secs
573                ))
574            })?
575            .map_err(|e| ZoeyError::Runtime(format!("Failed to spawn Python process: {}", e)))?
576            .map_err(|e| ZoeyError::Runtime(format!("Failed to execute Python code: {}", e)))?;
577
578        if output.status.success() {
579            Ok(String::from_utf8_lossy(&output.stdout).to_string())
580        } else {
581            let stderr = String::from_utf8_lossy(&output.stderr).to_string();
582            Err(ZoeyError::Runtime(format!(
583                "Python code failed: {}",
584                stderr
585            )))
586        }
587    }
588
589    /// Check if a Python package is installed
590    pub async fn check_package(&self, package: &str) -> Result<bool> {
591        // Security: Validate package name
592        if package.contains(";")
593            || package.contains("&")
594            || package.contains("|")
595            || package.contains("`")
596        {
597            return Err(ZoeyError::Runtime(format!(
598                "Invalid package name: {}",
599                package
600            )));
601        }
602
603        // Security: Check if package is blocked
604        if self.security.blocked_modules.contains(package) {
605            return Err(ZoeyError::Runtime(format!(
606                "Package is blocked: {}",
607                package
608            )));
609        }
610
611        // Use a safe check method that doesn't require direct code execution
612        let mut cmd = Command::new(&self.python_path);
613        cmd.arg("-c");
614        cmd.arg(format!("import {}; print('installed')", package));
615
616        // Security: Set safe defaults
617        cmd.env("PYTHONUNBUFFERED", "1");
618        cmd.env("PYTHONDONTWRITEBYTECODE", "1");
619
620        if let Some(ref venv) = self.venv_path {
621            let venv_python = venv.join("bin").join("python");
622            if venv_python.exists() {
623                cmd = Command::new(&venv_python);
624                cmd.arg("-c");
625                cmd.arg(format!("import {}; print('installed')", package));
626            }
627        }
628
629        let timeout_duration = Duration::from_secs(10); // Short timeout for package checks
630        let python_path = self.python_path.clone();
631        let package_clone = package.to_string();
632        let venv_path_clone = self.venv_path.clone();
633
634        let output_future = tokio::task::spawn_blocking(move || {
635            let mut cmd = Command::new(&python_path);
636            cmd.arg("-c");
637            cmd.arg(format!("import {}; print('installed')", package_clone));
638            cmd.env("PYTHONUNBUFFERED", "1");
639            cmd.env("PYTHONDONTWRITEBYTECODE", "1");
640            if let Some(ref venv) = venv_path_clone {
641                let venv_python = venv.join("bin").join("python");
642                if venv_python.exists() {
643                    cmd = Command::new(&venv_python);
644                    cmd.arg("-c");
645                    cmd.arg(format!("import {}; print('installed')", package_clone));
646                    cmd.env("PYTHONUNBUFFERED", "1");
647                    cmd.env("PYTHONDONTWRITEBYTECODE", "1");
648                }
649            }
650            cmd.stdout(Stdio::piped()).stderr(Stdio::piped()).output()
651        });
652
653        match timeout(timeout_duration, output_future).await {
654            Ok(join_result) => match join_result {
655                Ok(output_result) => match output_result {
656                    Ok(output) => {
657                        if output.status.success() {
658                            let stdout = String::from_utf8_lossy(&output.stdout);
659                            Ok(stdout.contains("installed"))
660                        } else {
661                            Ok(false)
662                        }
663                    }
664                    Err(_) => Ok(false),
665                },
666                Err(_) => Ok(false),
667            },
668            Err(_) => Ok(false),
669        }
670    }
671}
672
673/// Model interface for trained ML models
674#[async_trait::async_trait]
675pub trait ModelInterface: Send + Sync {
676    /// Get model name
677    fn name(&self) -> &str;
678
679    /// Get model framework
680    fn framework(&self) -> MLFramework;
681
682    /// Load model from path
683    async fn load(&mut self, path: &Path) -> Result<()>;
684
685    /// Save model to path
686    async fn save(&self, path: &Path) -> Result<()>;
687
688    /// Run inference on input data
689    async fn predict(&self, input: &[f32]) -> Result<Vec<f32>>;
690
691    /// Get model metadata
692    fn metadata(&self) -> HashMap<String, String>;
693}
694
695/// Trained model wrapper
696#[derive(Debug, Clone)]
697pub struct TrainedModel {
698    /// Model name/identifier
699    pub name: String,
700
701    /// Framework used
702    pub framework: MLFramework,
703
704    /// Model file path
705    pub path: PathBuf,
706
707    /// Model metadata
708    pub metadata: HashMap<String, String>,
709}
710
711impl TrainedModel {
712    /// Create a new trained model reference
713    pub fn new(name: String, framework: MLFramework, path: PathBuf) -> Self {
714        Self {
715            name,
716            framework,
717            path,
718            metadata: HashMap::new(),
719        }
720    }
721
722    /// Add metadata to the model
723    pub fn with_metadata(mut self, key: String, value: String) -> Self {
724        self.metadata.insert(key, value);
725        self
726    }
727}
728
729/// ML Bridge - Main interface for ML operations
730pub struct MLBridge {
731    /// Python environment
732    python_env: PythonEnvironment,
733
734    /// Cached models
735    models: Arc<RwLock<HashMap<String, TrainedModel>>>,
736
737    /// Framework availability cache
738    frameworks: Arc<RwLock<HashMap<MLFramework, bool>>>,
739}
740
741impl MLBridge {
742    /// Create a new ML bridge
743    pub fn new(python_env: PythonEnvironment) -> Self {
744        Self {
745            python_env,
746            models: Arc::new(RwLock::new(HashMap::new())),
747            frameworks: Arc::new(RwLock::new(HashMap::new())),
748        }
749    }
750
751    /// Get Python environment reference
752    pub fn python_env(&self) -> &PythonEnvironment {
753        &self.python_env
754    }
755
756    /// Check if a framework is available
757    #[instrument(skip(self), level = "info")]
758    pub async fn check_framework(&self, framework: MLFramework) -> Result<bool> {
759        // Check cache first
760        {
761            let cache = self.frameworks.read().await;
762            if let Some(&available) = cache.get(&framework) {
763                return Ok(available);
764            }
765        }
766
767        // Check package availability
768        let available = self
769            .python_env
770            .check_package(framework.package_name())
771            .await?;
772
773        // Update cache
774        {
775            let mut cache = self.frameworks.write().await;
776            cache.insert(framework, available);
777        }
778
779        if available {
780            info!("✓ {} is available", framework.package_name());
781        } else {
782            warn!("✗ {} is not installed", framework.package_name());
783        }
784
785        Ok(available)
786    }
787
788    /// Register a trained model
789    pub async fn register_model(&self, model: TrainedModel) -> Result<()> {
790        let name = model.name.clone();
791        let mut models = self.models.write().await;
792        models.insert(name.clone(), model);
793        info!("Registered model: {}", name);
794        Ok(())
795    }
796
797    /// Get a registered model
798    pub async fn get_model(&self, name: &str) -> Option<TrainedModel> {
799        let models = self.models.read().await;
800        models.get(name).cloned()
801    }
802
803    /// List all registered models
804    pub async fn list_models(&self) -> Vec<String> {
805        let models = self.models.read().await;
806        models.keys().cloned().collect()
807    }
808
809    /// Remove a model from registry
810    pub async fn unregister_model(&self, name: &str) -> Result<()> {
811        let mut models = self.models.write().await;
812        models.remove(name);
813        info!("Unregistered model: {}", name);
814        Ok(())
815    }
816
817    /// Execute a Python ML script
818    pub async fn execute_script(&self, script_path: &Path, args: &[&str]) -> Result<String> {
819        self.python_env.run_script(script_path, args).await
820    }
821
822    /// Execute Python ML code directly
823    ///
824    /// # Security Warning
825    /// This method requires `allow_direct_code` to be true in the security configuration.
826    /// Use `execute_script()` with whitelisted scripts instead for better security.
827    pub async fn execute_code(&self, code: &str) -> Result<String> {
828        self.python_env.run_code(code).await
829    }
830
831    /// Get security configuration
832    pub fn security_config(&self) -> &SecurityConfig {
833        &self.python_env.security
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use std::fs;
841    use tempfile::TempDir;
842
843    #[test]
844    fn test_python_environment_creation() {
845        let env = PythonEnvironment::default();
846        assert_eq!(env.python_path, "python3");
847        assert!(env.venv_path.is_none());
848        assert!(!env.security.allow_direct_code);
849    }
850
851    #[test]
852    fn test_ml_framework_package_names() {
853        assert_eq!(MLFramework::PyTorch.package_name(), "torch");
854        assert_eq!(MLFramework::TensorFlow.package_name(), "tensorflow");
855    }
856
857    #[test]
858    fn test_trained_model_creation() {
859        let model = TrainedModel::new(
860            "test_model".to_string(),
861            MLFramework::PyTorch,
862            PathBuf::from("/tmp/model.pt"),
863        );
864        assert_eq!(model.name, "test_model");
865        assert_eq!(model.framework, MLFramework::PyTorch);
866    }
867
868    #[tokio::test]
869    async fn test_ml_bridge_creation() {
870        let env = PythonEnvironment::default();
871        let bridge = MLBridge::new(env);
872        assert!(bridge.list_models().await.is_empty());
873    }
874
875    #[test]
876    fn test_security_config_default() {
877        let config = SecurityConfig::default();
878        assert!(!config.allow_direct_code);
879        assert_eq!(config.max_timeout_secs, 300);
880        assert!(!config.blocked_modules.is_empty());
881        assert!(!config.blocked_operations.is_empty());
882    }
883
884    #[test]
885    fn test_security_config_strict() {
886        let config = SecurityConfig::strict();
887        assert!(!config.allow_direct_code);
888        assert_eq!(config.max_timeout_secs, 60);
889        assert_eq!(config.max_code_length, 100_000);
890    }
891
892    #[test]
893    fn test_security_path_validation() {
894        let temp_dir = TempDir::new().unwrap();
895        let allowed_dir = temp_dir.path().join("allowed");
896        fs::create_dir_all(&allowed_dir).unwrap();
897
898        let mut config = SecurityConfig::default();
899        config.allowed_script_dirs.push(allowed_dir.clone());
900
901        let allowed_script = allowed_dir.join("script.py");
902        assert!(config.is_path_allowed(&allowed_script));
903
904        let disallowed_script = PathBuf::from("/tmp/evil.py");
905        assert!(!config.is_path_allowed(&disallowed_script));
906    }
907
908    #[test]
909    fn test_security_code_validation() {
910        let config = SecurityConfig::default();
911
912        // Valid code should pass
913        assert!(config.validate_code("print('hello')").is_ok());
914
915        // Dangerous operations should be blocked
916        assert!(config
917            .validate_code("import os; os.system('rm -rf /')")
918            .is_err());
919        assert!(config.validate_code("eval('malicious code')").is_err());
920        assert!(config.validate_code("exec('dangerous')").is_err());
921        assert!(config.validate_code("import subprocess").is_err());
922        assert!(config.validate_code("import socket").is_err());
923        assert!(config.validate_code("import pickle").is_err());
924    }
925
926    #[test]
927    fn test_security_code_length_limit() {
928        let mut config = SecurityConfig::default();
929        config.max_code_length = 100;
930
931        let short_code = "print('hello')";
932        assert!(config.validate_code(short_code).is_ok());
933
934        let long_code = "x".repeat(200);
935        assert!(config.validate_code(&long_code).is_err());
936    }
937
938    #[test]
939    fn test_security_direct_code_blocked() {
940        let env = PythonEnvironment::default();
941        // Direct code should be blocked by default
942        assert!(!env.security.allow_direct_code);
943    }
944
945    #[tokio::test]
946    async fn test_security_direct_code_execution_blocked() {
947        let env = PythonEnvironment::default();
948        let result = env.run_code("print('test')").await;
949        assert!(result.is_err());
950        assert!(result.unwrap_err().to_string().contains("disabled"));
951    }
952
953    #[tokio::test]
954    async fn test_security_allowed_direct_code() {
955        let mut config = SecurityConfig::default();
956        config.allow_direct_code = true;
957        let env = PythonEnvironment::default().with_security(config);
958
959        // This should still fail because of dangerous operations check
960        let result = env.run_code("import os").await;
961        assert!(result.is_err());
962    }
963
964    #[test]
965    fn test_security_package_validation() {
966        let config = SecurityConfig::default();
967        assert!(config.blocked_modules.contains("os"));
968        assert!(config.blocked_modules.contains("subprocess"));
969    }
970
971    #[tokio::test]
972    async fn test_ml_bridge_model_registration() {
973        let env = PythonEnvironment::default();
974        let bridge = MLBridge::new(env);
975
976        let model = TrainedModel::new(
977            "test_model".to_string(),
978            MLFramework::PyTorch,
979            PathBuf::from("/tmp/model.pt"),
980        );
981
982        bridge.register_model(model).await.unwrap();
983        assert_eq!(bridge.list_models().await.len(), 1);
984        assert!(bridge.get_model("test_model").await.is_some());
985
986        bridge.unregister_model("test_model").await.unwrap();
987        assert!(bridge.list_models().await.is_empty());
988    }
989
990    #[test]
991    fn test_security_config_with_allowed_dir() {
992        let temp_dir = TempDir::new().unwrap();
993        let allowed_dir = temp_dir.path().to_path_buf();
994
995        let config = SecurityConfig::default().with_allowed_dir(allowed_dir.clone());
996        assert!(config.allowed_script_dirs.contains(&allowed_dir));
997    }
998}