tenflowers_dataset/distributed_streaming/
types.rs1use crate::{
4 distributed_sharding::{ShardConfig, ShardStrategy},
5 error_taxonomy::helpers as error_helpers,
6 Dataset,
7};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::marker::PhantomData;
11use std::sync::{Arc, Mutex, RwLock};
12use tenflowers_core::{Result, Tensor, TensorError};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct StreamingConfig {
17 pub world_size: usize,
19 pub rank: usize,
21 pub partition_strategy: PartitionStrategy,
23 pub prefetch_buffer_size: usize,
25 pub shuffle_seed: Option<u64>,
27 pub checkpoint_interval: Option<usize>,
29 pub fault_tolerant: bool,
31 pub replication_factor: usize,
33 pub dynamic_balancing: bool,
35}
36
37impl Default for StreamingConfig {
38 fn default() -> Self {
39 Self {
40 world_size: 1,
41 rank: 0,
42 partition_strategy: PartitionStrategy::HashBased {
43 num_partitions: 1,
44 hash_seed: 0,
45 },
46 prefetch_buffer_size: 128,
47 shuffle_seed: None,
48 checkpoint_interval: Some(1000),
49 fault_tolerant: false,
50 replication_factor: 1,
51 dynamic_balancing: false,
52 }
53 }
54}
55
56impl StreamingConfig {
57 pub fn new(world_size: usize, rank: usize) -> Result<Self> {
59 if world_size == 0 {
60 return Err(error_helpers::invalid_configuration(
61 "StreamingConfig::new",
62 "world_size",
63 "world_size must be > 0",
64 ));
65 }
66
67 if rank >= world_size {
68 return Err(error_helpers::invalid_configuration(
69 "StreamingConfig::new",
70 "rank",
71 format!("rank {} must be < world_size {}", rank, world_size),
72 ));
73 }
74
75 Ok(Self {
76 world_size,
77 rank,
78 ..Default::default()
79 })
80 }
81
82 pub fn with_partition_strategy(mut self, strategy: PartitionStrategy) -> Self {
84 self.partition_strategy = strategy;
85 self
86 }
87
88 pub fn with_prefetch_buffer_size(mut self, size: usize) -> Self {
90 self.prefetch_buffer_size = size;
91 self
92 }
93
94 pub fn with_shuffle_seed(mut self, seed: u64) -> Self {
96 self.shuffle_seed = Some(seed);
97 self
98 }
99
100 pub fn with_checkpointing(mut self, interval: usize) -> Self {
102 self.checkpoint_interval = Some(interval);
103 self
104 }
105
106 pub fn with_fault_tolerance(mut self, replication_factor: usize) -> Self {
108 self.fault_tolerant = true;
109 self.replication_factor = replication_factor;
110 self
111 }
112
113 pub fn with_dynamic_balancing(mut self, enabled: bool) -> Self {
115 self.dynamic_balancing = enabled;
116 self
117 }
118
119 pub fn validate(&self) -> Result<()> {
121 if self.world_size == 0 {
122 return Err(error_helpers::invalid_configuration(
123 "StreamingConfig::validate",
124 "world_size",
125 "world_size must be > 0",
126 ));
127 }
128
129 if self.rank >= self.world_size {
130 return Err(error_helpers::invalid_configuration(
131 "StreamingConfig::validate",
132 "rank",
133 format!(
134 "rank {} must be < world_size {}",
135 self.rank, self.world_size
136 ),
137 ));
138 }
139
140 if self.replication_factor > self.world_size {
141 return Err(error_helpers::invalid_configuration(
142 "StreamingConfig::validate",
143 "replication_factor",
144 format!(
145 "replication_factor {} cannot exceed world_size {}",
146 self.replication_factor, self.world_size
147 ),
148 ));
149 }
150
151 Ok(())
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum PartitionStrategy {
158 RoundRobin,
160
161 Contiguous,
163
164 HashBased {
166 num_partitions: usize,
167 hash_seed: u64,
168 },
169
170 RangeBased { ranges: Vec<(usize, usize)> },
172
173 Stratified { num_classes: usize },
175
176 Adaptive {
178 base_strategy: Box<PartitionStrategy>,
179 rebalance_threshold: f64,
180 },
181
182 Custom { partition_id: String },
184}
185
186impl Default for PartitionStrategy {
187 fn default() -> Self {
188 Self::RoundRobin
189 }
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct CheckpointState {
195 pub epoch: usize,
197 pub position: usize,
199 pub shuffle_seed: Option<u64>,
201 pub rank: usize,
203 pub timestamp: u64,
205 pub processed_indices: HashSet<usize>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct StreamingStats {
212 pub samples_loaded: u64,
214 pub local_samples: u64,
216 pub remote_samples: u64,
218 pub prefetch_hits: u64,
220 pub prefetch_misses: u64,
222 pub avg_load_time_us: u64,
224 pub num_checkpoints: u64,
226 pub worker_utilization: f64,
228}
229
230impl Default for StreamingStats {
231 fn default() -> Self {
232 Self {
233 samples_loaded: 0,
234 local_samples: 0,
235 remote_samples: 0,
236 prefetch_hits: 0,
237 prefetch_misses: 0,
238 avg_load_time_us: 0,
239 num_checkpoints: 0,
240 worker_utilization: 0.0,
241 }
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct WorkerHealth {
248 pub rank: usize,
249 pub status: WorkerStatus,
250 pub last_heartbeat: u64,
251 pub samples_processed: u64,
252 pub average_throughput: f64,
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257pub enum WorkerStatus {
258 Active,
259 Idle,
260 Slow,
261 Failed,
262 Unknown,
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct WorkerMetrics {
268 pub rank: usize,
269 pub throughput_samples_per_sec: f64,
270 pub queue_depth: usize,
271 pub cpu_utilization: f64,
272 pub memory_usage_mb: f64,
273}