Skip to main content

trustformers_models/weight_loading/
config.rs

1/// Weight Loading Configuration
2///
3/// This module contains all configuration structs and enums for weight loading functionality.
4use std::path::PathBuf;
5use std::time::Duration;
6
7/// Supported weight file formats
8#[derive(Debug, Clone, PartialEq)]
9pub enum WeightFormat {
10    HuggingFaceBin, // PyTorch .bin files
11    SafeTensors,    // SafeTensors format
12    ONNX,           // ONNX format
13    TensorFlow,     // TensorFlow SavedModel
14    GGUF,           // GGUF quantized format
15    Custom(String), // Custom format
16}
17
18/// Weight loading configuration
19#[derive(Debug, Clone)]
20pub struct WeightLoadingConfig {
21    pub format: Option<WeightFormat>,
22    pub lazy_loading: bool,
23    pub memory_mapped: bool,
24    pub streaming: bool,
25    pub device: String,
26    pub dtype: WeightDataType,
27    pub quantization: Option<QuantizationConfig>,
28    pub cache_dir: Option<PathBuf>,
29    pub verify_checksums: bool,
30    pub distributed: Option<DistributedConfig>,
31}
32
33/// Distributed weight loading configuration
34#[derive(Debug, Clone)]
35pub struct DistributedConfig {
36    /// List of worker nodes for distributed loading
37    pub nodes: Vec<NodeConfig>,
38    /// Load balancing strategy
39    pub load_balancer: LoadBalancingStrategy,
40    /// Fault tolerance settings
41    pub fault_tolerance: FaultToleranceConfig,
42    /// Network settings
43    pub network: NetworkConfig,
44    /// Caching strategy across nodes
45    pub distributed_cache: DistributedCacheConfig,
46    /// Enable compression for network transfer
47    pub compression: bool,
48}
49
50/// Configuration for individual nodes
51#[derive(Debug, Clone)]
52pub struct NodeConfig {
53    pub id: String,
54    pub address: String,
55    pub port: u16,
56    pub weight_capacity: u64, // Maximum weights this node can hold (bytes)
57    pub bandwidth: f64,       // Network bandwidth to this node (MB/s)
58    pub priority: u8,         // Higher priority nodes are preferred (0-255)
59    pub storage_paths: Vec<PathBuf>, // Paths where weights are stored on this node
60}
61
62/// Load balancing strategies for distributed weight loading
63#[derive(Debug, Clone, PartialEq)]
64pub enum LoadBalancingStrategy {
65    RoundRobin,
66    LeastLoaded,
67    WeightedRoundRobin,
68    ConsistentHashing,
69    LocalityAware,
70    PerformanceBased,
71    Adaptive,
72}
73
74/// Fault tolerance configuration
75#[derive(Debug, Clone)]
76pub struct FaultToleranceConfig {
77    pub max_retries: u32,
78    pub retry_delay: Duration,
79    pub timeout: Duration,
80    pub enable_failover: bool,
81    pub health_check_interval: Duration,
82    pub backup_nodes: Vec<String>, // Node IDs to use as backups
83}
84
85/// Network configuration for distributed loading
86#[derive(Debug, Clone)]
87pub struct NetworkConfig {
88    pub max_concurrent_connections: usize,
89    pub connection_timeout: Duration,
90    pub read_timeout: Duration,
91    pub chunk_size: usize,
92    pub enable_keepalive: bool,
93    pub compression_threshold: usize, // Compress transfers larger than this
94}
95
96/// Distributed cache configuration
97#[derive(Debug, Clone)]
98pub struct DistributedCacheConfig {
99    pub cache_strategy: CacheStrategy,
100    pub replication_factor: u8, // How many nodes should cache each tensor
101    pub eviction_policy: CacheEvictionPolicy,
102    pub consistency_level: ConsistencyLevel,
103}
104
105#[derive(Debug, Clone, PartialEq)]
106pub enum CacheStrategy {
107    None,
108    ReadThrough,
109    WriteThrough,
110    WriteBack,
111    ReadAround,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub enum CacheEvictionPolicy {
116    LRU,
117    LFU,
118    FIFO,
119    Random,
120    TTL,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum ConsistencyLevel {
125    Strong,
126    Eventual,
127    Weak,
128}
129
130#[derive(Debug, Clone)]
131pub enum WeightDataType {
132    Float32,
133    Float16,
134    BFloat16,
135    Int8,
136    Int4,
137}
138
139#[derive(Debug, Clone)]
140pub struct QuantizationConfig {
141    pub bits: u8,
142    pub group_size: Option<usize>,
143    pub symmetric: bool,
144}
145
146impl Default for WeightLoadingConfig {
147    fn default() -> Self {
148        Self {
149            format: None,
150            lazy_loading: false,
151            memory_mapped: false,
152            streaming: false,
153            device: "cpu".to_string(),
154            dtype: WeightDataType::Float32,
155            quantization: None,
156            cache_dir: None,
157            verify_checksums: true,
158            distributed: None,
159        }
160    }
161}
162
163impl Default for FaultToleranceConfig {
164    fn default() -> Self {
165        Self {
166            max_retries: 3,
167            retry_delay: Duration::from_millis(1000),
168            timeout: Duration::from_secs(30),
169            enable_failover: true,
170            health_check_interval: Duration::from_secs(60),
171            backup_nodes: Vec::new(),
172        }
173    }
174}
175
176impl Default for NetworkConfig {
177    fn default() -> Self {
178        Self {
179            max_concurrent_connections: 10,
180            connection_timeout: Duration::from_secs(30),
181            read_timeout: Duration::from_secs(60),
182            chunk_size: 8192,
183            enable_keepalive: true,
184            compression_threshold: 1024 * 1024, // 1MB
185        }
186    }
187}
188
189impl Default for DistributedCacheConfig {
190    fn default() -> Self {
191        Self {
192            cache_strategy: CacheStrategy::ReadThrough,
193            replication_factor: 2,
194            eviction_policy: CacheEvictionPolicy::LRU,
195            consistency_level: ConsistencyLevel::Eventual,
196        }
197    }
198}