rust_logic_graph/streaming/
stream_node.rs1use crate::core::Context;
4use crate::node::{Node, NodeType};
5use crate::rule::{RuleResult, RuleError};
6use crate::streaming::{
7 StreamProcessor, BackpressureConfig, ChunkConfig,
8 create_stream_from_vec, create_chunked_stream, apply_backpressure,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use tokio_stream::StreamExt;
13use tracing::{info, error};
14use std::sync::Arc;
15
16#[derive(Clone)]
18pub struct StreamNode {
19 pub id: String,
20 pub processor: Arc<dyn StreamProcessor>,
21 pub backpressure_config: BackpressureConfig,
22 pub chunk_config: Option<ChunkConfig>,
23 pub collect_results: bool,
24}
25
26impl StreamNode {
27 pub fn new(id: impl Into<String>, processor: Arc<dyn StreamProcessor>) -> Self {
29 Self {
30 id: id.into(),
31 processor,
32 backpressure_config: BackpressureConfig::default(),
33 chunk_config: None,
34 collect_results: true,
35 }
36 }
37
38 pub fn with_backpressure(mut self, config: BackpressureConfig) -> Self {
40 self.backpressure_config = config;
41 self
42 }
43
44 pub fn with_chunking(mut self, config: ChunkConfig) -> Self {
46 self.chunk_config = Some(config);
47 self
48 }
49
50 pub fn collect_results(mut self, collect: bool) -> Self {
53 self.collect_results = collect;
54 self
55 }
56
57 pub async fn process_stream(
59 &self,
60 data: Vec<Value>,
61 ctx: &Context,
62 ) -> RuleResult {
63 info!("StreamNode[{}]: Processing {} items", self.id, data.len());
64
65 if let Some(chunk_config) = &self.chunk_config {
66 self.process_chunked(data, chunk_config.clone(), ctx).await
68 } else {
69 self.process_regular(data, ctx).await
71 }
72 }
73
74 async fn process_regular(
76 &self,
77 data: Vec<Value>,
78 ctx: &Context,
79 ) -> RuleResult {
80 let stream = create_stream_from_vec(data, self.backpressure_config.clone());
81 let stream = apply_backpressure(stream, self.backpressure_config.clone());
82
83 let mut stream = Box::pin(stream);
84 let mut results = Vec::new();
85
86 while let Some(item) = stream.next().await {
87 match item {
88 Ok(value) => {
89 match self.processor.process_item(value, ctx).await {
90 Ok(result) => results.push(result),
91 Err(_) => continue,
92 }
93 }
94 Err(_) => continue,
95 }
96 }
97
98 info!("StreamNode[{}]: Processed {} items", self.id, results.len());
99
100 if self.collect_results {
101 Ok(Value::Array(results))
102 } else {
103 results.last().cloned().ok_or_else(|| {
104 RuleError::Eval("No results produced".to_string())
105 })
106 }
107 }
108
109 async fn process_chunked(
111 &self,
112 data: Vec<Value>,
113 chunk_config: ChunkConfig,
114 ctx: &Context,
115 ) -> RuleResult {
116 info!(
117 "StreamNode[{}]: Processing {} items in chunks of {}",
118 self.id,
119 data.len(),
120 chunk_config.chunk_size
121 );
122
123 let mut stream = create_chunked_stream(data, chunk_config);
124 let mut all_results = Vec::new();
125
126 while let Some(chunk_result) = stream.next().await {
127 match chunk_result {
128 Ok(chunk) => {
129 info!("StreamNode[{}]: Processing chunk of {} items", self.id, chunk.len());
130 match self.processor.process_chunk(chunk, ctx).await {
131 Ok(results) => {
132 if self.collect_results {
133 all_results.extend(results);
134 } else {
135 if let Some(last) = results.last() {
136 all_results = vec![last.clone()];
137 }
138 }
139 }
140 Err(e) => {
141 error!("StreamNode[{}]: Chunk processing failed: {}", self.id, e);
142 return Err(e);
143 }
144 }
145 }
146 Err(e) => {
147 error!("StreamNode[{}]: Stream error: {}", self.id, e);
148 return Err(e);
149 }
150 }
151 }
152
153 info!("StreamNode[{}]: Total processed {} items", self.id, all_results.len());
154
155 if self.collect_results {
156 Ok(Value::Array(all_results))
157 } else {
158 all_results.last().cloned().ok_or_else(|| {
159 RuleError::Eval("No results produced".to_string())
160 })
161 }
162 }
163}
164
165impl std::fmt::Debug for StreamNode {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("StreamNode")
168 .field("id", &self.id)
169 .field("backpressure_config", &self.backpressure_config)
170 .field("chunk_config", &self.chunk_config)
171 .field("collect_results", &self.collect_results)
172 .finish()
173 }
174}
175
176#[async_trait]
177impl Node for StreamNode {
178 fn id(&self) -> &str {
179 &self.id
180 }
181
182 fn node_type(&self) -> NodeType {
183 NodeType::AINode }
185
186 async fn run(&self, ctx: &mut Context) -> RuleResult {
187 info!("StreamNode[{}]: Starting stream execution", self.id);
188
189 let input_key = format!("{}_input", self.id);
191 let data = ctx.data.get(&input_key)
192 .and_then(|v| v.as_array())
193 .ok_or_else(|| RuleError::Eval(format!("No input data found for key: {}", input_key)))?
194 .clone();
195
196 let result = self.process_stream(data, ctx).await?;
197
198 ctx.data.insert(format!("{}_result", self.id), result.clone());
200
201 Ok(result)
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use std::collections::HashMap;
209
210 struct TestProcessor;
211
212 #[async_trait]
213 impl StreamProcessor for TestProcessor {
214 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
215 if let Some(n) = item.as_i64() {
216 Ok(Value::Number((n * 2).into()))
217 } else {
218 Ok(item)
219 }
220 }
221 }
222
223 #[tokio::test]
224 async fn test_stream_node_basic() {
225 let processor = Arc::new(TestProcessor);
226 let node = StreamNode::new("test", processor);
227
228 let data: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
229
230 let ctx = Context {
231 data: HashMap::new(),
232 };
233
234 let result = node.process_stream(data, &ctx).await.unwrap();
235
236 if let Value::Array(results) = result {
237 assert_eq!(results.len(), 5);
238 assert_eq!(results[0], Value::Number(2.into()));
239 assert_eq!(results[4], Value::Number(10.into()));
240 } else {
241 panic!("Expected array result");
242 }
243 }
244
245 #[tokio::test]
246 async fn test_stream_node_chunked() {
247 let processor = Arc::new(TestProcessor);
248 let node = StreamNode::new("test", processor)
249 .with_chunking(ChunkConfig {
250 chunk_size: 3,
251 overlap: 0,
252 });
253
254 let data: Vec<Value> = (1..=10).map(|i| Value::Number(i.into())).collect();
255
256 let ctx = Context {
257 data: HashMap::new(),
258 };
259
260 let result = node.process_stream(data, &ctx).await.unwrap();
261
262 if let Value::Array(results) = result {
263 assert_eq!(results.len(), 10);
264 } else {
265 panic!("Expected array result");
266 }
267 }
268}