Skip to main content

shape_runtime/
stream_executor.rs

1//! Stream executor for real-time data processing
2//!
3//! This module executes StreamDef blocks by:
4//! 1. Connecting to data providers via plugins
5//! 2. Managing subscriptions
6//! 3. Invoking handlers (on_tick, on_bar, on_connect, on_disconnect, on_error)
7//! 4. Maintaining stream state across callbacks
8
9use crate::context::ExecutionContext;
10use crate::data::Timeframe;
11use crate::plugins::PluginLoader;
12use parking_lot::RwLock;
13use shape_ast::ast::{Statement, StreamDef, VariableDecl};
14use shape_ast::error::{Result, ShapeError};
15use shape_value::ValueWord;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::mpsc;
19
20/// Stream execution state
21#[derive(Debug)]
22pub struct StreamState {
23    /// Stream name
24    pub name: String,
25    /// State variables (stored as ValueWord for efficiency)
26    pub variables: HashMap<String, ValueWord>,
27    /// Active subscription IDs
28    pub subscriptions: Vec<u64>,
29    /// Is the stream running
30    pub running: bool,
31}
32
33/// Message types for stream callbacks
34#[derive(Debug, Clone)]
35pub enum StreamEvent {
36    /// Generic tick/update data (single data point)
37    Tick {
38        /// Data identifier (e.g., symbol, device_id, sensor_id)
39        id: String,
40        /// Dynamic field data (price, temperature, power, etc.)
41        fields: std::collections::HashMap<String, ValueWord>,
42        /// Timestamp of the update
43        timestamp: i64,
44    },
45    /// Generic aggregated data (bar, candle, summary, etc.)
46    Data {
47        /// Data identifier (e.g., symbol, device_id, sensor_id)
48        id: String,
49        /// Optional timeframe/period (e.g., 1m, 5m, 1h)
50        timeframe: Option<Timeframe>,
51        /// Dynamic field data (open/high/low/close/volume or ANY fields)
52        fields: std::collections::HashMap<String, ValueWord>,
53        /// Timestamp when aggregation occurred
54        timestamp: i64,
55    },
56    /// Connection established
57    Connected,
58    /// Connection lost
59    Disconnected,
60    /// Error occurred
61    Error { message: String },
62    /// Shutdown signal
63    Shutdown,
64}
65
66/// Stream executor manages real-time data streams
67pub struct StreamExecutor {
68    /// Plugin loader for data sources
69    plugin_loader: Arc<RwLock<PluginLoader>>,
70    /// Expression evaluator for executing handlers
71    evaluator: Option<Arc<dyn crate::engine::ExpressionEvaluator>>,
72    /// Active streams
73    streams: HashMap<String, StreamState>,
74    /// Event channel for receiving stream events
75    event_rx: Option<mpsc::Receiver<(String, StreamEvent)>>,
76    /// Event sender for sending stream events
77    event_tx: mpsc::Sender<(String, StreamEvent)>,
78}
79
80impl StreamExecutor {
81    /// Create a new stream executor
82    pub fn new(plugin_loader: Arc<RwLock<PluginLoader>>) -> Self {
83        let (event_tx, event_rx) = mpsc::channel(1000);
84        Self {
85            plugin_loader,
86            evaluator: None,
87            streams: HashMap::new(),
88            event_rx: Some(event_rx),
89            event_tx,
90        }
91    }
92
93    /// Create a stream executor with an expression evaluator for handler execution
94    pub fn with_evaluator(
95        plugin_loader: Arc<RwLock<PluginLoader>>,
96        evaluator: Arc<dyn crate::engine::ExpressionEvaluator>,
97    ) -> Self {
98        let (event_tx, event_rx) = mpsc::channel(1000);
99        Self {
100            plugin_loader,
101            evaluator: Some(evaluator),
102            streams: HashMap::new(),
103            event_rx: Some(event_rx),
104            event_tx,
105        }
106    }
107
108    /// Get event sender for plugins to send events
109    pub fn event_sender(&self) -> mpsc::Sender<(String, StreamEvent)> {
110        self.event_tx.clone()
111    }
112
113    /// Start a stream from a StreamDef
114    pub async fn start_stream(
115        &mut self,
116        stream_def: &StreamDef,
117        ctx: &mut ExecutionContext,
118    ) -> Result<()> {
119        let stream_name = stream_def.name.clone();
120
121        // Check if already running
122        if self.streams.contains_key(&stream_name) {
123            return Err(ShapeError::RuntimeError {
124                message: format!("Stream '{}' is already running", stream_name),
125                location: None,
126            });
127        }
128
129        // Initialize state variables
130        let mut state_vars = HashMap::new();
131        for var_decl in &stream_def.state {
132            let value = self.initialize_variable(var_decl, ctx)?;
133            // Get all identifiers from the pattern
134            for ident in var_decl.pattern.get_identifiers() {
135                state_vars.insert(ident, value.clone());
136            }
137        }
138
139        // Create stream state
140        let stream_state = StreamState {
141            name: stream_name.clone(),
142            variables: state_vars,
143            subscriptions: Vec::new(),
144            running: true,
145        };
146
147        self.streams.insert(stream_name.clone(), stream_state);
148
149        // Fire on_connect handler
150        if let Some(on_connect) = &stream_def.on_connect {
151            self.execute_handler(on_connect, &stream_name, HashMap::new(), ctx)?;
152        }
153
154        // Subscribe to data
155        self.subscribe_to_data(stream_def, ctx).await?;
156
157        Ok(())
158    }
159
160    /// Stop a running stream
161    pub fn stop_stream(&mut self, name: &str) -> Result<()> {
162        let stream_state = self
163            .streams
164            .get_mut(name)
165            .ok_or_else(|| ShapeError::RuntimeError {
166                message: format!("Stream '{}' is not running", name),
167                location: None,
168            })?;
169
170        stream_state.running = false;
171
172        // Unsubscribe from all data sources
173        let plugin_loader = self.plugin_loader.read();
174        for sub_id in &stream_state.subscriptions {
175            // Try to find the plugin and unsubscribe
176            // Note: We'd need to track which plugin each subscription belongs to
177            let _ = sub_id;
178        }
179        drop(plugin_loader);
180
181        self.streams.remove(name);
182        Ok(())
183    }
184
185    /// Handle an incoming event
186    pub fn handle_event(
187        &mut self,
188        stream_name: &str,
189        event: StreamEvent,
190        stream_def: &StreamDef,
191        ctx: &mut ExecutionContext,
192    ) -> Result<()> {
193        let stream_state =
194            self.streams
195                .get_mut(stream_name)
196                .ok_or_else(|| ShapeError::RuntimeError {
197                    message: format!("Stream '{}' is not running", stream_name),
198                    location: None,
199                })?;
200
201        if !stream_state.running {
202            return Ok(());
203        }
204
205        match event {
206            StreamEvent::Tick {
207                id,
208                fields,
209                timestamp,
210            } => {
211                if let Some(on_event) = &stream_def.on_event {
212                    // Convert fields to ValueWord and add metadata
213                    let mut nb_fields: Vec<(String, ValueWord)> =
214                        fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
215                    nb_fields.push(("id".to_string(), ValueWord::from_string(Arc::new(id))));
216                    nb_fields.push((
217                        "timestamp".to_string(),
218                        ValueWord::from_f64(timestamp as f64),
219                    ));
220
221                    let pairs: Vec<(&str, ValueWord)> = nb_fields
222                        .iter()
223                        .map(|(k, v)| (k.as_str(), v.clone()))
224                        .collect();
225                    let event_obj = crate::type_schema::typed_object_from_nb_pairs(&pairs);
226
227                    let mut params = HashMap::new();
228                    params.insert(on_event.event_param.clone(), event_obj.clone());
229
230                    self.execute_handler(&on_event.body, stream_name, params, ctx)?;
231                }
232            }
233
234            StreamEvent::Data {
235                id,
236                timeframe,
237                fields,
238                timestamp,
239            } => {
240                if let Some(on_window) = &stream_def.on_window {
241                    // Convert fields to ValueWord and add metadata
242                    let mut nb_fields: Vec<(String, ValueWord)> =
243                        fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
244                    nb_fields.push((
245                        "timestamp".to_string(),
246                        ValueWord::from_f64(timestamp as f64),
247                    ));
248                    if let Some(tf) = timeframe {
249                        nb_fields.push(("timeframe".to_string(), ValueWord::from_timeframe(tf)));
250                    }
251
252                    let pairs: Vec<(&str, ValueWord)> = nb_fields
253                        .iter()
254                        .map(|(k, v)| (k.as_str(), v.clone()))
255                        .collect();
256                    let window_obj = crate::type_schema::typed_object_from_nb_pairs(&pairs);
257
258                    let mut params = HashMap::new();
259                    params.insert(
260                        on_window.key_param.clone(),
261                        ValueWord::from_string(Arc::new(id.clone())),
262                    );
263                    params.insert(on_window.window_param.clone(), window_obj.clone());
264
265                    self.execute_handler(&on_window.body, stream_name, params, ctx)?;
266                }
267            }
268
269            StreamEvent::Connected => {
270                if let Some(on_connect) = &stream_def.on_connect {
271                    self.execute_handler(on_connect, stream_name, HashMap::new(), ctx)?;
272                }
273            }
274
275            StreamEvent::Disconnected => {
276                if let Some(on_disconnect) = &stream_def.on_disconnect {
277                    self.execute_handler(on_disconnect, stream_name, HashMap::new(), ctx)?;
278                }
279            }
280
281            StreamEvent::Error { message } => {
282                if let Some(on_error) = &stream_def.on_error {
283                    let mut params = HashMap::new();
284                    params.insert(
285                        on_error.error_param.clone(),
286                        ValueWord::from_string(Arc::new(message)),
287                    );
288                    self.execute_handler(&on_error.body, stream_name, params, ctx)?;
289                }
290            }
291
292            StreamEvent::Shutdown => {
293                stream_state.running = false;
294            }
295        }
296
297        Ok(())
298    }
299
300    /// Execute a handler with given parameters
301    fn execute_handler(
302        &mut self,
303        statements: &[Statement],
304        stream_name: &str,
305        params: HashMap<String, ValueWord>,
306        ctx: &mut ExecutionContext,
307    ) -> Result<()> {
308        // Get stream state variables
309        let stream_state = self.streams.get_mut(stream_name);
310
311        // Set up context with stream variables and parameters
312        if let Some(state) = stream_state {
313            for (name, value) in &state.variables {
314                let _ = ctx.set_variable_nb(name, value.clone());
315            }
316        }
317
318        // Add handler parameters to context
319        for (name, value) in params {
320            let _ = ctx.set_variable(&name, value);
321        }
322
323        // Execute statements via evaluator
324        if let Some(ref evaluator) = self.evaluator {
325            let _ = evaluator.eval_statements(statements, ctx)?;
326        }
327
328        // Update stream state variables from context
329        if let Some(state) = self.streams.get_mut(stream_name) {
330            for name in state.variables.keys().cloned().collect::<Vec<_>>() {
331                if let Ok(Some(value)) = ctx.get_variable_nb(&name) {
332                    state.variables.insert(name, value.clone());
333                }
334            }
335        }
336
337        Ok(())
338    }
339
340    /// Initialize a variable declaration
341    fn initialize_variable(
342        &self,
343        var_decl: &VariableDecl,
344        ctx: &mut ExecutionContext,
345    ) -> Result<ValueWord> {
346        if let Some(init_expr) = &var_decl.value {
347            if let Some(ref evaluator) = self.evaluator {
348                Ok(evaluator.eval_expr(init_expr, ctx)?)
349            } else {
350                Ok(ValueWord::none())
351            }
352        } else {
353            Ok(ValueWord::none())
354        }
355    }
356
357    /// Subscribe to data sources based on stream config
358    async fn subscribe_to_data(
359        &mut self,
360        stream_def: &StreamDef,
361        _ctx: &mut ExecutionContext,
362    ) -> Result<()> {
363        let config = &stream_def.config;
364        let stream_name = stream_def.name.clone();
365
366        // Verify the plugin exists
367        let plugin_loader = self.plugin_loader.read();
368        if plugin_loader
369            .get_data_source_vtable(&config.provider)
370            .is_err()
371        {
372            return Err(ShapeError::RuntimeError {
373                message: format!("Data source plugin '{}' not found", config.provider),
374                location: None,
375            });
376        }
377        drop(plugin_loader);
378
379        // Log subscription setup
380        tracing::info!(
381            "Stream '{}' subscribing to {} symbols via provider '{}'",
382            stream_name,
383            config.symbols.len(),
384            config.provider
385        );
386
387        // TODO: Implement actual plugin subscription via FFI callback
388        // For now, subscriptions are set up externally and events are pushed
389        // via the event_sender() method
390
391        Ok(())
392    }
393
394    /// Run the event loop (call this in an async context)
395    pub async fn run_event_loop(
396        &mut self,
397        stream_defs: HashMap<String, StreamDef>,
398        ctx: &mut ExecutionContext,
399    ) -> Result<()> {
400        let mut rx = self
401            .event_rx
402            .take()
403            .ok_or_else(|| ShapeError::RuntimeError {
404                message: "Event loop already running".to_string(),
405                location: None,
406            })?;
407
408        while let Some((stream_name, event)) = rx.recv().await {
409            if let Some(stream_def) = stream_defs.get(&stream_name) {
410                if let Err(e) = self.handle_event(&stream_name, event.clone(), stream_def, ctx) {
411                    tracing::error!("Error handling stream event: {}", e);
412                }
413
414                // Check for shutdown
415                if matches!(event, StreamEvent::Shutdown) {
416                    break;
417                }
418            }
419        }
420
421        Ok(())
422    }
423
424    /// Get list of running streams
425    pub fn list_streams(&self) -> Vec<&str> {
426        self.streams.keys().map(|s| s.as_str()).collect()
427    }
428
429    /// Check if a stream is running
430    pub fn is_running(&self, name: &str) -> bool {
431        self.streams.get(name).map(|s| s.running).unwrap_or(false)
432    }
433
434    /// Get stream state for inspection
435    pub fn get_state(&self, name: &str) -> Option<&StreamState> {
436        self.streams.get(name)
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_stream_event_types() {
446        let mut fields = std::collections::HashMap::new();
447        fields.insert("price".to_string(), ValueWord::from_f64(150.0));
448        fields.insert("volume".to_string(), ValueWord::from_f64(1000.0));
449
450        let tick = StreamEvent::Tick {
451            id: "AAPL".to_string(),
452            fields,
453            timestamp: 1234567890,
454        };
455
456        match tick {
457            StreamEvent::Tick { id, fields, .. } => {
458                assert_eq!(id, "AAPL");
459                assert_eq!(fields.get("price"), Some(&ValueWord::from_f64(150.0)));
460            }
461            _ => panic!("Expected tick event"),
462        }
463    }
464
465    #[test]
466    fn test_stream_state_creation() {
467        let state = StreamState {
468            name: "test_stream".to_string(),
469            variables: HashMap::new(),
470            subscriptions: Vec::new(),
471            running: true,
472        };
473
474        assert!(state.running);
475        assert_eq!(state.name, "test_stream");
476    }
477}