Skip to main content

torsh_hub/
security.rs

1//! Security functionality for ToRSh Hub
2//!
3//! This module provides model signing, verification, and security features
4//! to ensure model integrity and authenticity.
5
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{BufReader, Read};
11use std::path::Path;
12use std::sync::Arc;
13use std::sync::Mutex;
14use std::time::{SystemTime, UNIX_EPOCH};
15use torsh_core::error::{GeneralError, Result, TorshError};
16
17/// Digital signature for a model
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ModelSignature {
20    /// SHA-256 hash of the model file
21    pub file_hash: String,
22    /// Signature of the hash using a private key
23    pub signature: String,
24    /// Public key ID used for verification
25    pub key_id: String,
26    /// Timestamp when the signature was created
27    pub timestamp: u64,
28    /// Algorithm used for signing
29    pub algorithm: SignatureAlgorithm,
30    /// Additional metadata
31    pub metadata: HashMap<String, String>,
32}
33
34/// Supported signature algorithms
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum SignatureAlgorithm {
37    /// RSA with SHA-256
38    RsaSha256,
39    /// Ed25519
40    Ed25519,
41    /// ECDSA with P-256 curve
42    EcdsaP256,
43}
44
45/// Key pair for signing and verification
46#[derive(Debug, Clone)]
47pub struct KeyPair {
48    pub key_id: String,
49    pub algorithm: SignatureAlgorithm,
50    pub public_key: Vec<u8>,
51    pub private_key: Option<Vec<u8>>, // None for verification-only keys
52}
53
54/// Model security manager
55pub struct SecurityManager {
56    key_store: HashMap<String, KeyPair>,
57    trusted_keys: Vec<String>,
58}
59
60impl SecurityManager {
61    /// Create a new security manager
62    pub fn new() -> Self {
63        Self {
64            key_store: HashMap::new(),
65            trusted_keys: Vec::new(),
66        }
67    }
68
69    /// Add a key pair to the key store
70    pub fn add_key(&mut self, key_pair: KeyPair) {
71        let key_id = key_pair.key_id.clone();
72        self.key_store.insert(key_id, key_pair);
73    }
74
75    /// Mark a key as trusted
76    pub fn trust_key(&mut self, key_id: &str) -> Result<()> {
77        if !self.key_store.contains_key(key_id) {
78            return Err(TorshError::General(GeneralError::InvalidArgument(format!(
79                "Key '{}' not found in key store",
80                key_id
81            ))));
82        }
83
84        if !self.trusted_keys.contains(&key_id.to_string()) {
85            self.trusted_keys.push(key_id.to_string());
86        }
87
88        Ok(())
89    }
90
91    /// Remove trust from a key
92    pub fn untrust_key(&mut self, key_id: &str) {
93        self.trusted_keys.retain(|k| k != key_id);
94    }
95
96    /// Sign a model file
97    pub fn sign_model<P: AsRef<Path>>(
98        &self,
99        model_path: P,
100        key_id: &str,
101        metadata: Option<HashMap<String, String>>,
102    ) -> Result<ModelSignature> {
103        let model_path = model_path.as_ref();
104
105        // Get the key pair
106        let key_pair = self.key_store.get(key_id).ok_or_else(|| {
107            TorshError::General(GeneralError::InvalidArgument(format!(
108                "Key '{}' not found",
109                key_id
110            )))
111        })?;
112
113        if key_pair.private_key.is_none() {
114            return Err(TorshError::General(GeneralError::InvalidArgument(
115                "Cannot sign with a verification-only key".to_string(),
116            )));
117        }
118
119        // Calculate file hash
120        let file_hash = calculate_file_hash(model_path)?;
121
122        // Create signature
123        let signature = match key_pair.algorithm {
124            SignatureAlgorithm::RsaSha256 => sign_with_rsa_sha256(
125                &file_hash,
126                key_pair
127                    .private_key
128                    .as_ref()
129                    .expect("RSA private key required for signing"),
130            )?,
131            SignatureAlgorithm::Ed25519 => sign_with_ed25519(
132                &file_hash,
133                key_pair
134                    .private_key
135                    .as_ref()
136                    .expect("Ed25519 private key required for signing"),
137            )?,
138            SignatureAlgorithm::EcdsaP256 => sign_with_ecdsa_p256(
139                &file_hash,
140                key_pair
141                    .private_key
142                    .as_ref()
143                    .expect("ECDSA P256 private key required for signing"),
144            )?,
145        };
146
147        let timestamp = SystemTime::now()
148            .duration_since(UNIX_EPOCH)
149            .expect("system time should be after UNIX epoch")
150            .as_secs();
151
152        Ok(ModelSignature {
153            file_hash,
154            signature,
155            key_id: key_id.to_string(),
156            timestamp,
157            algorithm: key_pair.algorithm.clone(),
158            metadata: metadata.unwrap_or_default(),
159        })
160    }
161
162    /// Verify a model signature
163    pub fn verify_model<P: AsRef<Path>>(
164        &self,
165        model_path: P,
166        signature: &ModelSignature,
167        require_trusted: bool,
168    ) -> Result<bool> {
169        let model_path = model_path.as_ref();
170
171        // Check if key is trusted (if required)
172        if require_trusted && !self.trusted_keys.contains(&signature.key_id) {
173            return Err(TorshError::General(GeneralError::RuntimeError(format!(
174                "Key '{}' is not trusted",
175                signature.key_id
176            ))));
177        }
178
179        // Get the public key
180        let key_pair = self.key_store.get(&signature.key_id).ok_or_else(|| {
181            TorshError::General(GeneralError::InvalidArgument(format!(
182                "Key '{}' not found",
183                signature.key_id
184            )))
185        })?;
186
187        // Verify algorithm matches
188        if key_pair.algorithm != signature.algorithm {
189            return Ok(false);
190        }
191
192        // Calculate current file hash
193        let current_hash = calculate_file_hash(model_path)?;
194
195        // Verify file hash matches
196        if current_hash != signature.file_hash {
197            return Ok(false);
198        }
199
200        // Verify signature
201        let is_valid = match signature.algorithm {
202            SignatureAlgorithm::RsaSha256 => verify_rsa_sha256(
203                &signature.file_hash,
204                &signature.signature,
205                &key_pair.public_key,
206            )?,
207            SignatureAlgorithm::Ed25519 => verify_ed25519(
208                &signature.file_hash,
209                &signature.signature,
210                &key_pair.public_key,
211            )?,
212            SignatureAlgorithm::EcdsaP256 => verify_ecdsa_p256(
213                &signature.file_hash,
214                &signature.signature,
215                &key_pair.public_key,
216            )?,
217        };
218
219        Ok(is_valid)
220    }
221
222    /// Save a signature to a file
223    pub fn save_signature<P: AsRef<Path>>(
224        signature: &ModelSignature,
225        signature_path: P,
226    ) -> Result<()> {
227        let json = serde_json::to_string_pretty(signature)
228            .map_err(|e| TorshError::SerializationError(e.to_string()))?;
229
230        std::fs::write(signature_path, json)?;
231        Ok(())
232    }
233
234    /// Load a signature from a file
235    pub fn load_signature<P: AsRef<Path>>(signature_path: P) -> Result<ModelSignature> {
236        let content = std::fs::read_to_string(signature_path)?;
237        let signature: ModelSignature = serde_json::from_str(&content)
238            .map_err(|e| TorshError::SerializationError(e.to_string()))?;
239
240        Ok(signature)
241    }
242
243    /// Generate a new key pair
244    pub fn generate_key_pair(key_id: String, algorithm: SignatureAlgorithm) -> Result<KeyPair> {
245        match algorithm {
246            SignatureAlgorithm::RsaSha256 => generate_rsa_key_pair(key_id),
247            SignatureAlgorithm::Ed25519 => generate_ed25519_key_pair(key_id),
248            SignatureAlgorithm::EcdsaP256 => generate_ecdsa_p256_key_pair(key_id),
249        }
250    }
251
252    /// Get list of trusted keys
253    pub fn trusted_keys(&self) -> &[String] {
254        &self.trusted_keys
255    }
256
257    /// Get list of all keys
258    pub fn all_keys(&self) -> Vec<&str> {
259        self.key_store.keys().map(|s| s.as_str()).collect()
260    }
261}
262
263impl Default for SecurityManager {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269/// Calculate SHA-256 hash of a file
270pub fn calculate_file_hash<P: AsRef<Path>>(file_path: P) -> Result<String> {
271    let file = File::open(file_path)?;
272    let mut reader = BufReader::new(file);
273    let mut hasher = Sha256::new();
274    let mut buffer = [0; 8192];
275
276    loop {
277        let n = reader.read(&mut buffer)?;
278        if n == 0 {
279            break;
280        }
281        hasher.update(&buffer[..n]);
282    }
283
284    let result = hasher.finalize();
285    Ok(hex::encode(result))
286}
287
288/// Verify file integrity by comparing hashes
289pub fn verify_file_integrity<P: AsRef<Path>>(file_path: P, expected_hash: &str) -> Result<bool> {
290    let actual_hash = calculate_file_hash(file_path)?;
291    Ok(actual_hash == expected_hash)
292}
293
294// Placeholder implementations for cryptographic operations
295// In a real implementation, these would use proper cryptographic libraries
296
297fn sign_with_rsa_sha256(_data: &str, _private_key: &[u8]) -> Result<String> {
298    // Placeholder: In real implementation, use RSA + SHA-256 signing
299    // This would use libraries like `rsa` and `sha2`
300    Ok("rsa_signature_placeholder".to_string())
301}
302
303fn verify_rsa_sha256(_data: &str, _signature: &str, _public_key: &[u8]) -> Result<bool> {
304    // Placeholder: In real implementation, verify RSA + SHA-256 signature
305    Ok(_signature == "rsa_signature_placeholder")
306}
307
308fn sign_with_ed25519(_data: &str, _private_key: &[u8]) -> Result<String> {
309    // Placeholder: In real implementation, use Ed25519 signing
310    // This would use libraries like `ed25519-dalek`
311    Ok("ed25519_signature_placeholder".to_string())
312}
313
314fn verify_ed25519(_data: &str, _signature: &str, _public_key: &[u8]) -> Result<bool> {
315    // Placeholder: In real implementation, verify Ed25519 signature
316    Ok(_signature == "ed25519_signature_placeholder")
317}
318
319fn sign_with_ecdsa_p256(_data: &str, _private_key: &[u8]) -> Result<String> {
320    // Placeholder: In real implementation, use ECDSA P-256 signing
321    // This would use libraries like `p256` and `ecdsa`
322    Ok("ecdsa_signature_placeholder".to_string())
323}
324
325fn verify_ecdsa_p256(_data: &str, _signature: &str, _public_key: &[u8]) -> Result<bool> {
326    // Placeholder: In real implementation, verify ECDSA P-256 signature
327    Ok(_signature == "ecdsa_signature_placeholder")
328}
329
330fn generate_rsa_key_pair(key_id: String) -> Result<KeyPair> {
331    // Placeholder: In real implementation, generate RSA key pair
332    // This would use libraries like `rsa`
333    Ok(KeyPair {
334        key_id,
335        algorithm: SignatureAlgorithm::RsaSha256,
336        public_key: b"rsa_public_key_placeholder".to_vec(),
337        private_key: Some(b"rsa_private_key_placeholder".to_vec()),
338    })
339}
340
341fn generate_ed25519_key_pair(key_id: String) -> Result<KeyPair> {
342    // Placeholder: In real implementation, generate Ed25519 key pair
343    // This would use libraries like `ed25519-dalek`
344    Ok(KeyPair {
345        key_id,
346        algorithm: SignatureAlgorithm::Ed25519,
347        public_key: b"ed25519_public_key_placeholder".to_vec(),
348        private_key: Some(b"ed25519_private_key_placeholder".to_vec()),
349    })
350}
351
352fn generate_ecdsa_p256_key_pair(key_id: String) -> Result<KeyPair> {
353    // Placeholder: In real implementation, generate ECDSA P-256 key pair
354    // This would use libraries like `p256` and `ecdsa`
355    Ok(KeyPair {
356        key_id,
357        algorithm: SignatureAlgorithm::EcdsaP256,
358        public_key: b"ecdsa_public_key_placeholder".to_vec(),
359        private_key: Some(b"ecdsa_private_key_placeholder".to_vec()),
360    })
361}
362
363/// Security configuration for downloads and model loading
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct SecurityConfig {
366    /// Require signatures for all models
367    pub require_signatures: bool,
368    /// Only accept models signed by trusted keys
369    pub require_trusted_keys: bool,
370    /// Maximum age of signatures in seconds
371    pub max_signature_age: Option<u64>,
372    /// List of blocked model sources
373    pub blocked_sources: Vec<String>,
374    /// List of allowed model sources (if empty, all sources are allowed)
375    pub allowed_sources: Vec<String>,
376    /// Verify file integrity on load
377    pub verify_integrity: bool,
378}
379
380impl Default for SecurityConfig {
381    fn default() -> Self {
382        Self {
383            require_signatures: false,
384            require_trusted_keys: false,
385            max_signature_age: Some(30 * 24 * 3600), // 30 days
386            blocked_sources: Vec::new(),
387            allowed_sources: Vec::new(),
388            verify_integrity: true,
389        }
390    }
391}
392
393/// Validate a model source against security configuration
394pub fn validate_model_source(url: &str, config: &SecurityConfig) -> Result<()> {
395    // Check blocked sources
396    for blocked in &config.blocked_sources {
397        if url.contains(blocked) {
398            return Err(TorshError::General(GeneralError::RuntimeError(format!(
399                "Model source '{}' is blocked",
400                blocked
401            ))));
402        }
403    }
404
405    // Check allowed sources (if specified)
406    if !config.allowed_sources.is_empty() {
407        let is_allowed = config
408            .allowed_sources
409            .iter()
410            .any(|allowed| url.contains(allowed));
411        if !is_allowed {
412            return Err(TorshError::General(GeneralError::RuntimeError(
413                "Model source is not in the allowed list".to_string(),
414            )));
415        }
416    }
417
418    Ok(())
419}
420
421/// Validate signature age
422pub fn validate_signature_age(signature: &ModelSignature, max_age: Option<u64>) -> Result<()> {
423    if let Some(max_age_secs) = max_age {
424        let current_time = SystemTime::now()
425            .duration_since(UNIX_EPOCH)
426            .expect("system time should be after UNIX epoch")
427            .as_secs();
428
429        let age = current_time.saturating_sub(signature.timestamp);
430        if age > max_age_secs {
431            return Err(TorshError::General(GeneralError::RuntimeError(format!(
432                "Signature is too old: {} seconds (max: {})",
433                age, max_age_secs
434            ))));
435        }
436    }
437
438    Ok(())
439}
440
441/// Sandbox configuration for model execution
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct SandboxConfig {
444    /// Maximum memory usage in bytes
445    pub max_memory: usize,
446    /// Maximum execution time in seconds
447    pub max_execution_time: u64,
448    /// Maximum number of threads
449    pub max_threads: usize,
450    /// Allow network access
451    pub allow_network: bool,
452    /// Allow file system access
453    pub allow_filesystem: bool,
454    /// Allowed file paths for read access
455    pub read_paths: Vec<String>,
456    /// Allowed file paths for write access
457    pub write_paths: Vec<String>,
458    /// Maximum CPU usage percentage
459    pub max_cpu_usage: f32,
460}
461
462impl Default for SandboxConfig {
463    fn default() -> Self {
464        Self {
465            max_memory: 1024 * 1024 * 1024, // 1GB
466            max_execution_time: 300,        // 5 minutes
467            max_threads: 4,
468            allow_network: false,
469            allow_filesystem: false,
470            read_paths: Vec::new(),
471            write_paths: Vec::new(),
472            max_cpu_usage: 80.0,
473        }
474    }
475}
476
477/// Resource usage tracking
478#[derive(Debug, Clone, Default)]
479pub struct ResourceUsage {
480    pub memory_used: usize,
481    pub cpu_time: f64,
482    pub threads_created: usize,
483    pub network_requests: usize,
484    pub file_reads: usize,
485    pub file_writes: usize,
486    pub start_time: Option<SystemTime>,
487}
488
489/// Sandbox environment for model execution
490pub struct ModelSandbox {
491    config: SandboxConfig,
492    usage: Arc<Mutex<ResourceUsage>>,
493    is_active: Arc<Mutex<bool>>,
494}
495
496impl ModelSandbox {
497    /// Create a new sandbox with configuration
498    pub fn new(config: SandboxConfig) -> Self {
499        Self {
500            config,
501            usage: Arc::new(Mutex::new(ResourceUsage::default())),
502            is_active: Arc::new(Mutex::new(false)),
503        }
504    }
505
506    /// Enter the sandbox environment
507    pub fn enter(&self) -> Result<SandboxGuard<'_>> {
508        let mut is_active = self.is_active.lock().expect("lock should not be poisoned");
509        if *is_active {
510            return Err(TorshError::General(GeneralError::RuntimeError(
511                "Sandbox is already active".to_string(),
512            )));
513        }
514
515        *is_active = true;
516
517        // Initialize resource tracking
518        {
519            let mut usage = self.usage.lock().expect("lock should not be poisoned");
520            *usage = ResourceUsage {
521                start_time: Some(SystemTime::now()),
522                ..Default::default()
523            };
524        }
525
526        // Set up resource limits (platform-specific implementation would go here)
527        self.setup_memory_limits()?;
528        self.setup_thread_limits()?;
529        self.setup_network_limits()?;
530        self.setup_filesystem_limits()?;
531
532        Ok(SandboxGuard {
533            sandbox: self,
534            _phantom: std::marker::PhantomData,
535        })
536    }
537
538    /// Check if resource limits are exceeded
539    pub fn check_limits(&self) -> Result<()> {
540        let usage = self.usage.lock().expect("lock should not be poisoned");
541
542        // Check memory limit
543        if usage.memory_used > self.config.max_memory {
544            return Err(TorshError::General(GeneralError::RuntimeError(format!(
545                "Memory limit exceeded: {} > {}",
546                usage.memory_used, self.config.max_memory
547            ))));
548        }
549
550        // Check execution time limit
551        if let Some(start_time) = usage.start_time {
552            let elapsed = SystemTime::now()
553                .duration_since(start_time)
554                .expect("current time should be after start_time")
555                .as_secs();
556            if elapsed > self.config.max_execution_time {
557                return Err(TorshError::General(GeneralError::RuntimeError(format!(
558                    "Execution time limit exceeded: {} > {}",
559                    elapsed, self.config.max_execution_time
560                ))));
561            }
562        }
563
564        // Check thread limit
565        if usage.threads_created > self.config.max_threads {
566            return Err(TorshError::General(GeneralError::RuntimeError(format!(
567                "Thread limit exceeded: {} > {}",
568                usage.threads_created, self.config.max_threads
569            ))));
570        }
571
572        Ok(())
573    }
574
575    /// Record memory usage
576    pub fn record_memory_usage(&self, bytes: usize) {
577        let mut usage = self.usage.lock().expect("lock should not be poisoned");
578        usage.memory_used = usage.memory_used.saturating_add(bytes);
579    }
580
581    /// Record thread creation
582    pub fn record_thread_creation(&self) {
583        let mut usage = self.usage.lock().expect("lock should not be poisoned");
584        usage.threads_created += 1;
585    }
586
587    /// Record network request
588    pub fn record_network_request(&self) -> Result<()> {
589        if !self.config.allow_network {
590            return Err(TorshError::General(GeneralError::RuntimeError(
591                "Network access is not allowed in sandbox".to_string(),
592            )));
593        }
594
595        let mut usage = self.usage.lock().expect("lock should not be poisoned");
596        usage.network_requests += 1;
597        Ok(())
598    }
599
600    /// Record file system access
601    pub fn record_file_access(&self, path: &str, is_write: bool) -> Result<()> {
602        if !self.config.allow_filesystem {
603            return Err(TorshError::General(GeneralError::RuntimeError(
604                "File system access is not allowed in sandbox".to_string(),
605            )));
606        }
607
608        // Check if path is allowed
609        let allowed_paths = if is_write {
610            &self.config.write_paths
611        } else {
612            &self.config.read_paths
613        };
614
615        if !allowed_paths.is_empty() {
616            let is_allowed = allowed_paths
617                .iter()
618                .any(|allowed_path| path.starts_with(allowed_path));
619
620            if !is_allowed {
621                return Err(TorshError::General(GeneralError::RuntimeError(format!(
622                    "Access to path '{}' is not allowed",
623                    path
624                ))));
625            }
626        }
627
628        let mut usage = self.usage.lock().expect("lock should not be poisoned");
629        if is_write {
630            usage.file_writes += 1;
631        } else {
632            usage.file_reads += 1;
633        }
634
635        Ok(())
636    }
637
638    /// Get current resource usage
639    pub fn get_usage(&self) -> ResourceUsage {
640        self.usage
641            .lock()
642            .expect("lock should not be poisoned")
643            .clone()
644    }
645
646    /// Exit the sandbox (private, called by SandboxGuard)
647    fn exit(&self) {
648        let mut is_active = self.is_active.lock().expect("lock should not be poisoned");
649        *is_active = false;
650
651        // Clean up resource limits
652        self.cleanup_limits();
653    }
654
655    // Platform-specific implementations (simplified for demonstration)
656    fn setup_memory_limits(&self) -> Result<()> {
657        // In a real implementation, this would use platform-specific APIs
658        // like setrlimit on Unix or SetProcessWorkingSetSize on Windows
659        println!("Setting up memory limits: {} bytes", self.config.max_memory);
660        Ok(())
661    }
662
663    fn setup_thread_limits(&self) -> Result<()> {
664        // In a real implementation, this would limit thread creation
665        println!(
666            "Setting up thread limits: {} threads",
667            self.config.max_threads
668        );
669        Ok(())
670    }
671
672    fn setup_network_limits(&self) -> Result<()> {
673        if !self.config.allow_network {
674            // In a real implementation, this would block network access
675            println!("Blocking network access");
676        }
677        Ok(())
678    }
679
680    fn setup_filesystem_limits(&self) -> Result<()> {
681        if !self.config.allow_filesystem {
682            // In a real implementation, this would use chroot or similar
683            println!("Restricting filesystem access");
684        }
685        Ok(())
686    }
687
688    fn cleanup_limits(&self) {
689        // Clean up any resources or restore original limits
690        println!("Cleaning up sandbox limits");
691    }
692}
693
694/// RAII guard for sandbox environment
695pub struct SandboxGuard<'a> {
696    sandbox: &'a ModelSandbox,
697    _phantom: std::marker::PhantomData<&'a ()>,
698}
699
700impl<'a> Drop for SandboxGuard<'a> {
701    fn drop(&mut self) {
702        self.sandbox.exit();
703    }
704}
705
706/// Wrapper for sandboxed model execution
707pub struct SandboxedModel {
708    model: Box<dyn torsh_nn::Module>,
709    sandbox: std::sync::RwLock<ModelSandbox>,
710}
711
712impl SandboxedModel {
713    /// Create a new sandboxed model
714    pub fn new(model: Box<dyn torsh_nn::Module>, config: SandboxConfig) -> Self {
715        Self {
716            model,
717            sandbox: std::sync::RwLock::new(ModelSandbox::new(config)),
718        }
719    }
720
721    /// Execute model forward pass in sandbox
722    pub fn forward_sandboxed(
723        &self,
724        input: &torsh_tensor::Tensor<f32>,
725    ) -> Result<torsh_tensor::Tensor<f32>> {
726        {
727            let sandbox = self.sandbox.read().expect("lock should not be poisoned");
728            let _guard = sandbox.enter()?;
729
730            // Record memory usage for input tensor
731            let input_elements = input.shape().dims().iter().product::<usize>();
732            let input_memory = input_elements * std::mem::size_of::<f32>();
733            sandbox.record_memory_usage(input_memory);
734
735            // Check limits before execution
736            sandbox.check_limits()?;
737        } // guard and sandbox are dropped here
738
739        // Execute model
740        let result = self.model.forward(input)?;
741
742        // Record memory usage for output tensor
743        {
744            let sandbox = self.sandbox.read().expect("lock should not be poisoned");
745            let output_elements = result.shape().dims().iter().product::<usize>();
746            let output_memory = output_elements * std::mem::size_of::<f32>();
747            sandbox.record_memory_usage(output_memory);
748
749            // Check limits after execution
750            sandbox.check_limits()?;
751        }
752
753        Ok(result)
754    }
755
756    /// Get sandbox resource usage
757    pub fn get_sandbox_usage(&self) -> ResourceUsage {
758        self.sandbox
759            .read()
760            .expect("lock should not be poisoned")
761            .get_usage()
762    }
763}
764
765impl torsh_nn::Module for SandboxedModel {
766    fn forward(&self, input: &torsh_tensor::Tensor) -> Result<torsh_tensor::Tensor> {
767        // For now, we'll convert the input for backward compatibility
768        // In a real implementation, this would need proper type handling
769        let result = self.forward_sandboxed(input)?;
770        Ok(result)
771    }
772
773    fn parameters(&self) -> HashMap<String, torsh_nn::Parameter> {
774        self.model.parameters()
775    }
776
777    fn train(&mut self) {
778        self.model.train()
779    }
780
781    fn eval(&mut self) {
782        self.model.eval()
783    }
784
785    fn training(&self) -> bool {
786        self.model.training()
787    }
788
789    fn load_state_dict(
790        &mut self,
791        state_dict: &std::collections::HashMap<String, torsh_tensor::Tensor<f32>>,
792        strict: bool,
793    ) -> Result<()> {
794        self.model.load_state_dict(state_dict, strict)
795    }
796
797    fn state_dict(&self) -> std::collections::HashMap<String, torsh_tensor::Tensor<f32>> {
798        self.model.state_dict()
799    }
800}
801
802/// Create a sandboxed wrapper for any model
803pub fn sandbox_model(
804    model: Box<dyn torsh_nn::Module>,
805    config: Option<SandboxConfig>,
806) -> SandboxedModel {
807    let config = config.unwrap_or_default();
808    SandboxedModel::new(model, config)
809}
810
811/// Vulnerability scanner for model files
812#[derive(Debug, Clone, Serialize, Deserialize)]
813pub struct VulnerabilityScanner {
814    /// Known malicious patterns (simplified for demonstration)
815    malicious_patterns: Vec<String>,
816    /// Known vulnerable model signatures
817    vulnerable_signatures: Vec<String>,
818    /// Maximum file size to scan (in bytes)
819    max_scan_size: usize,
820    /// Enable deep scanning (more thorough but slower)
821    deep_scan: bool,
822}
823
824impl Default for VulnerabilityScanner {
825    fn default() -> Self {
826        Self {
827            malicious_patterns: vec![
828                // Common malicious patterns
829                "eval(".to_string(),
830                "exec(".to_string(),
831                "__import__".to_string(),
832                "subprocess.".to_string(),
833                "os.system".to_string(),
834                "shell=True".to_string(),
835                // Suspicious model operations
836                "torch.jit._script_to_py".to_string(),
837                "_C.import_ir_module".to_string(),
838                // Network access patterns
839                "urllib.request".to_string(),
840                "requests.get".to_string(),
841                "socket.".to_string(),
842                // File system patterns
843                "open(".to_string(),
844                "file.write".to_string(),
845                "pathlib".to_string(),
846            ],
847            vulnerable_signatures: vec![
848                // Known vulnerable model hashes (examples)
849                "bad_model_hash_1".to_string(),
850                "bad_model_hash_2".to_string(),
851            ],
852            max_scan_size: 100 * 1024 * 1024, // 100MB
853            deep_scan: true,
854        }
855    }
856}
857
858/// Vulnerability scan result
859#[derive(Debug, Clone, Serialize, Deserialize)]
860pub struct VulnerabilityScanResult {
861    /// Whether the scan completed successfully
862    pub success: bool,
863    /// List of vulnerabilities found
864    pub vulnerabilities: Vec<Vulnerability>,
865    /// Risk level assessment
866    pub risk_level: RiskLevel,
867    /// Scan metadata
868    pub scan_metadata: ScanMetadata,
869}
870
871/// Individual vulnerability
872#[derive(Debug, Clone, Serialize, Deserialize)]
873pub struct Vulnerability {
874    /// Vulnerability type
875    pub vuln_type: VulnerabilityType,
876    /// Severity level
877    pub severity: Severity,
878    /// Description of the vulnerability
879    pub description: String,
880    /// Location where vulnerability was found
881    pub location: String,
882    /// Remediation suggestions
883    pub remediation: String,
884}
885
886/// Types of vulnerabilities
887#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
888pub enum VulnerabilityType {
889    /// Malicious code execution
890    CodeExecution,
891    /// Unauthorized network access
892    NetworkAccess,
893    /// Unauthorized file system access
894    FileSystemAccess,
895    /// Data exfiltration
896    DataExfiltration,
897    /// Known vulnerable model
898    KnownVulnerable,
899    /// Suspicious patterns
900    SuspiciousPattern,
901    /// Cryptographic issues
902    CryptographicIssue,
903    /// Memory corruption
904    MemoryCorruption,
905}
906
907/// Severity levels
908#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
909pub enum Severity {
910    Low,
911    Medium,
912    High,
913    Critical,
914}
915
916/// Overall risk level
917#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
918pub enum RiskLevel {
919    Safe,
920    Low,
921    Medium,
922    High,
923    Critical,
924}
925
926/// Scan metadata
927#[derive(Debug, Clone, Serialize, Deserialize)]
928pub struct ScanMetadata {
929    /// Time when scan was performed
930    pub scan_time: u64,
931    /// Duration of scan in milliseconds
932    pub scan_duration: u64,
933    /// Number of files scanned
934    pub files_scanned: usize,
935    /// Total bytes scanned
936    pub bytes_scanned: usize,
937    /// Scanner version
938    pub scanner_version: String,
939}
940
941impl VulnerabilityScanner {
942    /// Create a new vulnerability scanner
943    pub fn new() -> Self {
944        Self::default()
945    }
946
947    /// Create a scanner with custom configuration
948    pub fn with_config(
949        max_scan_size: usize,
950        deep_scan: bool,
951        custom_patterns: Vec<String>,
952    ) -> Self {
953        let mut malicious_patterns = Self::default().malicious_patterns;
954        malicious_patterns.extend(custom_patterns);
955        Self {
956            max_scan_size,
957            deep_scan,
958            malicious_patterns,
959            ..Default::default()
960        }
961    }
962
963    /// Scan a model file for vulnerabilities
964    pub fn scan_file<P: AsRef<Path>>(&self, file_path: P) -> Result<VulnerabilityScanResult> {
965        let file_path = file_path.as_ref();
966        let start_time = SystemTime::now();
967
968        // Check file exists and size
969        let metadata = std::fs::metadata(file_path)?;
970        if metadata.len() > self.max_scan_size as u64 {
971            return Ok(VulnerabilityScanResult {
972                success: false,
973                vulnerabilities: vec![Vulnerability {
974                    vuln_type: VulnerabilityType::SuspiciousPattern,
975                    severity: Severity::Medium,
976                    description: format!("File too large to scan: {} bytes", metadata.len()),
977                    location: file_path.to_string_lossy().to_string(),
978                    remediation: "Manually verify large model files".to_string(),
979                }],
980                risk_level: RiskLevel::Medium,
981                scan_metadata: self.create_metadata(start_time, 1, metadata.len() as usize),
982            });
983        }
984
985        let mut vulnerabilities = Vec::new();
986
987        // Calculate file hash and check against known vulnerable models
988        let file_hash = calculate_file_hash(file_path)?;
989        if self.vulnerable_signatures.contains(&file_hash) {
990            vulnerabilities.push(Vulnerability {
991                vuln_type: VulnerabilityType::KnownVulnerable,
992                severity: Severity::Critical,
993                description: "Model matches known vulnerable signature".to_string(),
994                location: file_path.to_string_lossy().to_string(),
995                remediation: "Do not use this model. Find alternative from trusted source"
996                    .to_string(),
997            });
998        }
999
1000        // Scan file content for malicious patterns
1001        let content_vulns = self.scan_file_content(file_path)?;
1002        vulnerabilities.extend(content_vulns);
1003
1004        // If deep scan is enabled, perform additional checks
1005        if self.deep_scan {
1006            let deep_vulns = self.deep_scan_file(file_path)?;
1007            vulnerabilities.extend(deep_vulns);
1008        }
1009
1010        // Assess overall risk level
1011        let risk_level = self.assess_risk_level(&vulnerabilities);
1012
1013        let end_time = SystemTime::now();
1014        let scan_duration = end_time
1015            .duration_since(start_time)
1016            .expect("end_time should be after start_time")
1017            .as_millis() as u64;
1018
1019        Ok(VulnerabilityScanResult {
1020            success: true,
1021            vulnerabilities,
1022            risk_level,
1023            scan_metadata: ScanMetadata {
1024                scan_time: start_time
1025                    .duration_since(UNIX_EPOCH)
1026                    .expect("start_time should be after UNIX epoch")
1027                    .as_secs(),
1028                scan_duration,
1029                files_scanned: 1,
1030                bytes_scanned: metadata.len() as usize,
1031                scanner_version: "1.0.0".to_string(),
1032            },
1033        })
1034    }
1035
1036    /// Scan file content for malicious patterns
1037    fn scan_file_content<P: AsRef<Path>>(&self, file_path: P) -> Result<Vec<Vulnerability>> {
1038        let file_path = file_path.as_ref();
1039        let mut vulnerabilities = Vec::new();
1040
1041        // Read file content (limited to max_scan_size)
1042        let file = File::open(file_path)?;
1043        let mut reader = BufReader::new(file);
1044        let mut content = Vec::new();
1045        reader.read_to_end(&mut content)?;
1046
1047        // Convert to string for pattern matching (may lose some data for binary files)
1048        let content_str = String::from_utf8_lossy(&content);
1049
1050        // Check for malicious patterns
1051        for pattern in &self.malicious_patterns {
1052            if content_str.contains(pattern) {
1053                let vuln_type = self.classify_pattern(pattern);
1054                let severity = self.assess_pattern_severity(pattern);
1055
1056                vulnerabilities.push(Vulnerability {
1057                    vuln_type,
1058                    severity,
1059                    description: format!("Found suspicious pattern: {}", pattern),
1060                    location: file_path.to_string_lossy().to_string(),
1061                    remediation: "Review model source and verify legitimacy".to_string(),
1062                });
1063            }
1064        }
1065
1066        Ok(vulnerabilities)
1067    }
1068
1069    /// Perform deep scan for additional security issues
1070    fn deep_scan_file<P: AsRef<Path>>(&self, file_path: P) -> Result<Vec<Vulnerability>> {
1071        let file_path = file_path.as_ref();
1072        let mut vulnerabilities = Vec::new();
1073
1074        // Check file extension
1075        if let Some(extension) = file_path.extension() {
1076            if let Some("exe" | "bat" | "sh" | "ps1") = extension.to_str() {
1077                vulnerabilities.push(Vulnerability {
1078                    vuln_type: VulnerabilityType::CodeExecution,
1079                    severity: Severity::High,
1080                    description: "Model file has executable extension".to_string(),
1081                    location: file_path.to_string_lossy().to_string(),
1082                    remediation: "Verify this is actually a model file and not malware".to_string(),
1083                });
1084            }
1085        }
1086
1087        // Check for hidden files or suspicious names
1088        if let Some(filename) = file_path.file_name() {
1089            if let Some(filename_str) = filename.to_str() {
1090                if filename_str.starts_with('.') || filename_str.contains("..") {
1091                    vulnerabilities.push(Vulnerability {
1092                        vuln_type: VulnerabilityType::SuspiciousPattern,
1093                        severity: Severity::Medium,
1094                        description: "Suspicious filename pattern".to_string(),
1095                        location: file_path.to_string_lossy().to_string(),
1096                        remediation: "Verify file purpose and rename to standard convention"
1097                            .to_string(),
1098                    });
1099                }
1100            }
1101        }
1102
1103        // Check file permissions (Unix-like systems)
1104        #[cfg(unix)]
1105        {
1106            use std::os::unix::fs::PermissionsExt;
1107            let metadata = std::fs::metadata(file_path)?;
1108            let permissions = metadata.permissions();
1109            let mode = permissions.mode();
1110
1111            // Check if file is executable
1112            if mode & 0o111 != 0 {
1113                vulnerabilities.push(Vulnerability {
1114                    vuln_type: VulnerabilityType::CodeExecution,
1115                    severity: Severity::Medium,
1116                    description: "Model file has execute permissions".to_string(),
1117                    location: file_path.to_string_lossy().to_string(),
1118                    remediation: "Remove execute permissions: chmod -x filename".to_string(),
1119                });
1120            }
1121        }
1122
1123        Ok(vulnerabilities)
1124    }
1125
1126    /// Classify a pattern to determine vulnerability type
1127    fn classify_pattern(&self, pattern: &str) -> VulnerabilityType {
1128        match pattern {
1129            p if p.contains("eval") || p.contains("exec") || p.contains("subprocess") => {
1130                VulnerabilityType::CodeExecution
1131            }
1132            p if p.contains("socket") || p.contains("urllib") || p.contains("requests") => {
1133                VulnerabilityType::NetworkAccess
1134            }
1135            p if p.contains("open") || p.contains("file") || p.contains("pathlib") => {
1136                VulnerabilityType::FileSystemAccess
1137            }
1138            _ => VulnerabilityType::SuspiciousPattern,
1139        }
1140    }
1141
1142    /// Assess severity of a pattern
1143    fn assess_pattern_severity(&self, pattern: &str) -> Severity {
1144        match pattern {
1145            p if p.contains("eval") || p.contains("exec") => Severity::Critical,
1146            p if p.contains("subprocess") || p.contains("os.system") => Severity::High,
1147            p if p.contains("socket") || p.contains("urllib") => Severity::Medium,
1148            _ => Severity::Low,
1149        }
1150    }
1151
1152    /// Assess overall risk level based on vulnerabilities
1153    fn assess_risk_level(&self, vulnerabilities: &[Vulnerability]) -> RiskLevel {
1154        if vulnerabilities.is_empty() {
1155            return RiskLevel::Safe;
1156        }
1157
1158        let max_severity = vulnerabilities
1159            .iter()
1160            .map(|v| &v.severity)
1161            .max()
1162            .unwrap_or(&Severity::Low);
1163
1164        match max_severity {
1165            Severity::Critical => RiskLevel::Critical,
1166            Severity::High => RiskLevel::High,
1167            Severity::Medium => RiskLevel::Medium,
1168            Severity::Low => RiskLevel::Low,
1169        }
1170    }
1171
1172    /// Create scan metadata
1173    fn create_metadata(
1174        &self,
1175        start_time: SystemTime,
1176        files_scanned: usize,
1177        bytes_scanned: usize,
1178    ) -> ScanMetadata {
1179        let end_time = SystemTime::now();
1180        let scan_duration = end_time
1181            .duration_since(start_time)
1182            .expect("end_time should be after start_time")
1183            .as_millis() as u64;
1184
1185        ScanMetadata {
1186            scan_time: start_time
1187                .duration_since(UNIX_EPOCH)
1188                .expect("start_time should be after UNIX epoch")
1189                .as_secs(),
1190            scan_duration,
1191            files_scanned,
1192            bytes_scanned,
1193            scanner_version: "1.0.0".to_string(),
1194        }
1195    }
1196
1197    /// Add custom malicious pattern
1198    pub fn add_pattern(&mut self, pattern: String) {
1199        if !self.malicious_patterns.contains(&pattern) {
1200            self.malicious_patterns.push(pattern);
1201        }
1202    }
1203
1204    /// Add known vulnerable signature
1205    pub fn add_vulnerable_signature(&mut self, signature: String) {
1206        if !self.vulnerable_signatures.contains(&signature) {
1207            self.vulnerable_signatures.push(signature);
1208        }
1209    }
1210
1211    /// Get all patterns
1212    pub fn get_patterns(&self) -> &[String] {
1213        &self.malicious_patterns
1214    }
1215
1216    /// Get all vulnerable signatures
1217    pub fn get_vulnerable_signatures(&self) -> &[String] {
1218        &self.vulnerable_signatures
1219    }
1220}
1221
1222/// Convenience function to scan a model file for vulnerabilities
1223pub fn scan_model_vulnerabilities<P: AsRef<Path>>(
1224    file_path: P,
1225    scanner: Option<VulnerabilityScanner>,
1226) -> Result<VulnerabilityScanResult> {
1227    let scanner = scanner.unwrap_or_default();
1228    scanner.scan_file(file_path)
1229}
1230
1231#[cfg(test)]
1232mod tests {
1233    use super::*;
1234    use std::io::Write;
1235    use tempfile::NamedTempFile;
1236
1237    #[test]
1238    fn test_calculate_file_hash() {
1239        let mut temp_file = NamedTempFile::new().unwrap();
1240        temp_file.write_all(b"test content").unwrap();
1241        temp_file.flush().unwrap();
1242
1243        let hash = calculate_file_hash(temp_file.path()).unwrap();
1244        assert_eq!(hash.len(), 64); // SHA-256 produces 64-char hex string
1245    }
1246
1247    #[test]
1248    fn test_verify_file_integrity() {
1249        let mut temp_file = NamedTempFile::new().unwrap();
1250        temp_file.write_all(b"test content").unwrap();
1251        temp_file.flush().unwrap();
1252
1253        let hash = calculate_file_hash(temp_file.path()).unwrap();
1254        assert!(verify_file_integrity(temp_file.path(), &hash).unwrap());
1255        assert!(!verify_file_integrity(temp_file.path(), "wrong_hash").unwrap());
1256    }
1257
1258    #[test]
1259    fn test_security_manager() {
1260        let mut manager = SecurityManager::new();
1261
1262        // Generate key pair
1263        let key_pair =
1264            SecurityManager::generate_key_pair("test_key".to_string(), SignatureAlgorithm::Ed25519)
1265                .unwrap();
1266
1267        // Add key to manager
1268        manager.add_key(key_pair);
1269
1270        // Trust the key
1271        manager.trust_key("test_key").unwrap();
1272
1273        assert!(manager.trusted_keys().contains(&"test_key".to_string()));
1274    }
1275
1276    #[test]
1277    fn test_validate_model_source() {
1278        let config = SecurityConfig {
1279            blocked_sources: vec!["malicious.com".to_string()],
1280            allowed_sources: vec!["trusted.com".to_string()],
1281            ..Default::default()
1282        };
1283
1284        // Should fail for blocked source
1285        assert!(validate_model_source("https://malicious.com/model.torsh", &config).is_err());
1286
1287        // Should fail for non-allowed source when allowed list is specified
1288        assert!(validate_model_source("https://unknown.com/model.torsh", &config).is_err());
1289
1290        // Should succeed for allowed source
1291        assert!(validate_model_source("https://trusted.com/model.torsh", &config).is_ok());
1292    }
1293
1294    #[test]
1295    fn test_validate_signature_age() {
1296        let current_time = SystemTime::now()
1297            .duration_since(UNIX_EPOCH)
1298            .unwrap()
1299            .as_secs();
1300
1301        // Recent signature should be valid
1302        let recent_signature = ModelSignature {
1303            timestamp: current_time - 100, // 100 seconds ago
1304            file_hash: "hash".to_string(),
1305            signature: "sig".to_string(),
1306            key_id: "key".to_string(),
1307            algorithm: SignatureAlgorithm::Ed25519,
1308            metadata: HashMap::new(),
1309        };
1310
1311        assert!(validate_signature_age(&recent_signature, Some(3600)).is_ok());
1312
1313        // Old signature should be invalid
1314        let old_signature = ModelSignature {
1315            timestamp: current_time - 7200, // 2 hours ago
1316            ..recent_signature
1317        };
1318
1319        assert!(validate_signature_age(&old_signature, Some(3600)).is_err());
1320    }
1321
1322    #[test]
1323    fn test_sandbox_config() {
1324        let config = SandboxConfig::default();
1325        assert_eq!(config.max_memory, 1024 * 1024 * 1024);
1326        assert_eq!(config.max_execution_time, 300);
1327        assert_eq!(config.max_threads, 4);
1328        assert!(!config.allow_network);
1329        assert!(!config.allow_filesystem);
1330    }
1331
1332    #[test]
1333    fn test_sandbox_creation() {
1334        let config = SandboxConfig::default();
1335        let sandbox = ModelSandbox::new(config);
1336
1337        let usage = sandbox.get_usage();
1338        assert_eq!(usage.memory_used, 0);
1339        assert_eq!(usage.threads_created, 0);
1340        assert_eq!(usage.network_requests, 0);
1341    }
1342
1343    #[test]
1344    fn test_sandbox_resource_tracking() {
1345        let config = SandboxConfig::default();
1346        let sandbox = ModelSandbox::new(config);
1347
1348        // Test memory tracking
1349        sandbox.record_memory_usage(1024);
1350        let usage = sandbox.get_usage();
1351        assert_eq!(usage.memory_used, 1024);
1352
1353        // Test thread tracking
1354        sandbox.record_thread_creation();
1355        let usage = sandbox.get_usage();
1356        assert_eq!(usage.threads_created, 1);
1357    }
1358
1359    #[test]
1360    fn test_sandbox_network_restrictions() {
1361        let config = SandboxConfig {
1362            allow_network: false,
1363            ..Default::default()
1364        };
1365        let sandbox = ModelSandbox::new(config);
1366
1367        // Should fail when network is not allowed
1368        assert!(sandbox.record_network_request().is_err());
1369
1370        let config = SandboxConfig {
1371            allow_network: true,
1372            ..Default::default()
1373        };
1374        let sandbox = ModelSandbox::new(config);
1375
1376        // Should succeed when network is allowed
1377        assert!(sandbox.record_network_request().is_ok());
1378    }
1379
1380    #[test]
1381    fn test_sandbox_filesystem_restrictions() {
1382        let config = SandboxConfig {
1383            allow_filesystem: false,
1384            ..Default::default()
1385        };
1386        let sandbox = ModelSandbox::new(config);
1387
1388        let test_path = std::env::temp_dir().join("test");
1389        let test_path_str = test_path.to_string_lossy();
1390        // Should fail when filesystem access is not allowed
1391        assert!(sandbox.record_file_access(&test_path_str, false).is_err());
1392
1393        let temp_dir_str = std::env::temp_dir().to_string_lossy().into_owned();
1394        let config = SandboxConfig {
1395            allow_filesystem: true,
1396            read_paths: vec![temp_dir_str],
1397            ..Default::default()
1398        };
1399        let sandbox = ModelSandbox::new(config);
1400
1401        // Should succeed for allowed path
1402        assert!(sandbox.record_file_access(&test_path_str, false).is_ok());
1403
1404        // Should fail for disallowed path
1405        assert!(sandbox.record_file_access("/etc/passwd", false).is_err());
1406    }
1407
1408    #[test]
1409    fn test_sandbox_guard() {
1410        let config = SandboxConfig::default();
1411        let sandbox = ModelSandbox::new(config);
1412
1413        // Test that guard properly manages sandbox state
1414        {
1415            let _guard = sandbox.enter().unwrap();
1416            // Sandbox should be active here
1417        }
1418        // Sandbox should be inactive after guard is dropped
1419
1420        // Should be able to enter again
1421        let _guard = sandbox.enter().unwrap();
1422    }
1423
1424    #[test]
1425    fn test_vulnerability_scanner_creation() {
1426        let scanner = VulnerabilityScanner::new();
1427        assert!(!scanner.get_patterns().is_empty());
1428        assert!(!scanner.get_vulnerable_signatures().is_empty());
1429    }
1430
1431    #[test]
1432    fn test_vulnerability_scanner_custom_config() {
1433        let custom_patterns = vec!["custom_pattern".to_string()];
1434        let scanner = VulnerabilityScanner::with_config(
1435            50 * 1024 * 1024, // 50MB
1436            false,
1437            custom_patterns.clone(),
1438        );
1439
1440        assert!(scanner
1441            .get_patterns()
1442            .contains(&"custom_pattern".to_string()));
1443    }
1444
1445    #[test]
1446    fn test_vulnerability_scanner_pattern_detection() {
1447        let mut temp_file = NamedTempFile::new().unwrap();
1448        temp_file
1449            .write_all(b"This file contains eval() function")
1450            .unwrap();
1451        temp_file.flush().unwrap();
1452
1453        let scanner = VulnerabilityScanner::new();
1454        let result = scanner.scan_file(temp_file.path()).unwrap();
1455
1456        assert!(result.success);
1457        assert!(!result.vulnerabilities.is_empty());
1458        assert_eq!(result.risk_level, RiskLevel::Critical);
1459
1460        // Check that the vulnerability was detected
1461        let has_eval_vuln = result
1462            .vulnerabilities
1463            .iter()
1464            .any(|v| v.description.contains("eval"));
1465        assert!(has_eval_vuln);
1466    }
1467
1468    #[test]
1469    fn test_vulnerability_scanner_clean_file() {
1470        let mut temp_file = NamedTempFile::new().unwrap();
1471        temp_file
1472            .write_all(b"This is a clean model file with no suspicious content")
1473            .unwrap();
1474        temp_file.flush().unwrap();
1475
1476        // Create scanner with no malicious patterns for clean test
1477        let scanner = VulnerabilityScanner::with_config(
1478            100 * 1024 * 1024, // 100MB max size
1479            false,             // disable deep scan to avoid extension-based false positives
1480            vec![],            // no custom patterns
1481        );
1482        let result = scanner.scan_file(temp_file.path()).unwrap();
1483
1484        assert!(result.success);
1485        assert!(result.vulnerabilities.is_empty());
1486        assert_eq!(result.risk_level, RiskLevel::Safe);
1487    }
1488
1489    #[test]
1490    fn test_vulnerability_scanner_known_vulnerable() {
1491        let mut temp_file = NamedTempFile::new().unwrap();
1492        temp_file.write_all(b"test content").unwrap();
1493        temp_file.flush().unwrap();
1494
1495        let file_hash = calculate_file_hash(temp_file.path()).unwrap();
1496
1497        let mut scanner = VulnerabilityScanner::new();
1498        scanner.add_vulnerable_signature(file_hash);
1499
1500        let result = scanner.scan_file(temp_file.path()).unwrap();
1501
1502        assert!(result.success);
1503        assert!(!result.vulnerabilities.is_empty());
1504        assert_eq!(result.risk_level, RiskLevel::Critical);
1505
1506        // Check that known vulnerable signature was detected
1507        let has_known_vuln = result
1508            .vulnerabilities
1509            .iter()
1510            .any(|v| v.vuln_type == VulnerabilityType::KnownVulnerable);
1511        assert!(has_known_vuln);
1512    }
1513
1514    #[test]
1515    fn test_vulnerability_pattern_classification() {
1516        let scanner = VulnerabilityScanner::new();
1517
1518        assert_eq!(
1519            scanner.classify_pattern("eval("),
1520            VulnerabilityType::CodeExecution
1521        );
1522        assert_eq!(
1523            scanner.classify_pattern("socket."),
1524            VulnerabilityType::NetworkAccess
1525        );
1526        assert_eq!(
1527            scanner.classify_pattern("open("),
1528            VulnerabilityType::FileSystemAccess
1529        );
1530        assert_eq!(
1531            scanner.classify_pattern("unknown"),
1532            VulnerabilityType::SuspiciousPattern
1533        );
1534    }
1535
1536    #[test]
1537    fn test_vulnerability_severity_assessment() {
1538        let scanner = VulnerabilityScanner::new();
1539
1540        assert_eq!(scanner.assess_pattern_severity("eval("), Severity::Critical);
1541        assert_eq!(
1542            scanner.assess_pattern_severity("subprocess."),
1543            Severity::High
1544        );
1545        assert_eq!(scanner.assess_pattern_severity("socket."), Severity::Medium);
1546        assert_eq!(scanner.assess_pattern_severity("unknown"), Severity::Low);
1547    }
1548
1549    #[test]
1550    fn test_risk_level_assessment() {
1551        let scanner = VulnerabilityScanner::new();
1552
1553        // No vulnerabilities
1554        assert_eq!(scanner.assess_risk_level(&[]), RiskLevel::Safe);
1555
1556        // Critical vulnerability
1557        let critical_vuln = vec![Vulnerability {
1558            vuln_type: VulnerabilityType::CodeExecution,
1559            severity: Severity::Critical,
1560            description: "test".to_string(),
1561            location: "test".to_string(),
1562            remediation: "test".to_string(),
1563        }];
1564        assert_eq!(
1565            scanner.assess_risk_level(&critical_vuln),
1566            RiskLevel::Critical
1567        );
1568
1569        // High vulnerability
1570        let high_vuln = vec![Vulnerability {
1571            vuln_type: VulnerabilityType::CodeExecution,
1572            severity: Severity::High,
1573            description: "test".to_string(),
1574            location: "test".to_string(),
1575            remediation: "test".to_string(),
1576        }];
1577        assert_eq!(scanner.assess_risk_level(&high_vuln), RiskLevel::High);
1578    }
1579
1580    #[test]
1581    fn test_scan_model_vulnerabilities_convenience_function() {
1582        let mut temp_file = NamedTempFile::new().unwrap();
1583        temp_file.write_all(b"exec() call detected").unwrap();
1584        temp_file.flush().unwrap();
1585
1586        let result = scan_model_vulnerabilities(temp_file.path(), None).unwrap();
1587
1588        assert!(result.success);
1589        assert!(!result.vulnerabilities.is_empty());
1590        assert_eq!(result.risk_level, RiskLevel::Critical);
1591    }
1592}