trustformers_models/weight_loading/
utils.rs1use std::path::Path;
5use trustformers_core::errors::{ErrorKind, Result, TrustformersError};
6
7use super::config::{DistributedConfig, WeightLoadingConfig};
8use super::distributed::DistributedWeightLoader;
9use super::gguf::GGUFLoader;
10use super::huggingface::{HuggingFaceLoader, WeightLoader};
11use super::memory_mapped::MemoryMappedLoader;
12
13pub fn create_huggingface_loader(
15 model_dir: impl AsRef<Path>,
16 config: Option<WeightLoadingConfig>,
17) -> Result<Box<dyn WeightLoader>> {
18 let config = config.unwrap_or_default();
19 let loader = HuggingFaceLoader::new(model_dir, config)?;
20 Ok(Box::new(loader))
21}
22
23pub fn create_memory_mapped_loader(path: impl AsRef<Path>) -> Result<Box<dyn WeightLoader>> {
25 let loader = MemoryMappedLoader::new(path)?;
26 Ok(Box::new(loader))
27}
28
29pub fn create_gguf_loader(path: impl AsRef<Path>) -> Result<Box<dyn WeightLoader>> {
31 let loader = GGUFLoader::new(path)?;
32 Ok(Box::new(loader))
33}
34
35pub fn create_distributed_loader(
37 config: WeightLoadingConfig,
38 distributed_config: DistributedConfig,
39) -> Result<Box<dyn WeightLoader>> {
40 let loader = DistributedWeightLoader::new(config, distributed_config)?;
41 Ok(Box::new(loader))
42}
43
44pub fn auto_create_loader(
46 path: impl AsRef<Path>,
47 config: Option<WeightLoadingConfig>,
48) -> Result<Box<dyn WeightLoader>> {
49 let path = path.as_ref();
50 let config = config.unwrap_or_default();
51
52 if let Some(distributed_config) = config.distributed.clone() {
54 return create_distributed_loader(config, distributed_config);
55 }
56
57 if path.is_dir() {
58 create_huggingface_loader(path, Some(config))
60 } else if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
61 if config.memory_mapped {
63 create_memory_mapped_loader(path)
64 } else {
65 create_huggingface_loader(path.parent().unwrap_or(path), Some(config))
67 }
68 } else if path.extension().and_then(|s| s.to_str()) == Some("gguf") {
69 create_gguf_loader(path)
71 } else {
72 Err(TrustformersError::new(ErrorKind::InvalidFormat {
73 expected: "HuggingFace directory, .safetensors, or .gguf".to_string(),
74 actual: "Unknown weight format".to_string(),
75 }))
76 }
77}