Skip to main content

tenflowers_dataset/distributed_streaming/
types.rs

1//! Types, configs, and errors for distributed streaming
2
3use crate::{
4    distributed_sharding::{ShardConfig, ShardStrategy},
5    error_taxonomy::helpers as error_helpers,
6    Dataset,
7};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::marker::PhantomData;
11use std::sync::{Arc, Mutex, RwLock};
12use tenflowers_core::{Result, Tensor, TensorError};
13
14/// Configuration for distributed streaming
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct StreamingConfig {
17    /// Total number of workers in the distributed system
18    pub world_size: usize,
19    /// Current worker rank (0-indexed)
20    pub rank: usize,
21    /// Partition strategy for distributing data
22    pub partition_strategy: PartitionStrategy,
23    /// Buffer size for prefetching
24    pub prefetch_buffer_size: usize,
25    /// Enable deterministic shuffling with seed
26    pub shuffle_seed: Option<u64>,
27    /// Checkpoint interval (number of samples)
28    pub checkpoint_interval: Option<usize>,
29    /// Enable fault tolerance
30    pub fault_tolerant: bool,
31    /// Replication factor for fault tolerance
32    pub replication_factor: usize,
33    /// Dynamic load balancing enabled
34    pub dynamic_balancing: bool,
35}
36
37impl Default for StreamingConfig {
38    fn default() -> Self {
39        Self {
40            world_size: 1,
41            rank: 0,
42            partition_strategy: PartitionStrategy::HashBased {
43                num_partitions: 1,
44                hash_seed: 0,
45            },
46            prefetch_buffer_size: 128,
47            shuffle_seed: None,
48            checkpoint_interval: Some(1000),
49            fault_tolerant: false,
50            replication_factor: 1,
51            dynamic_balancing: false,
52        }
53    }
54}
55
56impl StreamingConfig {
57    /// Create a new streaming configuration
58    pub fn new(world_size: usize, rank: usize) -> Result<Self> {
59        if world_size == 0 {
60            return Err(error_helpers::invalid_configuration(
61                "StreamingConfig::new",
62                "world_size",
63                "world_size must be > 0",
64            ));
65        }
66
67        if rank >= world_size {
68            return Err(error_helpers::invalid_configuration(
69                "StreamingConfig::new",
70                "rank",
71                format!("rank {} must be < world_size {}", rank, world_size),
72            ));
73        }
74
75        Ok(Self {
76            world_size,
77            rank,
78            ..Default::default()
79        })
80    }
81
82    /// Set the partition strategy
83    pub fn with_partition_strategy(mut self, strategy: PartitionStrategy) -> Self {
84        self.partition_strategy = strategy;
85        self
86    }
87
88    /// Set the prefetch buffer size
89    pub fn with_prefetch_buffer_size(mut self, size: usize) -> Self {
90        self.prefetch_buffer_size = size;
91        self
92    }
93
94    /// Set the shuffle seed for deterministic shuffling
95    pub fn with_shuffle_seed(mut self, seed: u64) -> Self {
96        self.shuffle_seed = Some(seed);
97        self
98    }
99
100    /// Enable checkpointing with specified interval
101    pub fn with_checkpointing(mut self, interval: usize) -> Self {
102        self.checkpoint_interval = Some(interval);
103        self
104    }
105
106    /// Enable fault tolerance with replication
107    pub fn with_fault_tolerance(mut self, replication_factor: usize) -> Self {
108        self.fault_tolerant = true;
109        self.replication_factor = replication_factor;
110        self
111    }
112
113    /// Enable dynamic load balancing
114    pub fn with_dynamic_balancing(mut self, enabled: bool) -> Self {
115        self.dynamic_balancing = enabled;
116        self
117    }
118
119    /// Validate the configuration
120    pub fn validate(&self) -> Result<()> {
121        if self.world_size == 0 {
122            return Err(error_helpers::invalid_configuration(
123                "StreamingConfig::validate",
124                "world_size",
125                "world_size must be > 0",
126            ));
127        }
128
129        if self.rank >= self.world_size {
130            return Err(error_helpers::invalid_configuration(
131                "StreamingConfig::validate",
132                "rank",
133                format!(
134                    "rank {} must be < world_size {}",
135                    self.rank, self.world_size
136                ),
137            ));
138        }
139
140        if self.replication_factor > self.world_size {
141            return Err(error_helpers::invalid_configuration(
142                "StreamingConfig::validate",
143                "replication_factor",
144                format!(
145                    "replication_factor {} cannot exceed world_size {}",
146                    self.replication_factor, self.world_size
147                ),
148            ));
149        }
150
151        Ok(())
152    }
153}
154
155/// Advanced partition strategies for distributed streaming
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum PartitionStrategy {
158    /// Round-robin distribution (simple, balanced for uniform data)
159    RoundRobin,
160
161    /// Contiguous blocks (good for sequential access patterns)
162    Contiguous,
163
164    /// Hash-based partitioning (deterministic, good for key-based data)
165    HashBased {
166        num_partitions: usize,
167        hash_seed: u64,
168    },
169
170    /// Range-based partitioning (good for sorted data)
171    RangeBased { ranges: Vec<(usize, usize)> },
172
173    /// Stratified partitioning (maintains class distribution)
174    Stratified { num_classes: usize },
175
176    /// Adaptive partitioning (adjusts based on worker performance)
177    Adaptive {
178        base_strategy: Box<PartitionStrategy>,
179        rebalance_threshold: f64,
180    },
181
182    /// Custom partitioning (user-defined function)
183    Custom { partition_id: String },
184}
185
186impl Default for PartitionStrategy {
187    fn default() -> Self {
188        Self::RoundRobin
189    }
190}
191
192/// Checkpoint state for stream resumption
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct CheckpointState {
195    /// Current epoch
196    pub epoch: usize,
197    /// Current position in stream
198    pub position: usize,
199    /// Shuffle seed used
200    pub shuffle_seed: Option<u64>,
201    /// Worker rank
202    pub rank: usize,
203    /// Timestamp
204    pub timestamp: u64,
205    /// Indices processed so far
206    pub processed_indices: HashSet<usize>,
207}
208
209/// Statistics for streaming performance
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct StreamingStats {
212    /// Total samples loaded
213    pub samples_loaded: u64,
214    /// Samples loaded from local shard
215    pub local_samples: u64,
216    /// Samples loaded from remote workers
217    pub remote_samples: u64,
218    /// Prefetch buffer hits
219    pub prefetch_hits: u64,
220    /// Prefetch buffer misses
221    pub prefetch_misses: u64,
222    /// Average load time per sample (microseconds)
223    pub avg_load_time_us: u64,
224    /// Number of checkpoints created
225    pub num_checkpoints: u64,
226    /// Worker utilization (0.0 - 1.0)
227    pub worker_utilization: f64,
228}
229
230impl Default for StreamingStats {
231    fn default() -> Self {
232        Self {
233            samples_loaded: 0,
234            local_samples: 0,
235            remote_samples: 0,
236            prefetch_hits: 0,
237            prefetch_misses: 0,
238            avg_load_time_us: 0,
239            num_checkpoints: 0,
240            worker_utilization: 0.0,
241        }
242    }
243}
244
245/// Worker health status
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct WorkerHealth {
248    pub rank: usize,
249    pub status: WorkerStatus,
250    pub last_heartbeat: u64,
251    pub samples_processed: u64,
252    pub average_throughput: f64,
253}
254
255/// Worker status enumeration
256#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
257pub enum WorkerStatus {
258    Active,
259    Idle,
260    Slow,
261    Failed,
262    Unknown,
263}
264
265/// Worker performance metrics for load balancing
266#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct WorkerMetrics {
268    pub rank: usize,
269    pub throughput_samples_per_sec: f64,
270    pub queue_depth: usize,
271    pub cpu_utilization: f64,
272    pub memory_usage_mb: f64,
273}