rust_logic_graph/streaming/
operators.rs

1//! Stream transformation operators
2
3use crate::core::Context;
4use crate::rule::{RuleResult, RuleError};
5use crate::streaming::StreamProcessor;
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9
10/// Stream operator trait
11#[async_trait]
12pub trait StreamOperator: StreamProcessor {
13    /// Get operator name
14    fn name(&self) -> &str;
15}
16
17/// Map operator - transforms each item
18pub struct MapOperator<F>
19where
20    F: Fn(Value) -> Value + Send + Sync,
21{
22    pub name: String,
23    pub func: Arc<F>,
24}
25
26impl<F> MapOperator<F>
27where
28    F: Fn(Value) -> Value + Send + Sync,
29{
30    pub fn new(name: impl Into<String>, func: F) -> Self {
31        Self {
32            name: name.into(),
33            func: Arc::new(func),
34        }
35    }
36}
37
38#[async_trait]
39impl<F> StreamProcessor for MapOperator<F>
40where
41    F: Fn(Value) -> Value + Send + Sync,
42{
43    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
44        Ok((self.func)(item))
45    }
46}
47
48#[async_trait]
49impl<F> StreamOperator for MapOperator<F>
50where
51    F: Fn(Value) -> Value + Send + Sync,
52{
53    fn name(&self) -> &str {
54        &self.name
55    }
56}
57
58/// Filter operator - filters items based on predicate
59pub struct FilterOperator<F>
60where
61    F: Fn(&Value) -> bool + Send + Sync,
62{
63    pub name: String,
64    pub predicate: Arc<F>,
65}
66
67impl<F> FilterOperator<F>
68where
69    F: Fn(&Value) -> bool + Send + Sync,
70{
71    pub fn new(name: impl Into<String>, predicate: F) -> Self {
72        Self {
73            name: name.into(),
74            predicate: Arc::new(predicate),
75        }
76    }
77}
78
79#[async_trait]
80impl<F> StreamProcessor for FilterOperator<F>
81where
82    F: Fn(&Value) -> bool + Send + Sync,
83{
84    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
85        if (self.predicate)(&item) {
86            Ok(item)
87        } else {
88            Err(RuleError::Eval("Filtered out".to_string()))
89        }
90    }
91}
92
93#[async_trait]
94impl<F> StreamOperator for FilterOperator<F>
95where
96    F: Fn(&Value) -> bool + Send + Sync,
97{
98    fn name(&self) -> &str {
99        &self.name
100    }
101}
102
103/// Fold operator - accumulates values
104pub struct FoldOperator<F, T>
105where
106    F: Fn(T, Value) -> T + Send + Sync,
107    T: Clone + Send + Sync + 'static,
108{
109    pub name: String,
110    pub initial: T,
111    pub func: Arc<F>,
112}
113
114impl<F, T> FoldOperator<F, T>
115where
116    F: Fn(T, Value) -> T + Send + Sync,
117    T: Clone + Send + Sync + 'static,
118{
119    pub fn new(name: impl Into<String>, initial: T, func: F) -> Self {
120        Self {
121            name: name.into(),
122            initial,
123            func: Arc::new(func),
124        }
125    }
126}
127
128#[async_trait]
129impl<F, T> StreamProcessor for FoldOperator<F, T>
130where
131    F: Fn(T, Value) -> T + Send + Sync,
132    T: Clone + Send + Sync + Into<Value> + 'static,
133{
134    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
135        // For fold, we need to process in chunks
136        Ok(item)
137    }
138
139    async fn process_chunk(&self, items: Vec<Value>, _ctx: &Context) -> Result<Vec<Value>, RuleError> {
140        let result = items.into_iter().fold(self.initial.clone(), |acc, item| {
141            (self.func)(acc, item)
142        });
143
144        Ok(vec![result.into()])
145    }
146}
147
148#[async_trait]
149impl<F, T> StreamOperator for FoldOperator<F, T>
150where
151    F: Fn(T, Value) -> T + Send + Sync,
152    T: Clone + Send + Sync + Into<Value> + 'static,
153{
154    fn name(&self) -> &str {
155        &self.name
156    }
157}
158
159/// Async map operator - transforms each item asynchronously
160pub struct AsyncMapOperator<F, Fut>
161where
162    F: Fn(Value) -> Fut + Send + Sync,
163    Fut: std::future::Future<Output = Value> + Send,
164{
165    pub name: String,
166    pub func: Arc<F>,
167}
168
169impl<F, Fut> AsyncMapOperator<F, Fut>
170where
171    F: Fn(Value) -> Fut + Send + Sync,
172    Fut: std::future::Future<Output = Value> + Send,
173{
174    pub fn new(name: impl Into<String>, func: F) -> Self {
175        Self {
176            name: name.into(),
177            func: Arc::new(func),
178        }
179    }
180}
181
182#[async_trait]
183impl<F, Fut> StreamProcessor for AsyncMapOperator<F, Fut>
184where
185    F: Fn(Value) -> Fut + Send + Sync,
186    Fut: std::future::Future<Output = Value> + Send,
187{
188    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
189        Ok((self.func)(item).await)
190    }
191}
192
193#[async_trait]
194impl<F, Fut> StreamOperator for AsyncMapOperator<F, Fut>
195where
196    F: Fn(Value) -> Fut + Send + Sync,
197    Fut: std::future::Future<Output = Value> + Send,
198{
199    fn name(&self) -> &str {
200        &self.name
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use std::collections::HashMap;
208
209    #[tokio::test]
210    async fn test_map_operator() {
211        let op = MapOperator::new("double", |v: Value| {
212            if let Some(n) = v.as_i64() {
213                Value::Number((n * 2).into())
214            } else {
215                v
216            }
217        });
218
219        let ctx = Context {
220            data: HashMap::new(),
221        };
222
223        let result = op.process_item(Value::Number(5.into()), &ctx).await.unwrap();
224        assert_eq!(result, Value::Number(10.into()));
225    }
226
227    #[tokio::test]
228    async fn test_filter_operator() {
229        let op = FilterOperator::new("even_only", |v: &Value| {
230            v.as_i64().map(|n| n % 2 == 0).unwrap_or(false)
231        });
232
233        let ctx = Context {
234            data: HashMap::new(),
235        };
236
237        let result = op.process_item(Value::Number(4.into()), &ctx).await;
238        assert!(result.is_ok());
239
240        let result = op.process_item(Value::Number(5.into()), &ctx).await;
241        assert!(result.is_err());
242    }
243
244    #[tokio::test]
245    async fn test_fold_operator() {
246        let op = FoldOperator::new("sum", 0i64, |acc: i64, v: Value| {
247            acc + v.as_i64().unwrap_or(0)
248        });
249
250        let ctx = Context {
251            data: HashMap::new(),
252        };
253
254        let items: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
255        let results = op.process_chunk(items, &ctx).await.unwrap();
256
257        assert_eq!(results.len(), 1);
258        assert_eq!(results[0], Value::Number(15.into()));
259    }
260
261    #[tokio::test]
262    async fn test_async_map_operator() {
263        let op = AsyncMapOperator::new("async_double", |v: Value| async move {
264            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
265            if let Some(n) = v.as_i64() {
266                Value::Number((n * 2).into())
267            } else {
268                v
269            }
270        });
271
272        let ctx = Context {
273            data: HashMap::new(),
274        };
275
276        let result = op.process_item(Value::Number(5.into()), &ctx).await.unwrap();
277        assert_eq!(result, Value::Number(10.into()));
278    }
279}