Skip to main content

shape_runtime/
progress.rs

1//! Progress reporting system for data loading operations
2//!
3//! Provides a shared observable for monitoring load operations in TUI/REPL.
4//! Uses a lock-free queue for progress events and broadcast channel for subscribers.
5
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Instant;
9
10use crossbeam_queue::SegQueue;
11use tokio::sync::broadcast;
12
13/// Phase of a load operation
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(u8)]
16pub enum LoadPhase {
17    /// Establishing connection to data source
18    Connecting = 0,
19    /// Executing query
20    Querying = 1,
21    /// Fetching data from source
22    Fetching = 2,
23    /// Parsing received data
24    Parsing = 3,
25    /// Converting to table format
26    Converting = 4,
27}
28
29impl LoadPhase {
30    /// Convert from u8 for FFI
31    pub fn from_u8(value: u8) -> Option<Self> {
32        match value {
33            0 => Some(Self::Connecting),
34            1 => Some(Self::Querying),
35            2 => Some(Self::Fetching),
36            3 => Some(Self::Parsing),
37            4 => Some(Self::Converting),
38            _ => None,
39        }
40    }
41
42    /// Human-readable name
43    pub fn as_str(&self) -> &'static str {
44        match self {
45            Self::Connecting => "Connecting",
46            Self::Querying => "Querying",
47            Self::Fetching => "Fetching",
48            Self::Parsing => "Parsing",
49            Self::Converting => "Converting",
50        }
51    }
52}
53
54/// Granularity of progress reporting
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56#[repr(u8)]
57pub enum ProgressGranularity {
58    /// Only report phase changes (low overhead)
59    #[default]
60    Coarse = 0,
61    /// Report row counts and percentages (higher overhead)
62    Fine = 1,
63}
64
65impl ProgressGranularity {
66    /// Convert from u8 for FFI
67    pub fn from_u8(value: u8) -> Self {
68        match value {
69            1 => Self::Fine,
70            _ => Self::Coarse,
71        }
72    }
73}
74
75/// Progress event emitted during data loading
76#[derive(Debug, Clone)]
77pub enum ProgressEvent {
78    /// Phase change (coarse-grained)
79    Phase {
80        operation_id: u64,
81        phase: LoadPhase,
82        source: String,
83    },
84
85    /// Progress within a phase (fine-grained)
86    Progress {
87        operation_id: u64,
88        rows_processed: u64,
89        total_rows: Option<u64>,
90        bytes_processed: u64,
91    },
92
93    /// Operation completed successfully
94    Complete {
95        operation_id: u64,
96        rows_loaded: u64,
97        duration_ms: u64,
98    },
99
100    /// Operation failed
101    Error { operation_id: u64, message: String },
102}
103
104impl ProgressEvent {
105    /// Get the operation ID
106    pub fn operation_id(&self) -> u64 {
107        match self {
108            Self::Phase { operation_id, .. } => *operation_id,
109            Self::Progress { operation_id, .. } => *operation_id,
110            Self::Complete { operation_id, .. } => *operation_id,
111            Self::Error { operation_id, .. } => *operation_id,
112        }
113    }
114}
115
116/// Handle for reporting progress on a specific operation
117pub struct ProgressHandle {
118    operation_id: u64,
119    source: String,
120    registry: Arc<ProgressRegistry>,
121    start_time: Instant,
122    granularity: ProgressGranularity,
123}
124
125impl ProgressHandle {
126    /// Report a phase change
127    pub fn phase(&self, phase: LoadPhase) {
128        self.registry.emit(ProgressEvent::Phase {
129            operation_id: self.operation_id,
130            phase,
131            source: self.source.clone(),
132        });
133    }
134
135    /// Report fine-grained progress (only emits if granularity is Fine)
136    pub fn progress(&self, rows_processed: u64, total_rows: Option<u64>, bytes_processed: u64) {
137        if self.granularity == ProgressGranularity::Fine {
138            self.registry.emit(ProgressEvent::Progress {
139                operation_id: self.operation_id,
140                rows_processed,
141                total_rows,
142                bytes_processed,
143            });
144        }
145    }
146
147    /// Mark operation as complete
148    pub fn complete(self, rows_loaded: u64) {
149        let duration_ms = self.start_time.elapsed().as_millis() as u64;
150        self.registry.emit(ProgressEvent::Complete {
151            operation_id: self.operation_id,
152            rows_loaded,
153            duration_ms,
154        });
155    }
156
157    /// Mark operation as failed
158    pub fn error(self, message: String) {
159        self.registry.emit(ProgressEvent::Error {
160            operation_id: self.operation_id,
161            message,
162        });
163    }
164
165    /// Get the operation ID
166    pub fn operation_id(&self) -> u64 {
167        self.operation_id
168    }
169
170    /// Get the granularity setting
171    pub fn granularity(&self) -> ProgressGranularity {
172        self.granularity
173    }
174}
175
176/// Global registry for progress events
177///
178/// Uses a lock-free queue for event storage and broadcast channel for real-time subscribers.
179pub struct ProgressRegistry {
180    /// Lock-free queue for polling events
181    events: SegQueue<ProgressEvent>,
182    /// Broadcast channel for real-time subscribers
183    broadcast_tx: broadcast::Sender<ProgressEvent>,
184    /// Next operation ID
185    next_id: AtomicU64,
186}
187
188impl ProgressRegistry {
189    /// Create a new progress registry
190    pub fn new() -> Arc<Self> {
191        let (broadcast_tx, _) = broadcast::channel(256);
192        Arc::new(Self {
193            events: SegQueue::new(),
194            broadcast_tx,
195            next_id: AtomicU64::new(1),
196        })
197    }
198
199    /// Start a new operation and return a handle for reporting progress
200    pub fn start_operation(
201        self: &Arc<Self>,
202        source: &str,
203        granularity: ProgressGranularity,
204    ) -> ProgressHandle {
205        let operation_id = self.next_id.fetch_add(1, Ordering::SeqCst);
206        ProgressHandle {
207            operation_id,
208            source: source.to_string(),
209            registry: Arc::clone(self),
210            start_time: Instant::now(),
211            granularity,
212        }
213    }
214
215    /// Emit a progress event
216    fn emit(&self, event: ProgressEvent) {
217        // Store in queue for polling
218        self.events.push(event.clone());
219        // Broadcast to subscribers (ignore send errors - no subscribers is OK)
220        let _ = self.broadcast_tx.send(event);
221    }
222
223    /// Subscribe to real-time progress events
224    pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
225        self.broadcast_tx.subscribe()
226    }
227
228    /// Poll for a single event (non-blocking)
229    pub fn poll(&self) -> Option<ProgressEvent> {
230        self.events.pop()
231    }
232
233    /// Poll all available events (non-blocking)
234    pub fn poll_all(&self) -> Vec<ProgressEvent> {
235        let mut events = Vec::new();
236        while let Some(event) = self.events.pop() {
237            events.push(event);
238        }
239        events
240    }
241
242    /// Try to receive a single event (non-blocking, alias for poll)
243    pub fn try_recv(&self) -> Option<ProgressEvent> {
244        self.poll()
245    }
246
247    /// Check if the queue is empty
248    pub fn is_empty(&self) -> bool {
249        self.events.is_empty()
250    }
251}
252
253impl Default for ProgressRegistry {
254    fn default() -> Self {
255        let (broadcast_tx, _) = broadcast::channel(256);
256        Self {
257            events: SegQueue::new(),
258            broadcast_tx,
259            next_id: AtomicU64::new(1),
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_progress_handle() {
270        let registry = ProgressRegistry::new();
271        let handle = registry.start_operation("test-source", ProgressGranularity::Fine);
272
273        handle.phase(LoadPhase::Connecting);
274        handle.progress(100, Some(1000), 8000);
275        handle.complete(1000);
276
277        let events = registry.poll_all();
278        assert_eq!(events.len(), 3);
279
280        matches!(
281            &events[0],
282            ProgressEvent::Phase {
283                phase: LoadPhase::Connecting,
284                ..
285            }
286        );
287        matches!(
288            &events[1],
289            ProgressEvent::Progress {
290                rows_processed: 100,
291                ..
292            }
293        );
294        matches!(
295            &events[2],
296            ProgressEvent::Complete {
297                rows_loaded: 1000,
298                ..
299            }
300        );
301    }
302
303    #[test]
304    fn test_coarse_granularity_skips_progress() {
305        let registry = ProgressRegistry::new();
306        let handle = registry.start_operation("test-source", ProgressGranularity::Coarse);
307
308        handle.phase(LoadPhase::Fetching);
309        handle.progress(100, Some(1000), 8000); // Should be skipped
310        handle.complete(1000);
311
312        let events = registry.poll_all();
313        assert_eq!(events.len(), 2); // Only Phase and Complete, no Progress
314    }
315
316    #[test]
317    fn test_load_phase_from_u8() {
318        assert_eq!(LoadPhase::from_u8(0), Some(LoadPhase::Connecting));
319        assert_eq!(LoadPhase::from_u8(4), Some(LoadPhase::Converting));
320        assert_eq!(LoadPhase::from_u8(99), None);
321    }
322}