Skip to main content

trustformers_models/weight_loading/
utils.rs

1/// Utility functions for weight loading
2///
3/// This module provides convenience functions for creating different types of weight loaders.
4use std::path::Path;
5use trustformers_core::errors::{ErrorKind, Result, TrustformersError};
6
7use super::config::{DistributedConfig, WeightLoadingConfig};
8use super::distributed::DistributedWeightLoader;
9use super::gguf::GGUFLoader;
10use super::huggingface::{HuggingFaceLoader, WeightLoader};
11use super::memory_mapped::MemoryMappedLoader;
12
13/// Create a HuggingFace weight loader
14pub fn create_huggingface_loader(
15    model_dir: impl AsRef<Path>,
16    config: Option<WeightLoadingConfig>,
17) -> Result<Box<dyn WeightLoader>> {
18    let config = config.unwrap_or_default();
19    let loader = HuggingFaceLoader::new(model_dir, config)?;
20    Ok(Box::new(loader))
21}
22
23/// Create a memory-mapped weight loader
24pub fn create_memory_mapped_loader(path: impl AsRef<Path>) -> Result<Box<dyn WeightLoader>> {
25    let loader = MemoryMappedLoader::new(path)?;
26    Ok(Box::new(loader))
27}
28
29/// Create a GGUF loader
30pub fn create_gguf_loader(path: impl AsRef<Path>) -> Result<Box<dyn WeightLoader>> {
31    let loader = GGUFLoader::new(path)?;
32    Ok(Box::new(loader))
33}
34
35/// Create a distributed weight loader
36pub fn create_distributed_loader(
37    config: WeightLoadingConfig,
38    distributed_config: DistributedConfig,
39) -> Result<Box<dyn WeightLoader>> {
40    let loader = DistributedWeightLoader::new(config, distributed_config)?;
41    Ok(Box::new(loader))
42}
43
44/// Auto-detect format and create appropriate loader
45pub fn auto_create_loader(
46    path: impl AsRef<Path>,
47    config: Option<WeightLoadingConfig>,
48) -> Result<Box<dyn WeightLoader>> {
49    let path = path.as_ref();
50    let config = config.unwrap_or_default();
51
52    // Check if distributed loading is configured
53    if let Some(distributed_config) = config.distributed.clone() {
54        return create_distributed_loader(config, distributed_config);
55    }
56
57    if path.is_dir() {
58        // Directory - assume HuggingFace format
59        create_huggingface_loader(path, Some(config))
60    } else if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
61        // Single SafeTensors file
62        if config.memory_mapped {
63            create_memory_mapped_loader(path)
64        } else {
65            // Create single-file HuggingFace loader
66            create_huggingface_loader(path.parent().unwrap_or(path), Some(config))
67        }
68    } else if path.extension().and_then(|s| s.to_str()) == Some("gguf") {
69        // GGUF file
70        create_gguf_loader(path)
71    } else {
72        Err(TrustformersError::new(ErrorKind::InvalidFormat {
73            expected: "HuggingFace directory, .safetensors, or .gguf".to_string(),
74            actual: "Unknown weight format".to_string(),
75        }))
76    }
77}