rust_logic_graph/streaming/
operators.rs1use crate::core::Context;
4use crate::rule::{RuleError, RuleResult};
5use crate::streaming::StreamProcessor;
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9
10#[async_trait]
12pub trait StreamOperator: StreamProcessor {
13 fn name(&self) -> &str;
15}
16
17pub struct MapOperator<F>
19where
20 F: Fn(Value) -> Value + Send + Sync,
21{
22 pub name: String,
23 pub func: Arc<F>,
24}
25
26impl<F> MapOperator<F>
27where
28 F: Fn(Value) -> Value + Send + Sync,
29{
30 pub fn new(name: impl Into<String>, func: F) -> Self {
31 Self {
32 name: name.into(),
33 func: Arc::new(func),
34 }
35 }
36}
37
38#[async_trait]
39impl<F> StreamProcessor for MapOperator<F>
40where
41 F: Fn(Value) -> Value + Send + Sync,
42{
43 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
44 Ok((self.func)(item))
45 }
46}
47
48#[async_trait]
49impl<F> StreamOperator for MapOperator<F>
50where
51 F: Fn(Value) -> Value + Send + Sync,
52{
53 fn name(&self) -> &str {
54 &self.name
55 }
56}
57
58pub struct FilterOperator<F>
60where
61 F: Fn(&Value) -> bool + Send + Sync,
62{
63 pub name: String,
64 pub predicate: Arc<F>,
65}
66
67impl<F> FilterOperator<F>
68where
69 F: Fn(&Value) -> bool + Send + Sync,
70{
71 pub fn new(name: impl Into<String>, predicate: F) -> Self {
72 Self {
73 name: name.into(),
74 predicate: Arc::new(predicate),
75 }
76 }
77}
78
79#[async_trait]
80impl<F> StreamProcessor for FilterOperator<F>
81where
82 F: Fn(&Value) -> bool + Send + Sync,
83{
84 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
85 if (self.predicate)(&item) {
86 Ok(item)
87 } else {
88 Err(RuleError::Eval("Filtered out".to_string()))
89 }
90 }
91}
92
93#[async_trait]
94impl<F> StreamOperator for FilterOperator<F>
95where
96 F: Fn(&Value) -> bool + Send + Sync,
97{
98 fn name(&self) -> &str {
99 &self.name
100 }
101}
102
103pub struct FoldOperator<F, T>
105where
106 F: Fn(T, Value) -> T + Send + Sync,
107 T: Clone + Send + Sync + 'static,
108{
109 pub name: String,
110 pub initial: T,
111 pub func: Arc<F>,
112}
113
114impl<F, T> FoldOperator<F, T>
115where
116 F: Fn(T, Value) -> T + Send + Sync,
117 T: Clone + Send + Sync + 'static,
118{
119 pub fn new(name: impl Into<String>, initial: T, func: F) -> Self {
120 Self {
121 name: name.into(),
122 initial,
123 func: Arc::new(func),
124 }
125 }
126}
127
128#[async_trait]
129impl<F, T> StreamProcessor for FoldOperator<F, T>
130where
131 F: Fn(T, Value) -> T + Send + Sync,
132 T: Clone + Send + Sync + Into<Value> + 'static,
133{
134 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
135 Ok(item)
137 }
138
139 async fn process_chunk(
140 &self,
141 items: Vec<Value>,
142 _ctx: &Context,
143 ) -> Result<Vec<Value>, RuleError> {
144 let result = items
145 .into_iter()
146 .fold(self.initial.clone(), |acc, item| (self.func)(acc, item));
147
148 Ok(vec![result.into()])
149 }
150}
151
152#[async_trait]
153impl<F, T> StreamOperator for FoldOperator<F, T>
154where
155 F: Fn(T, Value) -> T + Send + Sync,
156 T: Clone + Send + Sync + Into<Value> + 'static,
157{
158 fn name(&self) -> &str {
159 &self.name
160 }
161}
162
163pub struct AsyncMapOperator<F, Fut>
165where
166 F: Fn(Value) -> Fut + Send + Sync,
167 Fut: std::future::Future<Output = Value> + Send,
168{
169 pub name: String,
170 pub func: Arc<F>,
171}
172
173impl<F, Fut> AsyncMapOperator<F, Fut>
174where
175 F: Fn(Value) -> Fut + Send + Sync,
176 Fut: std::future::Future<Output = Value> + Send,
177{
178 pub fn new(name: impl Into<String>, func: F) -> Self {
179 Self {
180 name: name.into(),
181 func: Arc::new(func),
182 }
183 }
184}
185
186#[async_trait]
187impl<F, Fut> StreamProcessor for AsyncMapOperator<F, Fut>
188where
189 F: Fn(Value) -> Fut + Send + Sync,
190 Fut: std::future::Future<Output = Value> + Send,
191{
192 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
193 Ok((self.func)(item).await)
194 }
195}
196
197#[async_trait]
198impl<F, Fut> StreamOperator for AsyncMapOperator<F, Fut>
199where
200 F: Fn(Value) -> Fut + Send + Sync,
201 Fut: std::future::Future<Output = Value> + Send,
202{
203 fn name(&self) -> &str {
204 &self.name
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::collections::HashMap;
212
213 #[tokio::test]
214 async fn test_map_operator() {
215 let op = MapOperator::new("double", |v: Value| {
216 if let Some(n) = v.as_i64() {
217 Value::Number((n * 2).into())
218 } else {
219 v
220 }
221 });
222
223 let ctx = Context {
224 data: HashMap::new(),
225 };
226
227 let result = op
228 .process_item(Value::Number(5.into()), &ctx)
229 .await
230 .unwrap();
231 assert_eq!(result, Value::Number(10.into()));
232 }
233
234 #[tokio::test]
235 async fn test_filter_operator() {
236 let op = FilterOperator::new("even_only", |v: &Value| {
237 v.as_i64().map(|n| n % 2 == 0).unwrap_or(false)
238 });
239
240 let ctx = Context {
241 data: HashMap::new(),
242 };
243
244 let result = op.process_item(Value::Number(4.into()), &ctx).await;
245 assert!(result.is_ok());
246
247 let result = op.process_item(Value::Number(5.into()), &ctx).await;
248 assert!(result.is_err());
249 }
250
251 #[tokio::test]
252 async fn test_fold_operator() {
253 let op = FoldOperator::new("sum", 0i64, |acc: i64, v: Value| {
254 acc + v.as_i64().unwrap_or(0)
255 });
256
257 let ctx = Context {
258 data: HashMap::new(),
259 };
260
261 let items: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
262 let results = op.process_chunk(items, &ctx).await.unwrap();
263
264 assert_eq!(results.len(), 1);
265 assert_eq!(results[0], Value::Number(15.into()));
266 }
267
268 #[tokio::test]
269 async fn test_async_map_operator() {
270 let op = AsyncMapOperator::new("async_double", |v: Value| async move {
271 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
272 if let Some(n) = v.as_i64() {
273 Value::Number((n * 2).into())
274 } else {
275 v
276 }
277 });
278
279 let ctx = Context {
280 data: HashMap::new(),
281 };
282
283 let result = op
284 .process_item(Value::Number(5.into()), &ctx)
285 .await
286 .unwrap();
287 assert_eq!(result, Value::Number(10.into()));
288 }
289}