rust_logic_graph/streaming/
operators.rs

1//! Stream transformation operators
2
3use crate::core::Context;
4use crate::rule::{RuleError, RuleResult};
5use crate::streaming::StreamProcessor;
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9
10/// Stream operator trait
11#[async_trait]
12pub trait StreamOperator: StreamProcessor {
13    /// Get operator name
14    fn name(&self) -> &str;
15}
16
17/// Map operator - transforms each item
18pub struct MapOperator<F>
19where
20    F: Fn(Value) -> Value + Send + Sync,
21{
22    pub name: String,
23    pub func: Arc<F>,
24}
25
26impl<F> MapOperator<F>
27where
28    F: Fn(Value) -> Value + Send + Sync,
29{
30    pub fn new(name: impl Into<String>, func: F) -> Self {
31        Self {
32            name: name.into(),
33            func: Arc::new(func),
34        }
35    }
36}
37
38#[async_trait]
39impl<F> StreamProcessor for MapOperator<F>
40where
41    F: Fn(Value) -> Value + Send + Sync,
42{
43    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
44        Ok((self.func)(item))
45    }
46}
47
48#[async_trait]
49impl<F> StreamOperator for MapOperator<F>
50where
51    F: Fn(Value) -> Value + Send + Sync,
52{
53    fn name(&self) -> &str {
54        &self.name
55    }
56}
57
58/// Filter operator - filters items based on predicate
59pub struct FilterOperator<F>
60where
61    F: Fn(&Value) -> bool + Send + Sync,
62{
63    pub name: String,
64    pub predicate: Arc<F>,
65}
66
67impl<F> FilterOperator<F>
68where
69    F: Fn(&Value) -> bool + Send + Sync,
70{
71    pub fn new(name: impl Into<String>, predicate: F) -> Self {
72        Self {
73            name: name.into(),
74            predicate: Arc::new(predicate),
75        }
76    }
77}
78
79#[async_trait]
80impl<F> StreamProcessor for FilterOperator<F>
81where
82    F: Fn(&Value) -> bool + Send + Sync,
83{
84    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
85        if (self.predicate)(&item) {
86            Ok(item)
87        } else {
88            Err(RuleError::Eval("Filtered out".to_string()))
89        }
90    }
91}
92
93#[async_trait]
94impl<F> StreamOperator for FilterOperator<F>
95where
96    F: Fn(&Value) -> bool + Send + Sync,
97{
98    fn name(&self) -> &str {
99        &self.name
100    }
101}
102
103/// Fold operator - accumulates values
104pub struct FoldOperator<F, T>
105where
106    F: Fn(T, Value) -> T + Send + Sync,
107    T: Clone + Send + Sync + 'static,
108{
109    pub name: String,
110    pub initial: T,
111    pub func: Arc<F>,
112}
113
114impl<F, T> FoldOperator<F, T>
115where
116    F: Fn(T, Value) -> T + Send + Sync,
117    T: Clone + Send + Sync + 'static,
118{
119    pub fn new(name: impl Into<String>, initial: T, func: F) -> Self {
120        Self {
121            name: name.into(),
122            initial,
123            func: Arc::new(func),
124        }
125    }
126}
127
128#[async_trait]
129impl<F, T> StreamProcessor for FoldOperator<F, T>
130where
131    F: Fn(T, Value) -> T + Send + Sync,
132    T: Clone + Send + Sync + Into<Value> + 'static,
133{
134    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
135        // For fold, we need to process in chunks
136        Ok(item)
137    }
138
139    async fn process_chunk(
140        &self,
141        items: Vec<Value>,
142        _ctx: &Context,
143    ) -> Result<Vec<Value>, RuleError> {
144        let result = items
145            .into_iter()
146            .fold(self.initial.clone(), |acc, item| (self.func)(acc, item));
147
148        Ok(vec![result.into()])
149    }
150}
151
152#[async_trait]
153impl<F, T> StreamOperator for FoldOperator<F, T>
154where
155    F: Fn(T, Value) -> T + Send + Sync,
156    T: Clone + Send + Sync + Into<Value> + 'static,
157{
158    fn name(&self) -> &str {
159        &self.name
160    }
161}
162
163/// Async map operator - transforms each item asynchronously
164pub struct AsyncMapOperator<F, Fut>
165where
166    F: Fn(Value) -> Fut + Send + Sync,
167    Fut: std::future::Future<Output = Value> + Send,
168{
169    pub name: String,
170    pub func: Arc<F>,
171}
172
173impl<F, Fut> AsyncMapOperator<F, Fut>
174where
175    F: Fn(Value) -> Fut + Send + Sync,
176    Fut: std::future::Future<Output = Value> + Send,
177{
178    pub fn new(name: impl Into<String>, func: F) -> Self {
179        Self {
180            name: name.into(),
181            func: Arc::new(func),
182        }
183    }
184}
185
186#[async_trait]
187impl<F, Fut> StreamProcessor for AsyncMapOperator<F, Fut>
188where
189    F: Fn(Value) -> Fut + Send + Sync,
190    Fut: std::future::Future<Output = Value> + Send,
191{
192    async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
193        Ok((self.func)(item).await)
194    }
195}
196
197#[async_trait]
198impl<F, Fut> StreamOperator for AsyncMapOperator<F, Fut>
199where
200    F: Fn(Value) -> Fut + Send + Sync,
201    Fut: std::future::Future<Output = Value> + Send,
202{
203    fn name(&self) -> &str {
204        &self.name
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use std::collections::HashMap;
212
213    #[tokio::test]
214    async fn test_map_operator() {
215        let op = MapOperator::new("double", |v: Value| {
216            if let Some(n) = v.as_i64() {
217                Value::Number((n * 2).into())
218            } else {
219                v
220            }
221        });
222
223        let ctx = Context {
224            data: HashMap::new(),
225        };
226
227        let result = op
228            .process_item(Value::Number(5.into()), &ctx)
229            .await
230            .unwrap();
231        assert_eq!(result, Value::Number(10.into()));
232    }
233
234    #[tokio::test]
235    async fn test_filter_operator() {
236        let op = FilterOperator::new("even_only", |v: &Value| {
237            v.as_i64().map(|n| n % 2 == 0).unwrap_or(false)
238        });
239
240        let ctx = Context {
241            data: HashMap::new(),
242        };
243
244        let result = op.process_item(Value::Number(4.into()), &ctx).await;
245        assert!(result.is_ok());
246
247        let result = op.process_item(Value::Number(5.into()), &ctx).await;
248        assert!(result.is_err());
249    }
250
251    #[tokio::test]
252    async fn test_fold_operator() {
253        let op = FoldOperator::new("sum", 0i64, |acc: i64, v: Value| {
254            acc + v.as_i64().unwrap_or(0)
255        });
256
257        let ctx = Context {
258            data: HashMap::new(),
259        };
260
261        let items: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
262        let results = op.process_chunk(items, &ctx).await.unwrap();
263
264        assert_eq!(results.len(), 1);
265        assert_eq!(results[0], Value::Number(15.into()));
266    }
267
268    #[tokio::test]
269    async fn test_async_map_operator() {
270        let op = AsyncMapOperator::new("async_double", |v: Value| async move {
271            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
272            if let Some(n) = v.as_i64() {
273                Value::Number((n * 2).into())
274            } else {
275                v
276            }
277        });
278
279        let ctx = Context {
280            data: HashMap::new(),
281        };
282
283        let result = op
284            .process_item(Value::Number(5.into()), &ctx)
285            .await
286            .unwrap();
287        assert_eq!(result, Value::Number(10.into()));
288    }
289}