Skip to main content

torsh_package/
utils.rs

1//! Utility functions
2//!
3//! This module contains common utility functions used throughout the package system,
4//! including cryptographic operations, validation, and helper functions.
5
6use sha2::{Digest, Sha256};
7use std::path::Path;
8use torsh_core::error::Result;
9
10use crate::package::Package;
11
12/// Calculate SHA-256 hash of data
13pub fn calculate_hash(data: &[u8]) -> String {
14    let mut hasher = Sha256::new();
15    hasher.update(data);
16    hex::encode(hasher.finalize())
17}
18
19/// Quick export function (temporarily disabled - requires torsh-nn)
20#[cfg(feature = "with-nn")]
21pub fn export_module<M: torsh_nn::Module, P: AsRef<Path>>(
22    module: &M,
23    path: P,
24    name: &str,
25    version: &str,
26) -> Result<()> {
27    crate::builder::PackageBuilder::new(name.to_string(), version.to_string())
28        .add_module("main", module)?
29        .build(path)
30}
31
32/// Quick import function
33pub fn import_module<P: AsRef<Path>>(path: P, module_name: &str) -> Result<Vec<u8>> {
34    let package = Package::load(path)?;
35    package.get_module(module_name)
36}
37
38/// Validate package name according to naming conventions
39pub fn validate_package_name(name: &str) -> bool {
40    if name.is_empty() || name.len() > 100 {
41        return false;
42    }
43
44    // Must start with alphanumeric, can contain alphanumeric, hyphens, and underscores
45    let first_char = name
46        .chars()
47        .next()
48        .expect("name is not empty after length check");
49    if !first_char.is_alphanumeric() {
50        return false;
51    }
52
53    name.chars()
54        .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
55}
56
57/// Validate semantic version string
58pub fn validate_version(version: &str) -> bool {
59    semver::Version::parse(version).is_ok()
60}
61
62/// Get file extension from resource name
63pub fn get_file_extension(filename: &str) -> Option<&str> {
64    std::path::Path::new(filename)
65        .extension()
66        .and_then(std::ffi::OsStr::to_str)
67}
68
69/// Sanitize filename for safe storage
70pub fn sanitize_filename(filename: &str) -> String {
71    filename
72        .chars()
73        .map(|c| {
74            if c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_') {
75                c
76            } else {
77                '_'
78            }
79        })
80        .collect()
81}
82
83/// Check if file path is safe (no directory traversal)
84pub fn is_safe_path(path: &str) -> bool {
85    !path.contains("..") && !path.starts_with('/') && !path.starts_with('\\')
86}
87
88/// Format file size in human-readable format
89pub fn format_file_size(size: u64) -> String {
90    const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
91    const THRESHOLD: u64 = 1024;
92
93    if size == 0 {
94        return "0 B".to_string();
95    }
96
97    let mut size = size as f64;
98    let mut unit_index = 0;
99
100    while size >= THRESHOLD as f64 && unit_index < UNITS.len() - 1 {
101        size /= THRESHOLD as f64;
102        unit_index += 1;
103    }
104
105    if unit_index == 0 {
106        format!("{} {}", size as u64, UNITS[unit_index])
107    } else {
108        format!("{:.1} {}", size, UNITS[unit_index])
109    }
110}
111
112/// Estimate compression ratio for data
113pub fn estimate_compression_ratio(data: &[u8]) -> f64 {
114    if data.is_empty() {
115        return 1.0;
116    }
117
118    // Simple entropy estimation
119    let mut counts = [0u32; 256];
120    for &byte in data {
121        counts[byte as usize] += 1;
122    }
123
124    let len = data.len() as f64;
125    let mut entropy = 0.0;
126
127    for &count in &counts {
128        if count > 0 {
129            let p = count as f64 / len;
130            entropy -= p * p.log2();
131        }
132    }
133
134    // Rough compression ratio estimation based on entropy
135    // Maximum entropy is 8 bits, so we can estimate compression potential
136    let max_entropy = 8.0;
137    let compression_potential = (max_entropy - entropy) / max_entropy;
138    1.0 - compression_potential.max(0.0).min(0.9) // Cap at 90% compression
139}
140
141/// Validate resource path is within package bounds
142pub fn validate_resource_path(path: &str) -> Result<()> {
143    use torsh_core::error::TorshError;
144
145    if path.is_empty() {
146        return Err(TorshError::InvalidArgument(
147            "Resource path cannot be empty".to_string(),
148        ));
149    }
150
151    if path.len() > 1024 {
152        return Err(TorshError::InvalidArgument(
153            "Resource path exceeds maximum length of 1024 characters".to_string(),
154        ));
155    }
156
157    if !is_safe_path(path) {
158        return Err(TorshError::InvalidArgument(format!(
159            "Resource path contains unsafe components: {}",
160            path
161        )));
162    }
163
164    Ok(())
165}
166
167/// Validate package metadata integrity
168pub fn validate_package_metadata(
169    name: &str,
170    version: &str,
171    description: Option<&str>,
172) -> Result<()> {
173    use torsh_core::error::TorshError;
174
175    if !validate_package_name(name) {
176        return Err(TorshError::InvalidArgument(format!(
177            "Invalid package name: {}",
178            name
179        )));
180    }
181
182    if !validate_version(version) {
183        return Err(TorshError::InvalidArgument(format!(
184            "Invalid semantic version: {}",
185            version
186        )));
187    }
188
189    if let Some(desc) = description {
190        if desc.len() > 10000 {
191            return Err(TorshError::InvalidArgument(
192                "Package description exceeds maximum length of 10000 characters".to_string(),
193            ));
194        }
195    }
196
197    Ok(())
198}
199
200/// Calculate checksum for integrity verification
201pub fn calculate_checksum(data: &[u8]) -> u64 {
202    // Simple CRC-64 implementation
203    let mut checksum = 0u64;
204    for &byte in data {
205        checksum = checksum.wrapping_mul(31).wrapping_add(byte as u64);
206    }
207    checksum
208}
209
210/// Verify data integrity using checksum
211pub fn verify_checksum(data: &[u8], expected: u64) -> bool {
212    calculate_checksum(data) == expected
213}
214
215/// Normalize path separators to forward slashes
216pub fn normalize_path(path: &str) -> String {
217    path.replace('\\', "/")
218}
219
220/// Get relative path between two paths
221pub fn get_relative_path(from: &str, to: &str) -> String {
222    let from_parts: Vec<&str> = from.split('/').filter(|s| !s.is_empty()).collect();
223    let to_parts: Vec<&str> = to.split('/').filter(|s| !s.is_empty()).collect();
224
225    let mut common = 0;
226    for (a, b) in from_parts.iter().zip(to_parts.iter()) {
227        if a == b {
228            common += 1;
229        } else {
230            break;
231        }
232    }
233
234    let mut result = Vec::new();
235    for _ in common..from_parts.len() {
236        result.push("..");
237    }
238    result.extend(to_parts[common..].iter());
239
240    if result.is_empty() {
241        ".".to_string()
242    } else {
243        result.join("/")
244    }
245}
246
247/// Parse content type from file extension
248pub fn parse_content_type(filename: &str) -> &'static str {
249    match get_file_extension(filename) {
250        Some("txt") | Some("md") => "text/plain",
251        Some("json") => "application/json",
252        Some("xml") => "application/xml",
253        Some("html") => "text/html",
254        Some("css") => "text/css",
255        Some("js") => "application/javascript",
256        Some("py") => "text/x-python",
257        Some("rs") => "text/x-rust",
258        Some("zip") => "application/zip",
259        Some("tar") => "application/x-tar",
260        Some("gz") => "application/gzip",
261        Some("torshpkg") => "application/x-torsh-package",
262        Some("onnx") => "application/onnx",
263        Some("pkl") | Some("pickle") => "application/python-pickle",
264        _ => "application/octet-stream",
265    }
266}
267
268/// Performance timer for operation profiling
269#[derive(Debug, Clone)]
270pub struct PerformanceTimer {
271    start: std::time::Instant,
272    name: String,
273}
274
275impl PerformanceTimer {
276    /// Create a new performance timer
277    pub fn new(name: impl Into<String>) -> Self {
278        Self {
279            start: std::time::Instant::now(),
280            name: name.into(),
281        }
282    }
283
284    /// Get elapsed time in milliseconds
285    pub fn elapsed_ms(&self) -> u64 {
286        self.start.elapsed().as_millis() as u64
287    }
288
289    /// Get elapsed time in seconds
290    pub fn elapsed_secs(&self) -> f64 {
291        self.start.elapsed().as_secs_f64()
292    }
293
294    /// Print elapsed time
295    pub fn print_elapsed(&self) {
296        eprintln!("[{}] Elapsed: {:.3}s", self.name, self.elapsed_secs());
297    }
298
299    /// Reset the timer
300    pub fn reset(&mut self) {
301        self.start = std::time::Instant::now();
302    }
303}
304
305impl Drop for PerformanceTimer {
306    fn drop(&mut self) {
307        if cfg!(debug_assertions) {
308            self.print_elapsed();
309        }
310    }
311}
312
313/// Memory usage statistics
314#[derive(Debug, Clone, Default)]
315pub struct MemoryStats {
316    /// Total bytes allocated
317    pub allocated: u64,
318    /// Peak memory usage
319    pub peak: u64,
320    /// Number of allocations
321    pub allocations: u64,
322}
323
324impl MemoryStats {
325    /// Create new memory statistics
326    pub fn new() -> Self {
327        Self::default()
328    }
329
330    /// Record allocation
331    pub fn record_allocation(&mut self, size: u64) {
332        self.allocated += size;
333        self.allocations += 1;
334        if self.allocated > self.peak {
335            self.peak = self.allocated;
336        }
337    }
338
339    /// Record deallocation
340    pub fn record_deallocation(&mut self, size: u64) {
341        self.allocated = self.allocated.saturating_sub(size);
342    }
343
344    /// Format memory stats as human-readable string
345    pub fn format(&self) -> String {
346        format!(
347            "Allocated: {}, Peak: {}, Allocations: {}",
348            format_file_size(self.allocated),
349            format_file_size(self.peak),
350            self.allocations
351        )
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_calculate_hash() {
361        let data = b"hello world";
362        let hash = calculate_hash(data);
363
364        // SHA256 of "hello world" is known
365        assert_eq!(
366            hash,
367            "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
368        );
369    }
370
371    #[test]
372    fn test_validate_package_name() {
373        assert!(validate_package_name("my-package"));
374        assert!(validate_package_name("package_name"));
375        assert!(validate_package_name("Package123"));
376
377        assert!(!validate_package_name(""));
378        assert!(!validate_package_name("-invalid"));
379        assert!(!validate_package_name("invalid@name"));
380        assert!(!validate_package_name("a".repeat(101).as_str()));
381    }
382
383    #[test]
384    fn test_validate_version() {
385        assert!(validate_version("1.0.0"));
386        assert!(validate_version("2.1.3-alpha.1"));
387        assert!(validate_version("0.0.1-beta+build.123"));
388
389        assert!(!validate_version(""));
390        assert!(!validate_version("1.0"));
391        assert!(!validate_version("invalid"));
392    }
393
394    #[test]
395    fn test_get_file_extension() {
396        assert_eq!(get_file_extension("file.txt"), Some("txt"));
397        assert_eq!(get_file_extension("archive.tar.gz"), Some("gz"));
398        assert_eq!(get_file_extension("README"), None);
399        assert_eq!(get_file_extension(".hidden"), None);
400    }
401
402    #[test]
403    fn test_sanitize_filename() {
404        assert_eq!(sanitize_filename("normal_file.txt"), "normal_file.txt");
405        assert_eq!(
406            sanitize_filename("file with spaces.txt"),
407            "file_with_spaces.txt"
408        );
409        assert_eq!(sanitize_filename("file@#$%.txt"), "file____.txt");
410        assert_eq!(sanitize_filename("αβγ.txt"), "___.txt");
411    }
412
413    #[test]
414    fn test_is_safe_path() {
415        assert!(is_safe_path("safe/path/file.txt"));
416        assert!(is_safe_path("file.txt"));
417        assert!(is_safe_path("subdir/file.txt"));
418
419        assert!(!is_safe_path("../etc/passwd"));
420        assert!(!is_safe_path("/absolute/path"));
421        assert!(!is_safe_path("\\windows\\path"));
422        assert!(!is_safe_path("safe/../unsafe"));
423    }
424
425    #[test]
426    fn test_format_file_size() {
427        assert_eq!(format_file_size(0), "0 B");
428        assert_eq!(format_file_size(512), "512 B");
429        assert_eq!(format_file_size(1024), "1.0 KB");
430        assert_eq!(format_file_size(1536), "1.5 KB");
431        assert_eq!(format_file_size(1048576), "1.0 MB");
432        assert_eq!(format_file_size(1073741824), "1.0 GB");
433    }
434
435    #[test]
436    fn test_estimate_compression_ratio() {
437        // Highly repetitive data should compress well
438        let repetitive = vec![b'A'; 1000];
439        let ratio = estimate_compression_ratio(&repetitive);
440        assert!(ratio < 0.5); // Should compress to less than 50%
441
442        // Random data should compress poorly
443        let random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
444        let ratio = estimate_compression_ratio(&random);
445        assert!(ratio > 0.8); // Should compress poorly
446
447        // Empty data
448        assert_eq!(estimate_compression_ratio(&[]), 1.0);
449    }
450
451    #[test]
452    fn test_import_module_nonexistent() {
453        let temp_dir = tempfile::TempDir::new().unwrap();
454        let nonexistent_path = temp_dir.path().join("nonexistent.torshpkg");
455
456        let result = import_module(nonexistent_path, "test");
457        assert!(result.is_err());
458    }
459
460    #[test]
461    fn test_validate_resource_path() {
462        assert!(validate_resource_path("valid/path.txt").is_ok());
463        assert!(validate_resource_path("another_file.rs").is_ok());
464
465        assert!(validate_resource_path("").is_err());
466        assert!(validate_resource_path("../unsafe").is_err());
467        assert!(validate_resource_path("/absolute").is_err());
468        assert!(validate_resource_path(&"x".repeat(1025)).is_err());
469    }
470
471    #[test]
472    fn test_validate_package_metadata() {
473        assert!(validate_package_metadata("my-package", "1.0.0", None).is_ok());
474        assert!(validate_package_metadata("test", "2.1.3", Some("A test package")).is_ok());
475
476        assert!(validate_package_metadata("", "1.0.0", None).is_err());
477        assert!(validate_package_metadata("test", "invalid", None).is_err());
478        assert!(validate_package_metadata("test", "1.0.0", Some(&"x".repeat(10001))).is_err());
479    }
480
481    #[test]
482    fn test_calculate_checksum() {
483        let data1 = b"hello world";
484        let data2 = b"hello world";
485        let data3 = b"different data";
486
487        let checksum1 = calculate_checksum(data1);
488        let checksum2 = calculate_checksum(data2);
489        let checksum3 = calculate_checksum(data3);
490
491        assert_eq!(checksum1, checksum2);
492        assert_ne!(checksum1, checksum3);
493    }
494
495    #[test]
496    fn test_verify_checksum() {
497        let data = b"test data";
498        let checksum = calculate_checksum(data);
499
500        assert!(verify_checksum(data, checksum));
501        assert!(!verify_checksum(data, checksum + 1));
502    }
503
504    #[test]
505    fn test_normalize_path() {
506        assert_eq!(normalize_path("path/to/file"), "path/to/file");
507        assert_eq!(normalize_path("path\\to\\file"), "path/to/file");
508        assert_eq!(normalize_path("mixed\\path/to\\file"), "mixed/path/to/file");
509    }
510
511    #[test]
512    fn test_get_relative_path() {
513        assert_eq!(get_relative_path("a/b/c", "a/b/d"), "../d");
514        assert_eq!(get_relative_path("a/b", "a/b/c/d"), "c/d");
515        assert_eq!(get_relative_path("a/b/c", "a/b/c"), ".");
516        assert_eq!(get_relative_path("a/b/c", "x/y/z"), "../../../x/y/z");
517        assert_eq!(get_relative_path("a/b", "x/y"), "../../x/y");
518    }
519
520    #[test]
521    fn test_parse_content_type() {
522        assert_eq!(parse_content_type("file.txt"), "text/plain");
523        assert_eq!(parse_content_type("data.json"), "application/json");
524        assert_eq!(parse_content_type("script.py"), "text/x-python");
525        assert_eq!(parse_content_type("code.rs"), "text/x-rust");
526        assert_eq!(parse_content_type("model.onnx"), "application/onnx");
527        assert_eq!(
528            parse_content_type("package.torshpkg"),
529            "application/x-torsh-package"
530        );
531        assert_eq!(
532            parse_content_type("unknown.xyz"),
533            "application/octet-stream"
534        );
535    }
536
537    #[test]
538    fn test_performance_timer() {
539        let timer = PerformanceTimer::new("test");
540        std::thread::sleep(std::time::Duration::from_millis(10));
541        let elapsed = timer.elapsed_ms();
542        assert!(elapsed >= 10);
543        assert!(elapsed < 100);
544    }
545
546    #[test]
547    fn test_memory_stats() {
548        let mut stats = MemoryStats::new();
549        assert_eq!(stats.allocated, 0);
550        assert_eq!(stats.peak, 0);
551
552        stats.record_allocation(1024);
553        assert_eq!(stats.allocated, 1024);
554        assert_eq!(stats.peak, 1024);
555        assert_eq!(stats.allocations, 1);
556
557        stats.record_allocation(512);
558        assert_eq!(stats.allocated, 1536);
559        assert_eq!(stats.peak, 1536);
560        assert_eq!(stats.allocations, 2);
561
562        stats.record_deallocation(512);
563        assert_eq!(stats.allocated, 1024);
564        assert_eq!(stats.peak, 1536); // Peak stays the same
565
566        let formatted = stats.format();
567        assert!(formatted.contains("KB"));
568    }
569}