1use std::sync::Arc;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Instant;
9
10use crossbeam_queue::SegQueue;
11use tokio::sync::broadcast;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(u8)]
16pub enum LoadPhase {
17 Connecting = 0,
19 Querying = 1,
21 Fetching = 2,
23 Parsing = 3,
25 Converting = 4,
27}
28
29impl LoadPhase {
30 pub fn from_u8(value: u8) -> Option<Self> {
32 match value {
33 0 => Some(Self::Connecting),
34 1 => Some(Self::Querying),
35 2 => Some(Self::Fetching),
36 3 => Some(Self::Parsing),
37 4 => Some(Self::Converting),
38 _ => None,
39 }
40 }
41
42 pub fn as_str(&self) -> &'static str {
44 match self {
45 Self::Connecting => "Connecting",
46 Self::Querying => "Querying",
47 Self::Fetching => "Fetching",
48 Self::Parsing => "Parsing",
49 Self::Converting => "Converting",
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56#[repr(u8)]
57pub enum ProgressGranularity {
58 #[default]
60 Coarse = 0,
61 Fine = 1,
63}
64
65impl ProgressGranularity {
66 pub fn from_u8(value: u8) -> Self {
68 match value {
69 1 => Self::Fine,
70 _ => Self::Coarse,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub enum ProgressEvent {
78 Phase {
80 operation_id: u64,
81 phase: LoadPhase,
82 source: String,
83 },
84
85 Progress {
87 operation_id: u64,
88 rows_processed: u64,
89 total_rows: Option<u64>,
90 bytes_processed: u64,
91 },
92
93 Complete {
95 operation_id: u64,
96 rows_loaded: u64,
97 duration_ms: u64,
98 },
99
100 Error { operation_id: u64, message: String },
102}
103
104impl ProgressEvent {
105 pub fn operation_id(&self) -> u64 {
107 match self {
108 Self::Phase { operation_id, .. } => *operation_id,
109 Self::Progress { operation_id, .. } => *operation_id,
110 Self::Complete { operation_id, .. } => *operation_id,
111 Self::Error { operation_id, .. } => *operation_id,
112 }
113 }
114}
115
116pub struct ProgressHandle {
118 operation_id: u64,
119 source: String,
120 registry: Arc<ProgressRegistry>,
121 start_time: Instant,
122 granularity: ProgressGranularity,
123}
124
125impl ProgressHandle {
126 pub fn phase(&self, phase: LoadPhase) {
128 self.registry.emit(ProgressEvent::Phase {
129 operation_id: self.operation_id,
130 phase,
131 source: self.source.clone(),
132 });
133 }
134
135 pub fn progress(&self, rows_processed: u64, total_rows: Option<u64>, bytes_processed: u64) {
137 if self.granularity == ProgressGranularity::Fine {
138 self.registry.emit(ProgressEvent::Progress {
139 operation_id: self.operation_id,
140 rows_processed,
141 total_rows,
142 bytes_processed,
143 });
144 }
145 }
146
147 pub fn complete(self, rows_loaded: u64) {
149 let duration_ms = self.start_time.elapsed().as_millis() as u64;
150 self.registry.emit(ProgressEvent::Complete {
151 operation_id: self.operation_id,
152 rows_loaded,
153 duration_ms,
154 });
155 }
156
157 pub fn error(self, message: String) {
159 self.registry.emit(ProgressEvent::Error {
160 operation_id: self.operation_id,
161 message,
162 });
163 }
164
165 pub fn operation_id(&self) -> u64 {
167 self.operation_id
168 }
169
170 pub fn granularity(&self) -> ProgressGranularity {
172 self.granularity
173 }
174}
175
176pub struct ProgressRegistry {
180 events: SegQueue<ProgressEvent>,
182 broadcast_tx: broadcast::Sender<ProgressEvent>,
184 next_id: AtomicU64,
186}
187
188impl ProgressRegistry {
189 pub fn new() -> Arc<Self> {
191 let (broadcast_tx, _) = broadcast::channel(256);
192 Arc::new(Self {
193 events: SegQueue::new(),
194 broadcast_tx,
195 next_id: AtomicU64::new(1),
196 })
197 }
198
199 pub fn start_operation(
201 self: &Arc<Self>,
202 source: &str,
203 granularity: ProgressGranularity,
204 ) -> ProgressHandle {
205 let operation_id = self.next_id.fetch_add(1, Ordering::SeqCst);
206 ProgressHandle {
207 operation_id,
208 source: source.to_string(),
209 registry: Arc::clone(self),
210 start_time: Instant::now(),
211 granularity,
212 }
213 }
214
215 fn emit(&self, event: ProgressEvent) {
217 self.events.push(event.clone());
219 let _ = self.broadcast_tx.send(event);
221 }
222
223 pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
225 self.broadcast_tx.subscribe()
226 }
227
228 pub fn poll(&self) -> Option<ProgressEvent> {
230 self.events.pop()
231 }
232
233 pub fn poll_all(&self) -> Vec<ProgressEvent> {
235 let mut events = Vec::new();
236 while let Some(event) = self.events.pop() {
237 events.push(event);
238 }
239 events
240 }
241
242 pub fn try_recv(&self) -> Option<ProgressEvent> {
244 self.poll()
245 }
246
247 pub fn is_empty(&self) -> bool {
249 self.events.is_empty()
250 }
251}
252
253impl Default for ProgressRegistry {
254 fn default() -> Self {
255 let (broadcast_tx, _) = broadcast::channel(256);
256 Self {
257 events: SegQueue::new(),
258 broadcast_tx,
259 next_id: AtomicU64::new(1),
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_progress_handle() {
270 let registry = ProgressRegistry::new();
271 let handle = registry.start_operation("test-source", ProgressGranularity::Fine);
272
273 handle.phase(LoadPhase::Connecting);
274 handle.progress(100, Some(1000), 8000);
275 handle.complete(1000);
276
277 let events = registry.poll_all();
278 assert_eq!(events.len(), 3);
279
280 matches!(
281 &events[0],
282 ProgressEvent::Phase {
283 phase: LoadPhase::Connecting,
284 ..
285 }
286 );
287 matches!(
288 &events[1],
289 ProgressEvent::Progress {
290 rows_processed: 100,
291 ..
292 }
293 );
294 matches!(
295 &events[2],
296 ProgressEvent::Complete {
297 rows_loaded: 1000,
298 ..
299 }
300 );
301 }
302
303 #[test]
304 fn test_coarse_granularity_skips_progress() {
305 let registry = ProgressRegistry::new();
306 let handle = registry.start_operation("test-source", ProgressGranularity::Coarse);
307
308 handle.phase(LoadPhase::Fetching);
309 handle.progress(100, Some(1000), 8000); handle.complete(1000);
311
312 let events = registry.poll_all();
313 assert_eq!(events.len(), 2); }
315
316 #[test]
317 fn test_load_phase_from_u8() {
318 assert_eq!(LoadPhase::from_u8(0), Some(LoadPhase::Connecting));
319 assert_eq!(LoadPhase::from_u8(4), Some(LoadPhase::Converting));
320 assert_eq!(LoadPhase::from_u8(99), None);
321 }
322}