1use crate::context::ExecutionContext;
10use crate::data::Timeframe;
11use crate::plugins::PluginLoader;
12use parking_lot::RwLock;
13use shape_ast::ast::{Statement, StreamDef, VariableDecl};
14use shape_ast::error::{Result, ShapeError};
15use shape_value::ValueWord;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::mpsc;
19
20#[derive(Debug)]
22pub struct StreamState {
23 pub name: String,
25 pub variables: HashMap<String, ValueWord>,
27 pub subscriptions: Vec<u64>,
29 pub running: bool,
31}
32
33#[derive(Debug, Clone)]
35pub enum StreamEvent {
36 Tick {
38 id: String,
40 fields: std::collections::HashMap<String, ValueWord>,
42 timestamp: i64,
44 },
45 Data {
47 id: String,
49 timeframe: Option<Timeframe>,
51 fields: std::collections::HashMap<String, ValueWord>,
53 timestamp: i64,
55 },
56 Connected,
58 Disconnected,
60 Error { message: String },
62 Shutdown,
64}
65
66pub struct StreamExecutor {
68 plugin_loader: Arc<RwLock<PluginLoader>>,
70 evaluator: Option<Arc<dyn crate::engine::ExpressionEvaluator>>,
72 streams: HashMap<String, StreamState>,
74 event_rx: Option<mpsc::Receiver<(String, StreamEvent)>>,
76 event_tx: mpsc::Sender<(String, StreamEvent)>,
78}
79
80impl StreamExecutor {
81 pub fn new(plugin_loader: Arc<RwLock<PluginLoader>>) -> Self {
83 let (event_tx, event_rx) = mpsc::channel(1000);
84 Self {
85 plugin_loader,
86 evaluator: None,
87 streams: HashMap::new(),
88 event_rx: Some(event_rx),
89 event_tx,
90 }
91 }
92
93 pub fn with_evaluator(
95 plugin_loader: Arc<RwLock<PluginLoader>>,
96 evaluator: Arc<dyn crate::engine::ExpressionEvaluator>,
97 ) -> Self {
98 let (event_tx, event_rx) = mpsc::channel(1000);
99 Self {
100 plugin_loader,
101 evaluator: Some(evaluator),
102 streams: HashMap::new(),
103 event_rx: Some(event_rx),
104 event_tx,
105 }
106 }
107
108 pub fn event_sender(&self) -> mpsc::Sender<(String, StreamEvent)> {
110 self.event_tx.clone()
111 }
112
113 pub async fn start_stream(
115 &mut self,
116 stream_def: &StreamDef,
117 ctx: &mut ExecutionContext,
118 ) -> Result<()> {
119 let stream_name = stream_def.name.clone();
120
121 if self.streams.contains_key(&stream_name) {
123 return Err(ShapeError::RuntimeError {
124 message: format!("Stream '{}' is already running", stream_name),
125 location: None,
126 });
127 }
128
129 let mut state_vars = HashMap::new();
131 for var_decl in &stream_def.state {
132 let value = self.initialize_variable(var_decl, ctx)?;
133 for ident in var_decl.pattern.get_identifiers() {
135 state_vars.insert(ident, value.clone());
136 }
137 }
138
139 let stream_state = StreamState {
141 name: stream_name.clone(),
142 variables: state_vars,
143 subscriptions: Vec::new(),
144 running: true,
145 };
146
147 self.streams.insert(stream_name.clone(), stream_state);
148
149 if let Some(on_connect) = &stream_def.on_connect {
151 self.execute_handler(on_connect, &stream_name, HashMap::new(), ctx)?;
152 }
153
154 self.subscribe_to_data(stream_def, ctx).await?;
156
157 Ok(())
158 }
159
160 pub fn stop_stream(&mut self, name: &str) -> Result<()> {
162 let stream_state = self
163 .streams
164 .get_mut(name)
165 .ok_or_else(|| ShapeError::RuntimeError {
166 message: format!("Stream '{}' is not running", name),
167 location: None,
168 })?;
169
170 stream_state.running = false;
171
172 let plugin_loader = self.plugin_loader.read();
174 for sub_id in &stream_state.subscriptions {
175 let _ = sub_id;
178 }
179 drop(plugin_loader);
180
181 self.streams.remove(name);
182 Ok(())
183 }
184
185 pub fn handle_event(
187 &mut self,
188 stream_name: &str,
189 event: StreamEvent,
190 stream_def: &StreamDef,
191 ctx: &mut ExecutionContext,
192 ) -> Result<()> {
193 let stream_state =
194 self.streams
195 .get_mut(stream_name)
196 .ok_or_else(|| ShapeError::RuntimeError {
197 message: format!("Stream '{}' is not running", stream_name),
198 location: None,
199 })?;
200
201 if !stream_state.running {
202 return Ok(());
203 }
204
205 match event {
206 StreamEvent::Tick {
207 id,
208 fields,
209 timestamp,
210 } => {
211 if let Some(on_event) = &stream_def.on_event {
212 let mut nb_fields: Vec<(String, ValueWord)> =
214 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
215 nb_fields.push(("id".to_string(), ValueWord::from_string(Arc::new(id))));
216 nb_fields.push((
217 "timestamp".to_string(),
218 ValueWord::from_f64(timestamp as f64),
219 ));
220
221 let pairs: Vec<(&str, ValueWord)> = nb_fields
222 .iter()
223 .map(|(k, v)| (k.as_str(), v.clone()))
224 .collect();
225 let event_obj = crate::type_schema::typed_object_from_nb_pairs(&pairs);
226
227 let mut params = HashMap::new();
228 params.insert(on_event.event_param.clone(), event_obj.clone());
229
230 self.execute_handler(&on_event.body, stream_name, params, ctx)?;
231 }
232 }
233
234 StreamEvent::Data {
235 id,
236 timeframe,
237 fields,
238 timestamp,
239 } => {
240 if let Some(on_window) = &stream_def.on_window {
241 let mut nb_fields: Vec<(String, ValueWord)> =
243 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
244 nb_fields.push((
245 "timestamp".to_string(),
246 ValueWord::from_f64(timestamp as f64),
247 ));
248 if let Some(tf) = timeframe {
249 nb_fields.push(("timeframe".to_string(), ValueWord::from_timeframe(tf)));
250 }
251
252 let pairs: Vec<(&str, ValueWord)> = nb_fields
253 .iter()
254 .map(|(k, v)| (k.as_str(), v.clone()))
255 .collect();
256 let window_obj = crate::type_schema::typed_object_from_nb_pairs(&pairs);
257
258 let mut params = HashMap::new();
259 params.insert(
260 on_window.key_param.clone(),
261 ValueWord::from_string(Arc::new(id.clone())),
262 );
263 params.insert(on_window.window_param.clone(), window_obj.clone());
264
265 self.execute_handler(&on_window.body, stream_name, params, ctx)?;
266 }
267 }
268
269 StreamEvent::Connected => {
270 if let Some(on_connect) = &stream_def.on_connect {
271 self.execute_handler(on_connect, stream_name, HashMap::new(), ctx)?;
272 }
273 }
274
275 StreamEvent::Disconnected => {
276 if let Some(on_disconnect) = &stream_def.on_disconnect {
277 self.execute_handler(on_disconnect, stream_name, HashMap::new(), ctx)?;
278 }
279 }
280
281 StreamEvent::Error { message } => {
282 if let Some(on_error) = &stream_def.on_error {
283 let mut params = HashMap::new();
284 params.insert(
285 on_error.error_param.clone(),
286 ValueWord::from_string(Arc::new(message)),
287 );
288 self.execute_handler(&on_error.body, stream_name, params, ctx)?;
289 }
290 }
291
292 StreamEvent::Shutdown => {
293 stream_state.running = false;
294 }
295 }
296
297 Ok(())
298 }
299
300 fn execute_handler(
302 &mut self,
303 statements: &[Statement],
304 stream_name: &str,
305 params: HashMap<String, ValueWord>,
306 ctx: &mut ExecutionContext,
307 ) -> Result<()> {
308 let stream_state = self.streams.get_mut(stream_name);
310
311 if let Some(state) = stream_state {
313 for (name, value) in &state.variables {
314 let _ = ctx.set_variable_nb(name, value.clone());
315 }
316 }
317
318 for (name, value) in params {
320 let _ = ctx.set_variable(&name, value);
321 }
322
323 if let Some(ref evaluator) = self.evaluator {
325 let _ = evaluator.eval_statements(statements, ctx)?;
326 }
327
328 if let Some(state) = self.streams.get_mut(stream_name) {
330 for name in state.variables.keys().cloned().collect::<Vec<_>>() {
331 if let Ok(Some(value)) = ctx.get_variable_nb(&name) {
332 state.variables.insert(name, value.clone());
333 }
334 }
335 }
336
337 Ok(())
338 }
339
340 fn initialize_variable(
342 &self,
343 var_decl: &VariableDecl,
344 ctx: &mut ExecutionContext,
345 ) -> Result<ValueWord> {
346 if let Some(init_expr) = &var_decl.value {
347 if let Some(ref evaluator) = self.evaluator {
348 Ok(evaluator.eval_expr(init_expr, ctx)?)
349 } else {
350 Ok(ValueWord::none())
351 }
352 } else {
353 Ok(ValueWord::none())
354 }
355 }
356
357 async fn subscribe_to_data(
359 &mut self,
360 stream_def: &StreamDef,
361 _ctx: &mut ExecutionContext,
362 ) -> Result<()> {
363 let config = &stream_def.config;
364 let stream_name = stream_def.name.clone();
365
366 let plugin_loader = self.plugin_loader.read();
368 if plugin_loader
369 .get_data_source_vtable(&config.provider)
370 .is_err()
371 {
372 return Err(ShapeError::RuntimeError {
373 message: format!("Data source plugin '{}' not found", config.provider),
374 location: None,
375 });
376 }
377 drop(plugin_loader);
378
379 tracing::info!(
381 "Stream '{}' subscribing to {} symbols via provider '{}'",
382 stream_name,
383 config.symbols.len(),
384 config.provider
385 );
386
387 Ok(())
392 }
393
394 pub async fn run_event_loop(
396 &mut self,
397 stream_defs: HashMap<String, StreamDef>,
398 ctx: &mut ExecutionContext,
399 ) -> Result<()> {
400 let mut rx = self
401 .event_rx
402 .take()
403 .ok_or_else(|| ShapeError::RuntimeError {
404 message: "Event loop already running".to_string(),
405 location: None,
406 })?;
407
408 while let Some((stream_name, event)) = rx.recv().await {
409 if let Some(stream_def) = stream_defs.get(&stream_name) {
410 if let Err(e) = self.handle_event(&stream_name, event.clone(), stream_def, ctx) {
411 tracing::error!("Error handling stream event: {}", e);
412 }
413
414 if matches!(event, StreamEvent::Shutdown) {
416 break;
417 }
418 }
419 }
420
421 Ok(())
422 }
423
424 pub fn list_streams(&self) -> Vec<&str> {
426 self.streams.keys().map(|s| s.as_str()).collect()
427 }
428
429 pub fn is_running(&self, name: &str) -> bool {
431 self.streams.get(name).map(|s| s.running).unwrap_or(false)
432 }
433
434 pub fn get_state(&self, name: &str) -> Option<&StreamState> {
436 self.streams.get(name)
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_stream_event_types() {
446 let mut fields = std::collections::HashMap::new();
447 fields.insert("price".to_string(), ValueWord::from_f64(150.0));
448 fields.insert("volume".to_string(), ValueWord::from_f64(1000.0));
449
450 let tick = StreamEvent::Tick {
451 id: "AAPL".to_string(),
452 fields,
453 timestamp: 1234567890,
454 };
455
456 match tick {
457 StreamEvent::Tick { id, fields, .. } => {
458 assert_eq!(id, "AAPL");
459 assert_eq!(fields.get("price"), Some(&ValueWord::from_f64(150.0)));
460 }
461 _ => panic!("Expected tick event"),
462 }
463 }
464
465 #[test]
466 fn test_stream_state_creation() {
467 let state = StreamState {
468 name: "test_stream".to_string(),
469 variables: HashMap::new(),
470 subscriptions: Vec::new(),
471 running: true,
472 };
473
474 assert!(state.running);
475 assert_eq!(state.name, "test_stream");
476 }
477}