1use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{BufReader, Read};
11use std::path::Path;
12use std::sync::Arc;
13use std::sync::Mutex;
14use std::time::{SystemTime, UNIX_EPOCH};
15use torsh_core::error::{GeneralError, Result, TorshError};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ModelSignature {
20 pub file_hash: String,
22 pub signature: String,
24 pub key_id: String,
26 pub timestamp: u64,
28 pub algorithm: SignatureAlgorithm,
30 pub metadata: HashMap<String, String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum SignatureAlgorithm {
37 RsaSha256,
39 Ed25519,
41 EcdsaP256,
43}
44
45#[derive(Debug, Clone)]
47pub struct KeyPair {
48 pub key_id: String,
49 pub algorithm: SignatureAlgorithm,
50 pub public_key: Vec<u8>,
51 pub private_key: Option<Vec<u8>>, }
53
54pub struct SecurityManager {
56 key_store: HashMap<String, KeyPair>,
57 trusted_keys: Vec<String>,
58}
59
60impl SecurityManager {
61 pub fn new() -> Self {
63 Self {
64 key_store: HashMap::new(),
65 trusted_keys: Vec::new(),
66 }
67 }
68
69 pub fn add_key(&mut self, key_pair: KeyPair) {
71 let key_id = key_pair.key_id.clone();
72 self.key_store.insert(key_id, key_pair);
73 }
74
75 pub fn trust_key(&mut self, key_id: &str) -> Result<()> {
77 if !self.key_store.contains_key(key_id) {
78 return Err(TorshError::General(GeneralError::InvalidArgument(format!(
79 "Key '{}' not found in key store",
80 key_id
81 ))));
82 }
83
84 if !self.trusted_keys.contains(&key_id.to_string()) {
85 self.trusted_keys.push(key_id.to_string());
86 }
87
88 Ok(())
89 }
90
91 pub fn untrust_key(&mut self, key_id: &str) {
93 self.trusted_keys.retain(|k| k != key_id);
94 }
95
96 pub fn sign_model<P: AsRef<Path>>(
98 &self,
99 model_path: P,
100 key_id: &str,
101 metadata: Option<HashMap<String, String>>,
102 ) -> Result<ModelSignature> {
103 let model_path = model_path.as_ref();
104
105 let key_pair = self.key_store.get(key_id).ok_or_else(|| {
107 TorshError::General(GeneralError::InvalidArgument(format!(
108 "Key '{}' not found",
109 key_id
110 )))
111 })?;
112
113 if key_pair.private_key.is_none() {
114 return Err(TorshError::General(GeneralError::InvalidArgument(
115 "Cannot sign with a verification-only key".to_string(),
116 )));
117 }
118
119 let file_hash = calculate_file_hash(model_path)?;
121
122 let signature = match key_pair.algorithm {
124 SignatureAlgorithm::RsaSha256 => sign_with_rsa_sha256(
125 &file_hash,
126 key_pair
127 .private_key
128 .as_ref()
129 .expect("RSA private key required for signing"),
130 )?,
131 SignatureAlgorithm::Ed25519 => sign_with_ed25519(
132 &file_hash,
133 key_pair
134 .private_key
135 .as_ref()
136 .expect("Ed25519 private key required for signing"),
137 )?,
138 SignatureAlgorithm::EcdsaP256 => sign_with_ecdsa_p256(
139 &file_hash,
140 key_pair
141 .private_key
142 .as_ref()
143 .expect("ECDSA P256 private key required for signing"),
144 )?,
145 };
146
147 let timestamp = SystemTime::now()
148 .duration_since(UNIX_EPOCH)
149 .expect("system time should be after UNIX epoch")
150 .as_secs();
151
152 Ok(ModelSignature {
153 file_hash,
154 signature,
155 key_id: key_id.to_string(),
156 timestamp,
157 algorithm: key_pair.algorithm.clone(),
158 metadata: metadata.unwrap_or_default(),
159 })
160 }
161
162 pub fn verify_model<P: AsRef<Path>>(
164 &self,
165 model_path: P,
166 signature: &ModelSignature,
167 require_trusted: bool,
168 ) -> Result<bool> {
169 let model_path = model_path.as_ref();
170
171 if require_trusted && !self.trusted_keys.contains(&signature.key_id) {
173 return Err(TorshError::General(GeneralError::RuntimeError(format!(
174 "Key '{}' is not trusted",
175 signature.key_id
176 ))));
177 }
178
179 let key_pair = self.key_store.get(&signature.key_id).ok_or_else(|| {
181 TorshError::General(GeneralError::InvalidArgument(format!(
182 "Key '{}' not found",
183 signature.key_id
184 )))
185 })?;
186
187 if key_pair.algorithm != signature.algorithm {
189 return Ok(false);
190 }
191
192 let current_hash = calculate_file_hash(model_path)?;
194
195 if current_hash != signature.file_hash {
197 return Ok(false);
198 }
199
200 let is_valid = match signature.algorithm {
202 SignatureAlgorithm::RsaSha256 => verify_rsa_sha256(
203 &signature.file_hash,
204 &signature.signature,
205 &key_pair.public_key,
206 )?,
207 SignatureAlgorithm::Ed25519 => verify_ed25519(
208 &signature.file_hash,
209 &signature.signature,
210 &key_pair.public_key,
211 )?,
212 SignatureAlgorithm::EcdsaP256 => verify_ecdsa_p256(
213 &signature.file_hash,
214 &signature.signature,
215 &key_pair.public_key,
216 )?,
217 };
218
219 Ok(is_valid)
220 }
221
222 pub fn save_signature<P: AsRef<Path>>(
224 signature: &ModelSignature,
225 signature_path: P,
226 ) -> Result<()> {
227 let json = serde_json::to_string_pretty(signature)
228 .map_err(|e| TorshError::SerializationError(e.to_string()))?;
229
230 std::fs::write(signature_path, json)?;
231 Ok(())
232 }
233
234 pub fn load_signature<P: AsRef<Path>>(signature_path: P) -> Result<ModelSignature> {
236 let content = std::fs::read_to_string(signature_path)?;
237 let signature: ModelSignature = serde_json::from_str(&content)
238 .map_err(|e| TorshError::SerializationError(e.to_string()))?;
239
240 Ok(signature)
241 }
242
243 pub fn generate_key_pair(key_id: String, algorithm: SignatureAlgorithm) -> Result<KeyPair> {
245 match algorithm {
246 SignatureAlgorithm::RsaSha256 => generate_rsa_key_pair(key_id),
247 SignatureAlgorithm::Ed25519 => generate_ed25519_key_pair(key_id),
248 SignatureAlgorithm::EcdsaP256 => generate_ecdsa_p256_key_pair(key_id),
249 }
250 }
251
252 pub fn trusted_keys(&self) -> &[String] {
254 &self.trusted_keys
255 }
256
257 pub fn all_keys(&self) -> Vec<&str> {
259 self.key_store.keys().map(|s| s.as_str()).collect()
260 }
261}
262
263impl Default for SecurityManager {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269pub fn calculate_file_hash<P: AsRef<Path>>(file_path: P) -> Result<String> {
271 let file = File::open(file_path)?;
272 let mut reader = BufReader::new(file);
273 let mut hasher = Sha256::new();
274 let mut buffer = [0; 8192];
275
276 loop {
277 let n = reader.read(&mut buffer)?;
278 if n == 0 {
279 break;
280 }
281 hasher.update(&buffer[..n]);
282 }
283
284 let result = hasher.finalize();
285 Ok(hex::encode(result))
286}
287
288pub fn verify_file_integrity<P: AsRef<Path>>(file_path: P, expected_hash: &str) -> Result<bool> {
290 let actual_hash = calculate_file_hash(file_path)?;
291 Ok(actual_hash == expected_hash)
292}
293
294fn sign_with_rsa_sha256(_data: &str, _private_key: &[u8]) -> Result<String> {
298 Ok("rsa_signature_placeholder".to_string())
301}
302
303fn verify_rsa_sha256(_data: &str, _signature: &str, _public_key: &[u8]) -> Result<bool> {
304 Ok(_signature == "rsa_signature_placeholder")
306}
307
308fn sign_with_ed25519(_data: &str, _private_key: &[u8]) -> Result<String> {
309 Ok("ed25519_signature_placeholder".to_string())
312}
313
314fn verify_ed25519(_data: &str, _signature: &str, _public_key: &[u8]) -> Result<bool> {
315 Ok(_signature == "ed25519_signature_placeholder")
317}
318
319fn sign_with_ecdsa_p256(_data: &str, _private_key: &[u8]) -> Result<String> {
320 Ok("ecdsa_signature_placeholder".to_string())
323}
324
325fn verify_ecdsa_p256(_data: &str, _signature: &str, _public_key: &[u8]) -> Result<bool> {
326 Ok(_signature == "ecdsa_signature_placeholder")
328}
329
330fn generate_rsa_key_pair(key_id: String) -> Result<KeyPair> {
331 Ok(KeyPair {
334 key_id,
335 algorithm: SignatureAlgorithm::RsaSha256,
336 public_key: b"rsa_public_key_placeholder".to_vec(),
337 private_key: Some(b"rsa_private_key_placeholder".to_vec()),
338 })
339}
340
341fn generate_ed25519_key_pair(key_id: String) -> Result<KeyPair> {
342 Ok(KeyPair {
345 key_id,
346 algorithm: SignatureAlgorithm::Ed25519,
347 public_key: b"ed25519_public_key_placeholder".to_vec(),
348 private_key: Some(b"ed25519_private_key_placeholder".to_vec()),
349 })
350}
351
352fn generate_ecdsa_p256_key_pair(key_id: String) -> Result<KeyPair> {
353 Ok(KeyPair {
356 key_id,
357 algorithm: SignatureAlgorithm::EcdsaP256,
358 public_key: b"ecdsa_public_key_placeholder".to_vec(),
359 private_key: Some(b"ecdsa_private_key_placeholder".to_vec()),
360 })
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct SecurityConfig {
366 pub require_signatures: bool,
368 pub require_trusted_keys: bool,
370 pub max_signature_age: Option<u64>,
372 pub blocked_sources: Vec<String>,
374 pub allowed_sources: Vec<String>,
376 pub verify_integrity: bool,
378}
379
380impl Default for SecurityConfig {
381 fn default() -> Self {
382 Self {
383 require_signatures: false,
384 require_trusted_keys: false,
385 max_signature_age: Some(30 * 24 * 3600), blocked_sources: Vec::new(),
387 allowed_sources: Vec::new(),
388 verify_integrity: true,
389 }
390 }
391}
392
393pub fn validate_model_source(url: &str, config: &SecurityConfig) -> Result<()> {
395 for blocked in &config.blocked_sources {
397 if url.contains(blocked) {
398 return Err(TorshError::General(GeneralError::RuntimeError(format!(
399 "Model source '{}' is blocked",
400 blocked
401 ))));
402 }
403 }
404
405 if !config.allowed_sources.is_empty() {
407 let is_allowed = config
408 .allowed_sources
409 .iter()
410 .any(|allowed| url.contains(allowed));
411 if !is_allowed {
412 return Err(TorshError::General(GeneralError::RuntimeError(
413 "Model source is not in the allowed list".to_string(),
414 )));
415 }
416 }
417
418 Ok(())
419}
420
421pub fn validate_signature_age(signature: &ModelSignature, max_age: Option<u64>) -> Result<()> {
423 if let Some(max_age_secs) = max_age {
424 let current_time = SystemTime::now()
425 .duration_since(UNIX_EPOCH)
426 .expect("system time should be after UNIX epoch")
427 .as_secs();
428
429 let age = current_time.saturating_sub(signature.timestamp);
430 if age > max_age_secs {
431 return Err(TorshError::General(GeneralError::RuntimeError(format!(
432 "Signature is too old: {} seconds (max: {})",
433 age, max_age_secs
434 ))));
435 }
436 }
437
438 Ok(())
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct SandboxConfig {
444 pub max_memory: usize,
446 pub max_execution_time: u64,
448 pub max_threads: usize,
450 pub allow_network: bool,
452 pub allow_filesystem: bool,
454 pub read_paths: Vec<String>,
456 pub write_paths: Vec<String>,
458 pub max_cpu_usage: f32,
460}
461
462impl Default for SandboxConfig {
463 fn default() -> Self {
464 Self {
465 max_memory: 1024 * 1024 * 1024, max_execution_time: 300, max_threads: 4,
468 allow_network: false,
469 allow_filesystem: false,
470 read_paths: Vec::new(),
471 write_paths: Vec::new(),
472 max_cpu_usage: 80.0,
473 }
474 }
475}
476
477#[derive(Debug, Clone, Default)]
479pub struct ResourceUsage {
480 pub memory_used: usize,
481 pub cpu_time: f64,
482 pub threads_created: usize,
483 pub network_requests: usize,
484 pub file_reads: usize,
485 pub file_writes: usize,
486 pub start_time: Option<SystemTime>,
487}
488
489pub struct ModelSandbox {
491 config: SandboxConfig,
492 usage: Arc<Mutex<ResourceUsage>>,
493 is_active: Arc<Mutex<bool>>,
494}
495
496impl ModelSandbox {
497 pub fn new(config: SandboxConfig) -> Self {
499 Self {
500 config,
501 usage: Arc::new(Mutex::new(ResourceUsage::default())),
502 is_active: Arc::new(Mutex::new(false)),
503 }
504 }
505
506 pub fn enter(&self) -> Result<SandboxGuard<'_>> {
508 let mut is_active = self.is_active.lock().expect("lock should not be poisoned");
509 if *is_active {
510 return Err(TorshError::General(GeneralError::RuntimeError(
511 "Sandbox is already active".to_string(),
512 )));
513 }
514
515 *is_active = true;
516
517 {
519 let mut usage = self.usage.lock().expect("lock should not be poisoned");
520 *usage = ResourceUsage {
521 start_time: Some(SystemTime::now()),
522 ..Default::default()
523 };
524 }
525
526 self.setup_memory_limits()?;
528 self.setup_thread_limits()?;
529 self.setup_network_limits()?;
530 self.setup_filesystem_limits()?;
531
532 Ok(SandboxGuard {
533 sandbox: self,
534 _phantom: std::marker::PhantomData,
535 })
536 }
537
538 pub fn check_limits(&self) -> Result<()> {
540 let usage = self.usage.lock().expect("lock should not be poisoned");
541
542 if usage.memory_used > self.config.max_memory {
544 return Err(TorshError::General(GeneralError::RuntimeError(format!(
545 "Memory limit exceeded: {} > {}",
546 usage.memory_used, self.config.max_memory
547 ))));
548 }
549
550 if let Some(start_time) = usage.start_time {
552 let elapsed = SystemTime::now()
553 .duration_since(start_time)
554 .expect("current time should be after start_time")
555 .as_secs();
556 if elapsed > self.config.max_execution_time {
557 return Err(TorshError::General(GeneralError::RuntimeError(format!(
558 "Execution time limit exceeded: {} > {}",
559 elapsed, self.config.max_execution_time
560 ))));
561 }
562 }
563
564 if usage.threads_created > self.config.max_threads {
566 return Err(TorshError::General(GeneralError::RuntimeError(format!(
567 "Thread limit exceeded: {} > {}",
568 usage.threads_created, self.config.max_threads
569 ))));
570 }
571
572 Ok(())
573 }
574
575 pub fn record_memory_usage(&self, bytes: usize) {
577 let mut usage = self.usage.lock().expect("lock should not be poisoned");
578 usage.memory_used = usage.memory_used.saturating_add(bytes);
579 }
580
581 pub fn record_thread_creation(&self) {
583 let mut usage = self.usage.lock().expect("lock should not be poisoned");
584 usage.threads_created += 1;
585 }
586
587 pub fn record_network_request(&self) -> Result<()> {
589 if !self.config.allow_network {
590 return Err(TorshError::General(GeneralError::RuntimeError(
591 "Network access is not allowed in sandbox".to_string(),
592 )));
593 }
594
595 let mut usage = self.usage.lock().expect("lock should not be poisoned");
596 usage.network_requests += 1;
597 Ok(())
598 }
599
600 pub fn record_file_access(&self, path: &str, is_write: bool) -> Result<()> {
602 if !self.config.allow_filesystem {
603 return Err(TorshError::General(GeneralError::RuntimeError(
604 "File system access is not allowed in sandbox".to_string(),
605 )));
606 }
607
608 let allowed_paths = if is_write {
610 &self.config.write_paths
611 } else {
612 &self.config.read_paths
613 };
614
615 if !allowed_paths.is_empty() {
616 let is_allowed = allowed_paths
617 .iter()
618 .any(|allowed_path| path.starts_with(allowed_path));
619
620 if !is_allowed {
621 return Err(TorshError::General(GeneralError::RuntimeError(format!(
622 "Access to path '{}' is not allowed",
623 path
624 ))));
625 }
626 }
627
628 let mut usage = self.usage.lock().expect("lock should not be poisoned");
629 if is_write {
630 usage.file_writes += 1;
631 } else {
632 usage.file_reads += 1;
633 }
634
635 Ok(())
636 }
637
638 pub fn get_usage(&self) -> ResourceUsage {
640 self.usage
641 .lock()
642 .expect("lock should not be poisoned")
643 .clone()
644 }
645
646 fn exit(&self) {
648 let mut is_active = self.is_active.lock().expect("lock should not be poisoned");
649 *is_active = false;
650
651 self.cleanup_limits();
653 }
654
655 fn setup_memory_limits(&self) -> Result<()> {
657 println!("Setting up memory limits: {} bytes", self.config.max_memory);
660 Ok(())
661 }
662
663 fn setup_thread_limits(&self) -> Result<()> {
664 println!(
666 "Setting up thread limits: {} threads",
667 self.config.max_threads
668 );
669 Ok(())
670 }
671
672 fn setup_network_limits(&self) -> Result<()> {
673 if !self.config.allow_network {
674 println!("Blocking network access");
676 }
677 Ok(())
678 }
679
680 fn setup_filesystem_limits(&self) -> Result<()> {
681 if !self.config.allow_filesystem {
682 println!("Restricting filesystem access");
684 }
685 Ok(())
686 }
687
688 fn cleanup_limits(&self) {
689 println!("Cleaning up sandbox limits");
691 }
692}
693
694pub struct SandboxGuard<'a> {
696 sandbox: &'a ModelSandbox,
697 _phantom: std::marker::PhantomData<&'a ()>,
698}
699
700impl<'a> Drop for SandboxGuard<'a> {
701 fn drop(&mut self) {
702 self.sandbox.exit();
703 }
704}
705
706pub struct SandboxedModel {
708 model: Box<dyn torsh_nn::Module>,
709 sandbox: std::sync::RwLock<ModelSandbox>,
710}
711
712impl SandboxedModel {
713 pub fn new(model: Box<dyn torsh_nn::Module>, config: SandboxConfig) -> Self {
715 Self {
716 model,
717 sandbox: std::sync::RwLock::new(ModelSandbox::new(config)),
718 }
719 }
720
721 pub fn forward_sandboxed(
723 &self,
724 input: &torsh_tensor::Tensor<f32>,
725 ) -> Result<torsh_tensor::Tensor<f32>> {
726 {
727 let sandbox = self.sandbox.read().expect("lock should not be poisoned");
728 let _guard = sandbox.enter()?;
729
730 let input_elements = input.shape().dims().iter().product::<usize>();
732 let input_memory = input_elements * std::mem::size_of::<f32>();
733 sandbox.record_memory_usage(input_memory);
734
735 sandbox.check_limits()?;
737 } let result = self.model.forward(input)?;
741
742 {
744 let sandbox = self.sandbox.read().expect("lock should not be poisoned");
745 let output_elements = result.shape().dims().iter().product::<usize>();
746 let output_memory = output_elements * std::mem::size_of::<f32>();
747 sandbox.record_memory_usage(output_memory);
748
749 sandbox.check_limits()?;
751 }
752
753 Ok(result)
754 }
755
756 pub fn get_sandbox_usage(&self) -> ResourceUsage {
758 self.sandbox
759 .read()
760 .expect("lock should not be poisoned")
761 .get_usage()
762 }
763}
764
765impl torsh_nn::Module for SandboxedModel {
766 fn forward(&self, input: &torsh_tensor::Tensor) -> Result<torsh_tensor::Tensor> {
767 let result = self.forward_sandboxed(input)?;
770 Ok(result)
771 }
772
773 fn parameters(&self) -> HashMap<String, torsh_nn::Parameter> {
774 self.model.parameters()
775 }
776
777 fn train(&mut self) {
778 self.model.train()
779 }
780
781 fn eval(&mut self) {
782 self.model.eval()
783 }
784
785 fn training(&self) -> bool {
786 self.model.training()
787 }
788
789 fn load_state_dict(
790 &mut self,
791 state_dict: &std::collections::HashMap<String, torsh_tensor::Tensor<f32>>,
792 strict: bool,
793 ) -> Result<()> {
794 self.model.load_state_dict(state_dict, strict)
795 }
796
797 fn state_dict(&self) -> std::collections::HashMap<String, torsh_tensor::Tensor<f32>> {
798 self.model.state_dict()
799 }
800}
801
802pub fn sandbox_model(
804 model: Box<dyn torsh_nn::Module>,
805 config: Option<SandboxConfig>,
806) -> SandboxedModel {
807 let config = config.unwrap_or_default();
808 SandboxedModel::new(model, config)
809}
810
811#[derive(Debug, Clone, Serialize, Deserialize)]
813pub struct VulnerabilityScanner {
814 malicious_patterns: Vec<String>,
816 vulnerable_signatures: Vec<String>,
818 max_scan_size: usize,
820 deep_scan: bool,
822}
823
824impl Default for VulnerabilityScanner {
825 fn default() -> Self {
826 Self {
827 malicious_patterns: vec![
828 "eval(".to_string(),
830 "exec(".to_string(),
831 "__import__".to_string(),
832 "subprocess.".to_string(),
833 "os.system".to_string(),
834 "shell=True".to_string(),
835 "torch.jit._script_to_py".to_string(),
837 "_C.import_ir_module".to_string(),
838 "urllib.request".to_string(),
840 "requests.get".to_string(),
841 "socket.".to_string(),
842 "open(".to_string(),
844 "file.write".to_string(),
845 "pathlib".to_string(),
846 ],
847 vulnerable_signatures: vec![
848 "bad_model_hash_1".to_string(),
850 "bad_model_hash_2".to_string(),
851 ],
852 max_scan_size: 100 * 1024 * 1024, deep_scan: true,
854 }
855 }
856}
857
858#[derive(Debug, Clone, Serialize, Deserialize)]
860pub struct VulnerabilityScanResult {
861 pub success: bool,
863 pub vulnerabilities: Vec<Vulnerability>,
865 pub risk_level: RiskLevel,
867 pub scan_metadata: ScanMetadata,
869}
870
871#[derive(Debug, Clone, Serialize, Deserialize)]
873pub struct Vulnerability {
874 pub vuln_type: VulnerabilityType,
876 pub severity: Severity,
878 pub description: String,
880 pub location: String,
882 pub remediation: String,
884}
885
886#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
888pub enum VulnerabilityType {
889 CodeExecution,
891 NetworkAccess,
893 FileSystemAccess,
895 DataExfiltration,
897 KnownVulnerable,
899 SuspiciousPattern,
901 CryptographicIssue,
903 MemoryCorruption,
905}
906
907#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
909pub enum Severity {
910 Low,
911 Medium,
912 High,
913 Critical,
914}
915
916#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
918pub enum RiskLevel {
919 Safe,
920 Low,
921 Medium,
922 High,
923 Critical,
924}
925
926#[derive(Debug, Clone, Serialize, Deserialize)]
928pub struct ScanMetadata {
929 pub scan_time: u64,
931 pub scan_duration: u64,
933 pub files_scanned: usize,
935 pub bytes_scanned: usize,
937 pub scanner_version: String,
939}
940
941impl VulnerabilityScanner {
942 pub fn new() -> Self {
944 Self::default()
945 }
946
947 pub fn with_config(
949 max_scan_size: usize,
950 deep_scan: bool,
951 custom_patterns: Vec<String>,
952 ) -> Self {
953 let mut malicious_patterns = Self::default().malicious_patterns;
954 malicious_patterns.extend(custom_patterns);
955 Self {
956 max_scan_size,
957 deep_scan,
958 malicious_patterns,
959 ..Default::default()
960 }
961 }
962
963 pub fn scan_file<P: AsRef<Path>>(&self, file_path: P) -> Result<VulnerabilityScanResult> {
965 let file_path = file_path.as_ref();
966 let start_time = SystemTime::now();
967
968 let metadata = std::fs::metadata(file_path)?;
970 if metadata.len() > self.max_scan_size as u64 {
971 return Ok(VulnerabilityScanResult {
972 success: false,
973 vulnerabilities: vec![Vulnerability {
974 vuln_type: VulnerabilityType::SuspiciousPattern,
975 severity: Severity::Medium,
976 description: format!("File too large to scan: {} bytes", metadata.len()),
977 location: file_path.to_string_lossy().to_string(),
978 remediation: "Manually verify large model files".to_string(),
979 }],
980 risk_level: RiskLevel::Medium,
981 scan_metadata: self.create_metadata(start_time, 1, metadata.len() as usize),
982 });
983 }
984
985 let mut vulnerabilities = Vec::new();
986
987 let file_hash = calculate_file_hash(file_path)?;
989 if self.vulnerable_signatures.contains(&file_hash) {
990 vulnerabilities.push(Vulnerability {
991 vuln_type: VulnerabilityType::KnownVulnerable,
992 severity: Severity::Critical,
993 description: "Model matches known vulnerable signature".to_string(),
994 location: file_path.to_string_lossy().to_string(),
995 remediation: "Do not use this model. Find alternative from trusted source"
996 .to_string(),
997 });
998 }
999
1000 let content_vulns = self.scan_file_content(file_path)?;
1002 vulnerabilities.extend(content_vulns);
1003
1004 if self.deep_scan {
1006 let deep_vulns = self.deep_scan_file(file_path)?;
1007 vulnerabilities.extend(deep_vulns);
1008 }
1009
1010 let risk_level = self.assess_risk_level(&vulnerabilities);
1012
1013 let end_time = SystemTime::now();
1014 let scan_duration = end_time
1015 .duration_since(start_time)
1016 .expect("end_time should be after start_time")
1017 .as_millis() as u64;
1018
1019 Ok(VulnerabilityScanResult {
1020 success: true,
1021 vulnerabilities,
1022 risk_level,
1023 scan_metadata: ScanMetadata {
1024 scan_time: start_time
1025 .duration_since(UNIX_EPOCH)
1026 .expect("start_time should be after UNIX epoch")
1027 .as_secs(),
1028 scan_duration,
1029 files_scanned: 1,
1030 bytes_scanned: metadata.len() as usize,
1031 scanner_version: "1.0.0".to_string(),
1032 },
1033 })
1034 }
1035
1036 fn scan_file_content<P: AsRef<Path>>(&self, file_path: P) -> Result<Vec<Vulnerability>> {
1038 let file_path = file_path.as_ref();
1039 let mut vulnerabilities = Vec::new();
1040
1041 let file = File::open(file_path)?;
1043 let mut reader = BufReader::new(file);
1044 let mut content = Vec::new();
1045 reader.read_to_end(&mut content)?;
1046
1047 let content_str = String::from_utf8_lossy(&content);
1049
1050 for pattern in &self.malicious_patterns {
1052 if content_str.contains(pattern) {
1053 let vuln_type = self.classify_pattern(pattern);
1054 let severity = self.assess_pattern_severity(pattern);
1055
1056 vulnerabilities.push(Vulnerability {
1057 vuln_type,
1058 severity,
1059 description: format!("Found suspicious pattern: {}", pattern),
1060 location: file_path.to_string_lossy().to_string(),
1061 remediation: "Review model source and verify legitimacy".to_string(),
1062 });
1063 }
1064 }
1065
1066 Ok(vulnerabilities)
1067 }
1068
1069 fn deep_scan_file<P: AsRef<Path>>(&self, file_path: P) -> Result<Vec<Vulnerability>> {
1071 let file_path = file_path.as_ref();
1072 let mut vulnerabilities = Vec::new();
1073
1074 if let Some(extension) = file_path.extension() {
1076 if let Some("exe" | "bat" | "sh" | "ps1") = extension.to_str() {
1077 vulnerabilities.push(Vulnerability {
1078 vuln_type: VulnerabilityType::CodeExecution,
1079 severity: Severity::High,
1080 description: "Model file has executable extension".to_string(),
1081 location: file_path.to_string_lossy().to_string(),
1082 remediation: "Verify this is actually a model file and not malware".to_string(),
1083 });
1084 }
1085 }
1086
1087 if let Some(filename) = file_path.file_name() {
1089 if let Some(filename_str) = filename.to_str() {
1090 if filename_str.starts_with('.') || filename_str.contains("..") {
1091 vulnerabilities.push(Vulnerability {
1092 vuln_type: VulnerabilityType::SuspiciousPattern,
1093 severity: Severity::Medium,
1094 description: "Suspicious filename pattern".to_string(),
1095 location: file_path.to_string_lossy().to_string(),
1096 remediation: "Verify file purpose and rename to standard convention"
1097 .to_string(),
1098 });
1099 }
1100 }
1101 }
1102
1103 #[cfg(unix)]
1105 {
1106 use std::os::unix::fs::PermissionsExt;
1107 let metadata = std::fs::metadata(file_path)?;
1108 let permissions = metadata.permissions();
1109 let mode = permissions.mode();
1110
1111 if mode & 0o111 != 0 {
1113 vulnerabilities.push(Vulnerability {
1114 vuln_type: VulnerabilityType::CodeExecution,
1115 severity: Severity::Medium,
1116 description: "Model file has execute permissions".to_string(),
1117 location: file_path.to_string_lossy().to_string(),
1118 remediation: "Remove execute permissions: chmod -x filename".to_string(),
1119 });
1120 }
1121 }
1122
1123 Ok(vulnerabilities)
1124 }
1125
1126 fn classify_pattern(&self, pattern: &str) -> VulnerabilityType {
1128 match pattern {
1129 p if p.contains("eval") || p.contains("exec") || p.contains("subprocess") => {
1130 VulnerabilityType::CodeExecution
1131 }
1132 p if p.contains("socket") || p.contains("urllib") || p.contains("requests") => {
1133 VulnerabilityType::NetworkAccess
1134 }
1135 p if p.contains("open") || p.contains("file") || p.contains("pathlib") => {
1136 VulnerabilityType::FileSystemAccess
1137 }
1138 _ => VulnerabilityType::SuspiciousPattern,
1139 }
1140 }
1141
1142 fn assess_pattern_severity(&self, pattern: &str) -> Severity {
1144 match pattern {
1145 p if p.contains("eval") || p.contains("exec") => Severity::Critical,
1146 p if p.contains("subprocess") || p.contains("os.system") => Severity::High,
1147 p if p.contains("socket") || p.contains("urllib") => Severity::Medium,
1148 _ => Severity::Low,
1149 }
1150 }
1151
1152 fn assess_risk_level(&self, vulnerabilities: &[Vulnerability]) -> RiskLevel {
1154 if vulnerabilities.is_empty() {
1155 return RiskLevel::Safe;
1156 }
1157
1158 let max_severity = vulnerabilities
1159 .iter()
1160 .map(|v| &v.severity)
1161 .max()
1162 .unwrap_or(&Severity::Low);
1163
1164 match max_severity {
1165 Severity::Critical => RiskLevel::Critical,
1166 Severity::High => RiskLevel::High,
1167 Severity::Medium => RiskLevel::Medium,
1168 Severity::Low => RiskLevel::Low,
1169 }
1170 }
1171
1172 fn create_metadata(
1174 &self,
1175 start_time: SystemTime,
1176 files_scanned: usize,
1177 bytes_scanned: usize,
1178 ) -> ScanMetadata {
1179 let end_time = SystemTime::now();
1180 let scan_duration = end_time
1181 .duration_since(start_time)
1182 .expect("end_time should be after start_time")
1183 .as_millis() as u64;
1184
1185 ScanMetadata {
1186 scan_time: start_time
1187 .duration_since(UNIX_EPOCH)
1188 .expect("start_time should be after UNIX epoch")
1189 .as_secs(),
1190 scan_duration,
1191 files_scanned,
1192 bytes_scanned,
1193 scanner_version: "1.0.0".to_string(),
1194 }
1195 }
1196
1197 pub fn add_pattern(&mut self, pattern: String) {
1199 if !self.malicious_patterns.contains(&pattern) {
1200 self.malicious_patterns.push(pattern);
1201 }
1202 }
1203
1204 pub fn add_vulnerable_signature(&mut self, signature: String) {
1206 if !self.vulnerable_signatures.contains(&signature) {
1207 self.vulnerable_signatures.push(signature);
1208 }
1209 }
1210
1211 pub fn get_patterns(&self) -> &[String] {
1213 &self.malicious_patterns
1214 }
1215
1216 pub fn get_vulnerable_signatures(&self) -> &[String] {
1218 &self.vulnerable_signatures
1219 }
1220}
1221
1222pub fn scan_model_vulnerabilities<P: AsRef<Path>>(
1224 file_path: P,
1225 scanner: Option<VulnerabilityScanner>,
1226) -> Result<VulnerabilityScanResult> {
1227 let scanner = scanner.unwrap_or_default();
1228 scanner.scan_file(file_path)
1229}
1230
1231#[cfg(test)]
1232mod tests {
1233 use super::*;
1234 use std::io::Write;
1235 use tempfile::NamedTempFile;
1236
1237 #[test]
1238 fn test_calculate_file_hash() {
1239 let mut temp_file = NamedTempFile::new().unwrap();
1240 temp_file.write_all(b"test content").unwrap();
1241 temp_file.flush().unwrap();
1242
1243 let hash = calculate_file_hash(temp_file.path()).unwrap();
1244 assert_eq!(hash.len(), 64); }
1246
1247 #[test]
1248 fn test_verify_file_integrity() {
1249 let mut temp_file = NamedTempFile::new().unwrap();
1250 temp_file.write_all(b"test content").unwrap();
1251 temp_file.flush().unwrap();
1252
1253 let hash = calculate_file_hash(temp_file.path()).unwrap();
1254 assert!(verify_file_integrity(temp_file.path(), &hash).unwrap());
1255 assert!(!verify_file_integrity(temp_file.path(), "wrong_hash").unwrap());
1256 }
1257
1258 #[test]
1259 fn test_security_manager() {
1260 let mut manager = SecurityManager::new();
1261
1262 let key_pair =
1264 SecurityManager::generate_key_pair("test_key".to_string(), SignatureAlgorithm::Ed25519)
1265 .unwrap();
1266
1267 manager.add_key(key_pair);
1269
1270 manager.trust_key("test_key").unwrap();
1272
1273 assert!(manager.trusted_keys().contains(&"test_key".to_string()));
1274 }
1275
1276 #[test]
1277 fn test_validate_model_source() {
1278 let config = SecurityConfig {
1279 blocked_sources: vec!["malicious.com".to_string()],
1280 allowed_sources: vec!["trusted.com".to_string()],
1281 ..Default::default()
1282 };
1283
1284 assert!(validate_model_source("https://malicious.com/model.torsh", &config).is_err());
1286
1287 assert!(validate_model_source("https://unknown.com/model.torsh", &config).is_err());
1289
1290 assert!(validate_model_source("https://trusted.com/model.torsh", &config).is_ok());
1292 }
1293
1294 #[test]
1295 fn test_validate_signature_age() {
1296 let current_time = SystemTime::now()
1297 .duration_since(UNIX_EPOCH)
1298 .unwrap()
1299 .as_secs();
1300
1301 let recent_signature = ModelSignature {
1303 timestamp: current_time - 100, file_hash: "hash".to_string(),
1305 signature: "sig".to_string(),
1306 key_id: "key".to_string(),
1307 algorithm: SignatureAlgorithm::Ed25519,
1308 metadata: HashMap::new(),
1309 };
1310
1311 assert!(validate_signature_age(&recent_signature, Some(3600)).is_ok());
1312
1313 let old_signature = ModelSignature {
1315 timestamp: current_time - 7200, ..recent_signature
1317 };
1318
1319 assert!(validate_signature_age(&old_signature, Some(3600)).is_err());
1320 }
1321
1322 #[test]
1323 fn test_sandbox_config() {
1324 let config = SandboxConfig::default();
1325 assert_eq!(config.max_memory, 1024 * 1024 * 1024);
1326 assert_eq!(config.max_execution_time, 300);
1327 assert_eq!(config.max_threads, 4);
1328 assert!(!config.allow_network);
1329 assert!(!config.allow_filesystem);
1330 }
1331
1332 #[test]
1333 fn test_sandbox_creation() {
1334 let config = SandboxConfig::default();
1335 let sandbox = ModelSandbox::new(config);
1336
1337 let usage = sandbox.get_usage();
1338 assert_eq!(usage.memory_used, 0);
1339 assert_eq!(usage.threads_created, 0);
1340 assert_eq!(usage.network_requests, 0);
1341 }
1342
1343 #[test]
1344 fn test_sandbox_resource_tracking() {
1345 let config = SandboxConfig::default();
1346 let sandbox = ModelSandbox::new(config);
1347
1348 sandbox.record_memory_usage(1024);
1350 let usage = sandbox.get_usage();
1351 assert_eq!(usage.memory_used, 1024);
1352
1353 sandbox.record_thread_creation();
1355 let usage = sandbox.get_usage();
1356 assert_eq!(usage.threads_created, 1);
1357 }
1358
1359 #[test]
1360 fn test_sandbox_network_restrictions() {
1361 let config = SandboxConfig {
1362 allow_network: false,
1363 ..Default::default()
1364 };
1365 let sandbox = ModelSandbox::new(config);
1366
1367 assert!(sandbox.record_network_request().is_err());
1369
1370 let config = SandboxConfig {
1371 allow_network: true,
1372 ..Default::default()
1373 };
1374 let sandbox = ModelSandbox::new(config);
1375
1376 assert!(sandbox.record_network_request().is_ok());
1378 }
1379
1380 #[test]
1381 fn test_sandbox_filesystem_restrictions() {
1382 let config = SandboxConfig {
1383 allow_filesystem: false,
1384 ..Default::default()
1385 };
1386 let sandbox = ModelSandbox::new(config);
1387
1388 let test_path = std::env::temp_dir().join("test");
1389 let test_path_str = test_path.to_string_lossy();
1390 assert!(sandbox.record_file_access(&test_path_str, false).is_err());
1392
1393 let temp_dir_str = std::env::temp_dir().to_string_lossy().into_owned();
1394 let config = SandboxConfig {
1395 allow_filesystem: true,
1396 read_paths: vec![temp_dir_str],
1397 ..Default::default()
1398 };
1399 let sandbox = ModelSandbox::new(config);
1400
1401 assert!(sandbox.record_file_access(&test_path_str, false).is_ok());
1403
1404 assert!(sandbox.record_file_access("/etc/passwd", false).is_err());
1406 }
1407
1408 #[test]
1409 fn test_sandbox_guard() {
1410 let config = SandboxConfig::default();
1411 let sandbox = ModelSandbox::new(config);
1412
1413 {
1415 let _guard = sandbox.enter().unwrap();
1416 }
1418 let _guard = sandbox.enter().unwrap();
1422 }
1423
1424 #[test]
1425 fn test_vulnerability_scanner_creation() {
1426 let scanner = VulnerabilityScanner::new();
1427 assert!(!scanner.get_patterns().is_empty());
1428 assert!(!scanner.get_vulnerable_signatures().is_empty());
1429 }
1430
1431 #[test]
1432 fn test_vulnerability_scanner_custom_config() {
1433 let custom_patterns = vec!["custom_pattern".to_string()];
1434 let scanner = VulnerabilityScanner::with_config(
1435 50 * 1024 * 1024, false,
1437 custom_patterns.clone(),
1438 );
1439
1440 assert!(scanner
1441 .get_patterns()
1442 .contains(&"custom_pattern".to_string()));
1443 }
1444
1445 #[test]
1446 fn test_vulnerability_scanner_pattern_detection() {
1447 let mut temp_file = NamedTempFile::new().unwrap();
1448 temp_file
1449 .write_all(b"This file contains eval() function")
1450 .unwrap();
1451 temp_file.flush().unwrap();
1452
1453 let scanner = VulnerabilityScanner::new();
1454 let result = scanner.scan_file(temp_file.path()).unwrap();
1455
1456 assert!(result.success);
1457 assert!(!result.vulnerabilities.is_empty());
1458 assert_eq!(result.risk_level, RiskLevel::Critical);
1459
1460 let has_eval_vuln = result
1462 .vulnerabilities
1463 .iter()
1464 .any(|v| v.description.contains("eval"));
1465 assert!(has_eval_vuln);
1466 }
1467
1468 #[test]
1469 fn test_vulnerability_scanner_clean_file() {
1470 let mut temp_file = NamedTempFile::new().unwrap();
1471 temp_file
1472 .write_all(b"This is a clean model file with no suspicious content")
1473 .unwrap();
1474 temp_file.flush().unwrap();
1475
1476 let scanner = VulnerabilityScanner::with_config(
1478 100 * 1024 * 1024, false, vec![], );
1482 let result = scanner.scan_file(temp_file.path()).unwrap();
1483
1484 assert!(result.success);
1485 assert!(result.vulnerabilities.is_empty());
1486 assert_eq!(result.risk_level, RiskLevel::Safe);
1487 }
1488
1489 #[test]
1490 fn test_vulnerability_scanner_known_vulnerable() {
1491 let mut temp_file = NamedTempFile::new().unwrap();
1492 temp_file.write_all(b"test content").unwrap();
1493 temp_file.flush().unwrap();
1494
1495 let file_hash = calculate_file_hash(temp_file.path()).unwrap();
1496
1497 let mut scanner = VulnerabilityScanner::new();
1498 scanner.add_vulnerable_signature(file_hash);
1499
1500 let result = scanner.scan_file(temp_file.path()).unwrap();
1501
1502 assert!(result.success);
1503 assert!(!result.vulnerabilities.is_empty());
1504 assert_eq!(result.risk_level, RiskLevel::Critical);
1505
1506 let has_known_vuln = result
1508 .vulnerabilities
1509 .iter()
1510 .any(|v| v.vuln_type == VulnerabilityType::KnownVulnerable);
1511 assert!(has_known_vuln);
1512 }
1513
1514 #[test]
1515 fn test_vulnerability_pattern_classification() {
1516 let scanner = VulnerabilityScanner::new();
1517
1518 assert_eq!(
1519 scanner.classify_pattern("eval("),
1520 VulnerabilityType::CodeExecution
1521 );
1522 assert_eq!(
1523 scanner.classify_pattern("socket."),
1524 VulnerabilityType::NetworkAccess
1525 );
1526 assert_eq!(
1527 scanner.classify_pattern("open("),
1528 VulnerabilityType::FileSystemAccess
1529 );
1530 assert_eq!(
1531 scanner.classify_pattern("unknown"),
1532 VulnerabilityType::SuspiciousPattern
1533 );
1534 }
1535
1536 #[test]
1537 fn test_vulnerability_severity_assessment() {
1538 let scanner = VulnerabilityScanner::new();
1539
1540 assert_eq!(scanner.assess_pattern_severity("eval("), Severity::Critical);
1541 assert_eq!(
1542 scanner.assess_pattern_severity("subprocess."),
1543 Severity::High
1544 );
1545 assert_eq!(scanner.assess_pattern_severity("socket."), Severity::Medium);
1546 assert_eq!(scanner.assess_pattern_severity("unknown"), Severity::Low);
1547 }
1548
1549 #[test]
1550 fn test_risk_level_assessment() {
1551 let scanner = VulnerabilityScanner::new();
1552
1553 assert_eq!(scanner.assess_risk_level(&[]), RiskLevel::Safe);
1555
1556 let critical_vuln = vec![Vulnerability {
1558 vuln_type: VulnerabilityType::CodeExecution,
1559 severity: Severity::Critical,
1560 description: "test".to_string(),
1561 location: "test".to_string(),
1562 remediation: "test".to_string(),
1563 }];
1564 assert_eq!(
1565 scanner.assess_risk_level(&critical_vuln),
1566 RiskLevel::Critical
1567 );
1568
1569 let high_vuln = vec![Vulnerability {
1571 vuln_type: VulnerabilityType::CodeExecution,
1572 severity: Severity::High,
1573 description: "test".to_string(),
1574 location: "test".to_string(),
1575 remediation: "test".to_string(),
1576 }];
1577 assert_eq!(scanner.assess_risk_level(&high_vuln), RiskLevel::High);
1578 }
1579
1580 #[test]
1581 fn test_scan_model_vulnerabilities_convenience_function() {
1582 let mut temp_file = NamedTempFile::new().unwrap();
1583 temp_file.write_all(b"exec() call detected").unwrap();
1584 temp_file.flush().unwrap();
1585
1586 let result = scan_model_vulnerabilities(temp_file.path(), None).unwrap();
1587
1588 assert!(result.success);
1589 assert!(!result.vulnerabilities.is_empty());
1590 assert_eq!(result.risk_level, RiskLevel::Critical);
1591 }
1592}