rust_logic_graph/streaming/
stream_node.rs1use crate::core::Context;
4use crate::node::{Node, NodeType};
5use crate::rule::{RuleError, RuleResult};
6use crate::streaming::{
7 apply_backpressure, create_chunked_stream, create_stream_from_vec, BackpressureConfig,
8 ChunkConfig, StreamProcessor,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::sync::Arc;
13use tokio_stream::StreamExt;
14use tracing::{error, info};
15
16#[derive(Clone)]
18pub struct StreamNode {
19 pub id: String,
20 pub processor: Arc<dyn StreamProcessor>,
21 pub backpressure_config: BackpressureConfig,
22 pub chunk_config: Option<ChunkConfig>,
23 pub collect_results: bool,
24}
25
26impl StreamNode {
27 pub fn new(id: impl Into<String>, processor: Arc<dyn StreamProcessor>) -> Self {
29 Self {
30 id: id.into(),
31 processor,
32 backpressure_config: BackpressureConfig::default(),
33 chunk_config: None,
34 collect_results: true,
35 }
36 }
37
38 pub fn with_backpressure(mut self, config: BackpressureConfig) -> Self {
40 self.backpressure_config = config;
41 self
42 }
43
44 pub fn with_chunking(mut self, config: ChunkConfig) -> Self {
46 self.chunk_config = Some(config);
47 self
48 }
49
50 pub fn collect_results(mut self, collect: bool) -> Self {
53 self.collect_results = collect;
54 self
55 }
56
57 pub async fn process_stream(&self, data: Vec<Value>, ctx: &Context) -> RuleResult {
59 info!("StreamNode[{}]: Processing {} items", self.id, data.len());
60
61 if let Some(chunk_config) = &self.chunk_config {
62 self.process_chunked(data, chunk_config.clone(), ctx).await
64 } else {
65 self.process_regular(data, ctx).await
67 }
68 }
69
70 async fn process_regular(&self, data: Vec<Value>, ctx: &Context) -> RuleResult {
72 let stream = create_stream_from_vec(data, self.backpressure_config.clone());
73 let stream = apply_backpressure(stream, self.backpressure_config.clone());
74
75 let mut stream = Box::pin(stream);
76 let mut results = Vec::new();
77
78 while let Some(item) = stream.next().await {
79 match item {
80 Ok(value) => match self.processor.process_item(value, ctx).await {
81 Ok(result) => results.push(result),
82 Err(_) => continue,
83 },
84 Err(_) => continue,
85 }
86 }
87
88 info!("StreamNode[{}]: Processed {} items", self.id, results.len());
89
90 if self.collect_results {
91 Ok(Value::Array(results))
92 } else {
93 results
94 .last()
95 .cloned()
96 .ok_or_else(|| RuleError::Eval("No results produced".to_string()))
97 }
98 }
99
100 async fn process_chunked(
102 &self,
103 data: Vec<Value>,
104 chunk_config: ChunkConfig,
105 ctx: &Context,
106 ) -> RuleResult {
107 info!(
108 "StreamNode[{}]: Processing {} items in chunks of {}",
109 self.id,
110 data.len(),
111 chunk_config.chunk_size
112 );
113
114 let mut stream = create_chunked_stream(data, chunk_config);
115 let mut all_results = Vec::new();
116
117 while let Some(chunk_result) = stream.next().await {
118 match chunk_result {
119 Ok(chunk) => {
120 info!(
121 "StreamNode[{}]: Processing chunk of {} items",
122 self.id,
123 chunk.len()
124 );
125 match self.processor.process_chunk(chunk, ctx).await {
126 Ok(results) => {
127 if self.collect_results {
128 all_results.extend(results);
129 } else {
130 if let Some(last) = results.last() {
131 all_results = vec![last.clone()];
132 }
133 }
134 }
135 Err(e) => {
136 error!("StreamNode[{}]: Chunk processing failed: {}", self.id, e);
137 return Err(e);
138 }
139 }
140 }
141 Err(e) => {
142 error!("StreamNode[{}]: Stream error: {}", self.id, e);
143 return Err(e);
144 }
145 }
146 }
147
148 info!(
149 "StreamNode[{}]: Total processed {} items",
150 self.id,
151 all_results.len()
152 );
153
154 if self.collect_results {
155 Ok(Value::Array(all_results))
156 } else {
157 all_results
158 .last()
159 .cloned()
160 .ok_or_else(|| RuleError::Eval("No results produced".to_string()))
161 }
162 }
163}
164
165impl std::fmt::Debug for StreamNode {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("StreamNode")
168 .field("id", &self.id)
169 .field("backpressure_config", &self.backpressure_config)
170 .field("chunk_config", &self.chunk_config)
171 .field("collect_results", &self.collect_results)
172 .finish()
173 }
174}
175
176#[async_trait]
177impl Node for StreamNode {
178 fn id(&self) -> &str {
179 &self.id
180 }
181
182 fn node_type(&self) -> NodeType {
183 NodeType::AINode }
185
186 async fn run(&self, ctx: &mut Context) -> RuleResult {
187 info!("StreamNode[{}]: Starting stream execution", self.id);
188
189 let input_key = format!("{}_input", self.id);
191 let data = ctx
192 .data
193 .get(&input_key)
194 .and_then(|v| v.as_array())
195 .ok_or_else(|| RuleError::Eval(format!("No input data found for key: {}", input_key)))?
196 .clone();
197
198 let result = self.process_stream(data, ctx).await?;
199
200 ctx.data
202 .insert(format!("{}_result", self.id), result.clone());
203
204 Ok(result)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::collections::HashMap;
212
213 struct TestProcessor;
214
215 #[async_trait]
216 impl StreamProcessor for TestProcessor {
217 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
218 if let Some(n) = item.as_i64() {
219 Ok(Value::Number((n * 2).into()))
220 } else {
221 Ok(item)
222 }
223 }
224 }
225
226 #[tokio::test]
227 async fn test_stream_node_basic() {
228 let processor = Arc::new(TestProcessor);
229 let node = StreamNode::new("test", processor);
230
231 let data: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
232
233 let ctx = Context {
234 data: HashMap::new(),
235 };
236
237 let result = node.process_stream(data, &ctx).await.unwrap();
238
239 if let Value::Array(results) = result {
240 assert_eq!(results.len(), 5);
241 assert_eq!(results[0], Value::Number(2.into()));
242 assert_eq!(results[4], Value::Number(10.into()));
243 } else {
244 panic!("Expected array result");
245 }
246 }
247
248 #[tokio::test]
249 async fn test_stream_node_chunked() {
250 let processor = Arc::new(TestProcessor);
251 let node = StreamNode::new("test", processor).with_chunking(ChunkConfig {
252 chunk_size: 3,
253 overlap: 0,
254 });
255
256 let data: Vec<Value> = (1..=10).map(|i| Value::Number(i.into())).collect();
257
258 let ctx = Context {
259 data: HashMap::new(),
260 };
261
262 let result = node.process_stream(data, &ctx).await.unwrap();
263
264 if let Value::Array(results) = result {
265 assert_eq!(results.len(), 10);
266 } else {
267 panic!("Expected array result");
268 }
269 }
270}