1use sha2::{Digest, Sha256};
7use std::path::Path;
8use torsh_core::error::Result;
9
10use crate::package::Package;
11
12pub fn calculate_hash(data: &[u8]) -> String {
14 let mut hasher = Sha256::new();
15 hasher.update(data);
16 hex::encode(hasher.finalize())
17}
18
19#[cfg(feature = "with-nn")]
21pub fn export_module<M: torsh_nn::Module, P: AsRef<Path>>(
22 module: &M,
23 path: P,
24 name: &str,
25 version: &str,
26) -> Result<()> {
27 crate::builder::PackageBuilder::new(name.to_string(), version.to_string())
28 .add_module("main", module)?
29 .build(path)
30}
31
32pub fn import_module<P: AsRef<Path>>(path: P, module_name: &str) -> Result<Vec<u8>> {
34 let package = Package::load(path)?;
35 package.get_module(module_name)
36}
37
38pub fn validate_package_name(name: &str) -> bool {
40 if name.is_empty() || name.len() > 100 {
41 return false;
42 }
43
44 let first_char = name
46 .chars()
47 .next()
48 .expect("name is not empty after length check");
49 if !first_char.is_alphanumeric() {
50 return false;
51 }
52
53 name.chars()
54 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
55}
56
57pub fn validate_version(version: &str) -> bool {
59 semver::Version::parse(version).is_ok()
60}
61
62pub fn get_file_extension(filename: &str) -> Option<&str> {
64 std::path::Path::new(filename)
65 .extension()
66 .and_then(std::ffi::OsStr::to_str)
67}
68
69pub fn sanitize_filename(filename: &str) -> String {
71 filename
72 .chars()
73 .map(|c| {
74 if c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_') {
75 c
76 } else {
77 '_'
78 }
79 })
80 .collect()
81}
82
83pub fn is_safe_path(path: &str) -> bool {
85 !path.contains("..") && !path.starts_with('/') && !path.starts_with('\\')
86}
87
88pub fn format_file_size(size: u64) -> String {
90 const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
91 const THRESHOLD: u64 = 1024;
92
93 if size == 0 {
94 return "0 B".to_string();
95 }
96
97 let mut size = size as f64;
98 let mut unit_index = 0;
99
100 while size >= THRESHOLD as f64 && unit_index < UNITS.len() - 1 {
101 size /= THRESHOLD as f64;
102 unit_index += 1;
103 }
104
105 if unit_index == 0 {
106 format!("{} {}", size as u64, UNITS[unit_index])
107 } else {
108 format!("{:.1} {}", size, UNITS[unit_index])
109 }
110}
111
112pub fn estimate_compression_ratio(data: &[u8]) -> f64 {
114 if data.is_empty() {
115 return 1.0;
116 }
117
118 let mut counts = [0u32; 256];
120 for &byte in data {
121 counts[byte as usize] += 1;
122 }
123
124 let len = data.len() as f64;
125 let mut entropy = 0.0;
126
127 for &count in &counts {
128 if count > 0 {
129 let p = count as f64 / len;
130 entropy -= p * p.log2();
131 }
132 }
133
134 let max_entropy = 8.0;
137 let compression_potential = (max_entropy - entropy) / max_entropy;
138 1.0 - compression_potential.max(0.0).min(0.9) }
140
141pub fn validate_resource_path(path: &str) -> Result<()> {
143 use torsh_core::error::TorshError;
144
145 if path.is_empty() {
146 return Err(TorshError::InvalidArgument(
147 "Resource path cannot be empty".to_string(),
148 ));
149 }
150
151 if path.len() > 1024 {
152 return Err(TorshError::InvalidArgument(
153 "Resource path exceeds maximum length of 1024 characters".to_string(),
154 ));
155 }
156
157 if !is_safe_path(path) {
158 return Err(TorshError::InvalidArgument(format!(
159 "Resource path contains unsafe components: {}",
160 path
161 )));
162 }
163
164 Ok(())
165}
166
167pub fn validate_package_metadata(
169 name: &str,
170 version: &str,
171 description: Option<&str>,
172) -> Result<()> {
173 use torsh_core::error::TorshError;
174
175 if !validate_package_name(name) {
176 return Err(TorshError::InvalidArgument(format!(
177 "Invalid package name: {}",
178 name
179 )));
180 }
181
182 if !validate_version(version) {
183 return Err(TorshError::InvalidArgument(format!(
184 "Invalid semantic version: {}",
185 version
186 )));
187 }
188
189 if let Some(desc) = description {
190 if desc.len() > 10000 {
191 return Err(TorshError::InvalidArgument(
192 "Package description exceeds maximum length of 10000 characters".to_string(),
193 ));
194 }
195 }
196
197 Ok(())
198}
199
200pub fn calculate_checksum(data: &[u8]) -> u64 {
202 let mut checksum = 0u64;
204 for &byte in data {
205 checksum = checksum.wrapping_mul(31).wrapping_add(byte as u64);
206 }
207 checksum
208}
209
210pub fn verify_checksum(data: &[u8], expected: u64) -> bool {
212 calculate_checksum(data) == expected
213}
214
215pub fn normalize_path(path: &str) -> String {
217 path.replace('\\', "/")
218}
219
220pub fn get_relative_path(from: &str, to: &str) -> String {
222 let from_parts: Vec<&str> = from.split('/').filter(|s| !s.is_empty()).collect();
223 let to_parts: Vec<&str> = to.split('/').filter(|s| !s.is_empty()).collect();
224
225 let mut common = 0;
226 for (a, b) in from_parts.iter().zip(to_parts.iter()) {
227 if a == b {
228 common += 1;
229 } else {
230 break;
231 }
232 }
233
234 let mut result = Vec::new();
235 for _ in common..from_parts.len() {
236 result.push("..");
237 }
238 result.extend(to_parts[common..].iter());
239
240 if result.is_empty() {
241 ".".to_string()
242 } else {
243 result.join("/")
244 }
245}
246
247pub fn parse_content_type(filename: &str) -> &'static str {
249 match get_file_extension(filename) {
250 Some("txt") | Some("md") => "text/plain",
251 Some("json") => "application/json",
252 Some("xml") => "application/xml",
253 Some("html") => "text/html",
254 Some("css") => "text/css",
255 Some("js") => "application/javascript",
256 Some("py") => "text/x-python",
257 Some("rs") => "text/x-rust",
258 Some("zip") => "application/zip",
259 Some("tar") => "application/x-tar",
260 Some("gz") => "application/gzip",
261 Some("torshpkg") => "application/x-torsh-package",
262 Some("onnx") => "application/onnx",
263 Some("pkl") | Some("pickle") => "application/python-pickle",
264 _ => "application/octet-stream",
265 }
266}
267
268#[derive(Debug, Clone)]
270pub struct PerformanceTimer {
271 start: std::time::Instant,
272 name: String,
273}
274
275impl PerformanceTimer {
276 pub fn new(name: impl Into<String>) -> Self {
278 Self {
279 start: std::time::Instant::now(),
280 name: name.into(),
281 }
282 }
283
284 pub fn elapsed_ms(&self) -> u64 {
286 self.start.elapsed().as_millis() as u64
287 }
288
289 pub fn elapsed_secs(&self) -> f64 {
291 self.start.elapsed().as_secs_f64()
292 }
293
294 pub fn print_elapsed(&self) {
296 eprintln!("[{}] Elapsed: {:.3}s", self.name, self.elapsed_secs());
297 }
298
299 pub fn reset(&mut self) {
301 self.start = std::time::Instant::now();
302 }
303}
304
305impl Drop for PerformanceTimer {
306 fn drop(&mut self) {
307 if cfg!(debug_assertions) {
308 self.print_elapsed();
309 }
310 }
311}
312
313#[derive(Debug, Clone, Default)]
315pub struct MemoryStats {
316 pub allocated: u64,
318 pub peak: u64,
320 pub allocations: u64,
322}
323
324impl MemoryStats {
325 pub fn new() -> Self {
327 Self::default()
328 }
329
330 pub fn record_allocation(&mut self, size: u64) {
332 self.allocated += size;
333 self.allocations += 1;
334 if self.allocated > self.peak {
335 self.peak = self.allocated;
336 }
337 }
338
339 pub fn record_deallocation(&mut self, size: u64) {
341 self.allocated = self.allocated.saturating_sub(size);
342 }
343
344 pub fn format(&self) -> String {
346 format!(
347 "Allocated: {}, Peak: {}, Allocations: {}",
348 format_file_size(self.allocated),
349 format_file_size(self.peak),
350 self.allocations
351 )
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_calculate_hash() {
361 let data = b"hello world";
362 let hash = calculate_hash(data);
363
364 assert_eq!(
366 hash,
367 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
368 );
369 }
370
371 #[test]
372 fn test_validate_package_name() {
373 assert!(validate_package_name("my-package"));
374 assert!(validate_package_name("package_name"));
375 assert!(validate_package_name("Package123"));
376
377 assert!(!validate_package_name(""));
378 assert!(!validate_package_name("-invalid"));
379 assert!(!validate_package_name("invalid@name"));
380 assert!(!validate_package_name("a".repeat(101).as_str()));
381 }
382
383 #[test]
384 fn test_validate_version() {
385 assert!(validate_version("1.0.0"));
386 assert!(validate_version("2.1.3-alpha.1"));
387 assert!(validate_version("0.0.1-beta+build.123"));
388
389 assert!(!validate_version(""));
390 assert!(!validate_version("1.0"));
391 assert!(!validate_version("invalid"));
392 }
393
394 #[test]
395 fn test_get_file_extension() {
396 assert_eq!(get_file_extension("file.txt"), Some("txt"));
397 assert_eq!(get_file_extension("archive.tar.gz"), Some("gz"));
398 assert_eq!(get_file_extension("README"), None);
399 assert_eq!(get_file_extension(".hidden"), None);
400 }
401
402 #[test]
403 fn test_sanitize_filename() {
404 assert_eq!(sanitize_filename("normal_file.txt"), "normal_file.txt");
405 assert_eq!(
406 sanitize_filename("file with spaces.txt"),
407 "file_with_spaces.txt"
408 );
409 assert_eq!(sanitize_filename("file@#$%.txt"), "file____.txt");
410 assert_eq!(sanitize_filename("αβγ.txt"), "___.txt");
411 }
412
413 #[test]
414 fn test_is_safe_path() {
415 assert!(is_safe_path("safe/path/file.txt"));
416 assert!(is_safe_path("file.txt"));
417 assert!(is_safe_path("subdir/file.txt"));
418
419 assert!(!is_safe_path("../etc/passwd"));
420 assert!(!is_safe_path("/absolute/path"));
421 assert!(!is_safe_path("\\windows\\path"));
422 assert!(!is_safe_path("safe/../unsafe"));
423 }
424
425 #[test]
426 fn test_format_file_size() {
427 assert_eq!(format_file_size(0), "0 B");
428 assert_eq!(format_file_size(512), "512 B");
429 assert_eq!(format_file_size(1024), "1.0 KB");
430 assert_eq!(format_file_size(1536), "1.5 KB");
431 assert_eq!(format_file_size(1048576), "1.0 MB");
432 assert_eq!(format_file_size(1073741824), "1.0 GB");
433 }
434
435 #[test]
436 fn test_estimate_compression_ratio() {
437 let repetitive = vec![b'A'; 1000];
439 let ratio = estimate_compression_ratio(&repetitive);
440 assert!(ratio < 0.5); let random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
444 let ratio = estimate_compression_ratio(&random);
445 assert!(ratio > 0.8); assert_eq!(estimate_compression_ratio(&[]), 1.0);
449 }
450
451 #[test]
452 fn test_import_module_nonexistent() {
453 let temp_dir = tempfile::TempDir::new().unwrap();
454 let nonexistent_path = temp_dir.path().join("nonexistent.torshpkg");
455
456 let result = import_module(nonexistent_path, "test");
457 assert!(result.is_err());
458 }
459
460 #[test]
461 fn test_validate_resource_path() {
462 assert!(validate_resource_path("valid/path.txt").is_ok());
463 assert!(validate_resource_path("another_file.rs").is_ok());
464
465 assert!(validate_resource_path("").is_err());
466 assert!(validate_resource_path("../unsafe").is_err());
467 assert!(validate_resource_path("/absolute").is_err());
468 assert!(validate_resource_path(&"x".repeat(1025)).is_err());
469 }
470
471 #[test]
472 fn test_validate_package_metadata() {
473 assert!(validate_package_metadata("my-package", "1.0.0", None).is_ok());
474 assert!(validate_package_metadata("test", "2.1.3", Some("A test package")).is_ok());
475
476 assert!(validate_package_metadata("", "1.0.0", None).is_err());
477 assert!(validate_package_metadata("test", "invalid", None).is_err());
478 assert!(validate_package_metadata("test", "1.0.0", Some(&"x".repeat(10001))).is_err());
479 }
480
481 #[test]
482 fn test_calculate_checksum() {
483 let data1 = b"hello world";
484 let data2 = b"hello world";
485 let data3 = b"different data";
486
487 let checksum1 = calculate_checksum(data1);
488 let checksum2 = calculate_checksum(data2);
489 let checksum3 = calculate_checksum(data3);
490
491 assert_eq!(checksum1, checksum2);
492 assert_ne!(checksum1, checksum3);
493 }
494
495 #[test]
496 fn test_verify_checksum() {
497 let data = b"test data";
498 let checksum = calculate_checksum(data);
499
500 assert!(verify_checksum(data, checksum));
501 assert!(!verify_checksum(data, checksum + 1));
502 }
503
504 #[test]
505 fn test_normalize_path() {
506 assert_eq!(normalize_path("path/to/file"), "path/to/file");
507 assert_eq!(normalize_path("path\\to\\file"), "path/to/file");
508 assert_eq!(normalize_path("mixed\\path/to\\file"), "mixed/path/to/file");
509 }
510
511 #[test]
512 fn test_get_relative_path() {
513 assert_eq!(get_relative_path("a/b/c", "a/b/d"), "../d");
514 assert_eq!(get_relative_path("a/b", "a/b/c/d"), "c/d");
515 assert_eq!(get_relative_path("a/b/c", "a/b/c"), ".");
516 assert_eq!(get_relative_path("a/b/c", "x/y/z"), "../../../x/y/z");
517 assert_eq!(get_relative_path("a/b", "x/y"), "../../x/y");
518 }
519
520 #[test]
521 fn test_parse_content_type() {
522 assert_eq!(parse_content_type("file.txt"), "text/plain");
523 assert_eq!(parse_content_type("data.json"), "application/json");
524 assert_eq!(parse_content_type("script.py"), "text/x-python");
525 assert_eq!(parse_content_type("code.rs"), "text/x-rust");
526 assert_eq!(parse_content_type("model.onnx"), "application/onnx");
527 assert_eq!(
528 parse_content_type("package.torshpkg"),
529 "application/x-torsh-package"
530 );
531 assert_eq!(
532 parse_content_type("unknown.xyz"),
533 "application/octet-stream"
534 );
535 }
536
537 #[test]
538 fn test_performance_timer() {
539 let timer = PerformanceTimer::new("test");
540 std::thread::sleep(std::time::Duration::from_millis(10));
541 let elapsed = timer.elapsed_ms();
542 assert!(elapsed >= 10);
543 assert!(elapsed < 100);
544 }
545
546 #[test]
547 fn test_memory_stats() {
548 let mut stats = MemoryStats::new();
549 assert_eq!(stats.allocated, 0);
550 assert_eq!(stats.peak, 0);
551
552 stats.record_allocation(1024);
553 assert_eq!(stats.allocated, 1024);
554 assert_eq!(stats.peak, 1024);
555 assert_eq!(stats.allocations, 1);
556
557 stats.record_allocation(512);
558 assert_eq!(stats.allocated, 1536);
559 assert_eq!(stats.peak, 1536);
560 assert_eq!(stats.allocations, 2);
561
562 stats.record_deallocation(512);
563 assert_eq!(stats.allocated, 1024);
564 assert_eq!(stats.peak, 1536); let formatted = stats.format();
567 assert!(formatted.contains("KB"));
568 }
569}