rust_logic_graph/streaming/
stream_node.rs

1//! Stream-based node implementation
2
3use crate::core::Context;
4use crate::node::{Node, NodeType};
5use crate::rule::{RuleError, RuleResult};
6use crate::streaming::{
7    apply_backpressure, create_chunked_stream, create_stream_from_vec, BackpressureConfig,
8    ChunkConfig, StreamProcessor,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::sync::Arc;
13use tokio_stream::StreamExt;
14use tracing::{error, info};
15
16/// Stream node for processing data streams
17#[derive(Clone)]
18pub struct StreamNode {
19    pub id: String,
20    pub processor: Arc<dyn StreamProcessor>,
21    pub backpressure_config: BackpressureConfig,
22    pub chunk_config: Option<ChunkConfig>,
23    pub collect_results: bool,
24}
25
26impl StreamNode {
27    /// Create a new stream node
28    pub fn new(id: impl Into<String>, processor: Arc<dyn StreamProcessor>) -> Self {
29        Self {
30            id: id.into(),
31            processor,
32            backpressure_config: BackpressureConfig::default(),
33            chunk_config: None,
34            collect_results: true,
35        }
36    }
37
38    /// Configure backpressure
39    pub fn with_backpressure(mut self, config: BackpressureConfig) -> Self {
40        self.backpressure_config = config;
41        self
42    }
43
44    /// Enable chunked processing for large datasets
45    pub fn with_chunking(mut self, config: ChunkConfig) -> Self {
46        self.chunk_config = Some(config);
47        self
48    }
49
50    /// Set whether to collect all results (default: true)
51    /// If false, only the last result is stored
52    pub fn collect_results(mut self, collect: bool) -> Self {
53        self.collect_results = collect;
54        self
55    }
56
57    /// Process a stream of data
58    pub async fn process_stream(&self, data: Vec<Value>, ctx: &Context) -> RuleResult {
59        info!("StreamNode[{}]: Processing {} items", self.id, data.len());
60
61        if let Some(chunk_config) = &self.chunk_config {
62            // Chunked processing for large datasets
63            self.process_chunked(data, chunk_config.clone(), ctx).await
64        } else {
65            // Regular streaming processing
66            self.process_regular(data, ctx).await
67        }
68    }
69
70    /// Regular stream processing
71    async fn process_regular(&self, data: Vec<Value>, ctx: &Context) -> RuleResult {
72        let stream = create_stream_from_vec(data, self.backpressure_config.clone());
73        let stream = apply_backpressure(stream, self.backpressure_config.clone());
74
75        let mut stream = Box::pin(stream);
76        let mut results = Vec::new();
77
78        while let Some(item) = stream.next().await {
79            match item {
80                Ok(value) => match self.processor.process_item(value, ctx).await {
81                    Ok(result) => results.push(result),
82                    Err(_) => continue,
83                },
84                Err(_) => continue,
85            }
86        }
87
88        info!("StreamNode[{}]: Processed {} items", self.id, results.len());
89
90        if self.collect_results {
91            Ok(Value::Array(results))
92        } else {
93            results
94                .last()
95                .cloned()
96                .ok_or_else(|| RuleError::Eval("No results produced".to_string()))
97        }
98    }
99
100    /// Chunked stream processing for large datasets
101    async fn process_chunked(
102        &self,
103        data: Vec<Value>,
104        chunk_config: ChunkConfig,
105        ctx: &Context,
106    ) -> RuleResult {
107        info!(
108            "StreamNode[{}]: Processing {} items in chunks of {}",
109            self.id,
110            data.len(),
111            chunk_config.chunk_size
112        );
113
114        let mut stream = create_chunked_stream(data, chunk_config);
115        let mut all_results = Vec::new();
116
117        while let Some(chunk_result) = stream.next().await {
118            match chunk_result {
119                Ok(chunk) => {
120                    info!(
121                        "StreamNode[{}]: Processing chunk of {} items",
122                        self.id,
123                        chunk.len()
124                    );
125                    match self.processor.process_chunk(chunk, ctx).await {
126                        Ok(results) => {
127                            if self.collect_results {
128                                all_results.extend(results);
129                            } else {
130                                if let Some(last) = results.last() {
131                                    all_results = vec![last.clone()];
132                                }
133                            }
134                        }
135                        Err(e) => {
136                            error!("StreamNode[{}]: Chunk processing failed: {}", self.id, e);
137                            return Err(e);
138                        }
139                    }
140                }
141                Err(e) => {
142                    error!("StreamNode[{}]: Stream error: {}", self.id, e);
143                    return Err(e);
144                }
145            }
146        }
147
148        info!(
149            "StreamNode[{}]: Total processed {} items",
150            self.id,
151            all_results.len()
152        );
153
154        if self.collect_results {
155            Ok(Value::Array(all_results))
156        } else {
157            all_results
158                .last()
159                .cloned()
160                .ok_or_else(|| RuleError::Eval("No results produced".to_string()))
161        }
162    }
163}
164
165impl std::fmt::Debug for StreamNode {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("StreamNode")
168            .field("id", &self.id)
169            .field("backpressure_config", &self.backpressure_config)
170            .field("chunk_config", &self.chunk_config)
171            .field("collect_results", &self.collect_results)
172            .finish()
173    }
174}
175
176#[async_trait]
177impl Node for StreamNode {
178    fn id(&self) -> &str {
179        &self.id
180    }
181
182    fn node_type(&self) -> NodeType {
183        NodeType::AINode // Using AINode as it's for processing
184    }
185
186    async fn run(&self, ctx: &mut Context) -> RuleResult {
187        info!("StreamNode[{}]: Starting stream execution", self.id);
188
189        // Get input data from context
190        let input_key = format!("{}_input", self.id);
191        let data = ctx
192            .data
193            .get(&input_key)
194            .and_then(|v| v.as_array())
195            .ok_or_else(|| RuleError::Eval(format!("No input data found for key: {}", input_key)))?
196            .clone();
197
198        let result = self.process_stream(data, ctx).await?;
199
200        // Store result in context
201        ctx.data
202            .insert(format!("{}_result", self.id), result.clone());
203
204        Ok(result)
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use std::collections::HashMap;
212
213    struct TestProcessor;
214
215    #[async_trait]
216    impl StreamProcessor for TestProcessor {
217        async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
218            if let Some(n) = item.as_i64() {
219                Ok(Value::Number((n * 2).into()))
220            } else {
221                Ok(item)
222            }
223        }
224    }
225
226    #[tokio::test]
227    async fn test_stream_node_basic() {
228        let processor = Arc::new(TestProcessor);
229        let node = StreamNode::new("test", processor);
230
231        let data: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
232
233        let ctx = Context {
234            data: HashMap::new(),
235        };
236
237        let result = node.process_stream(data, &ctx).await.unwrap();
238
239        if let Value::Array(results) = result {
240            assert_eq!(results.len(), 5);
241            assert_eq!(results[0], Value::Number(2.into()));
242            assert_eq!(results[4], Value::Number(10.into()));
243        } else {
244            panic!("Expected array result");
245        }
246    }
247
248    #[tokio::test]
249    async fn test_stream_node_chunked() {
250        let processor = Arc::new(TestProcessor);
251        let node = StreamNode::new("test", processor).with_chunking(ChunkConfig {
252            chunk_size: 3,
253            overlap: 0,
254        });
255
256        let data: Vec<Value> = (1..=10).map(|i| Value::Number(i.into())).collect();
257
258        let ctx = Context {
259            data: HashMap::new(),
260        };
261
262        let result = node.process_stream(data, &ctx).await.unwrap();
263
264        if let Value::Array(results) = result {
265            assert_eq!(results.len(), 10);
266        } else {
267            panic!("Expected array result");
268        }
269    }
270}