rust_logic_graph/streaming/
operators.rs1use crate::core::Context;
4use crate::rule::{RuleResult, RuleError};
5use crate::streaming::StreamProcessor;
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9
10#[async_trait]
12pub trait StreamOperator: StreamProcessor {
13 fn name(&self) -> &str;
15}
16
17pub struct MapOperator<F>
19where
20 F: Fn(Value) -> Value + Send + Sync,
21{
22 pub name: String,
23 pub func: Arc<F>,
24}
25
26impl<F> MapOperator<F>
27where
28 F: Fn(Value) -> Value + Send + Sync,
29{
30 pub fn new(name: impl Into<String>, func: F) -> Self {
31 Self {
32 name: name.into(),
33 func: Arc::new(func),
34 }
35 }
36}
37
38#[async_trait]
39impl<F> StreamProcessor for MapOperator<F>
40where
41 F: Fn(Value) -> Value + Send + Sync,
42{
43 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
44 Ok((self.func)(item))
45 }
46}
47
48#[async_trait]
49impl<F> StreamOperator for MapOperator<F>
50where
51 F: Fn(Value) -> Value + Send + Sync,
52{
53 fn name(&self) -> &str {
54 &self.name
55 }
56}
57
58pub struct FilterOperator<F>
60where
61 F: Fn(&Value) -> bool + Send + Sync,
62{
63 pub name: String,
64 pub predicate: Arc<F>,
65}
66
67impl<F> FilterOperator<F>
68where
69 F: Fn(&Value) -> bool + Send + Sync,
70{
71 pub fn new(name: impl Into<String>, predicate: F) -> Self {
72 Self {
73 name: name.into(),
74 predicate: Arc::new(predicate),
75 }
76 }
77}
78
79#[async_trait]
80impl<F> StreamProcessor for FilterOperator<F>
81where
82 F: Fn(&Value) -> bool + Send + Sync,
83{
84 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
85 if (self.predicate)(&item) {
86 Ok(item)
87 } else {
88 Err(RuleError::Eval("Filtered out".to_string()))
89 }
90 }
91}
92
93#[async_trait]
94impl<F> StreamOperator for FilterOperator<F>
95where
96 F: Fn(&Value) -> bool + Send + Sync,
97{
98 fn name(&self) -> &str {
99 &self.name
100 }
101}
102
103pub struct FoldOperator<F, T>
105where
106 F: Fn(T, Value) -> T + Send + Sync,
107 T: Clone + Send + Sync + 'static,
108{
109 pub name: String,
110 pub initial: T,
111 pub func: Arc<F>,
112}
113
114impl<F, T> FoldOperator<F, T>
115where
116 F: Fn(T, Value) -> T + Send + Sync,
117 T: Clone + Send + Sync + 'static,
118{
119 pub fn new(name: impl Into<String>, initial: T, func: F) -> Self {
120 Self {
121 name: name.into(),
122 initial,
123 func: Arc::new(func),
124 }
125 }
126}
127
128#[async_trait]
129impl<F, T> StreamProcessor for FoldOperator<F, T>
130where
131 F: Fn(T, Value) -> T + Send + Sync,
132 T: Clone + Send + Sync + Into<Value> + 'static,
133{
134 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
135 Ok(item)
137 }
138
139 async fn process_chunk(&self, items: Vec<Value>, _ctx: &Context) -> Result<Vec<Value>, RuleError> {
140 let result = items.into_iter().fold(self.initial.clone(), |acc, item| {
141 (self.func)(acc, item)
142 });
143
144 Ok(vec![result.into()])
145 }
146}
147
148#[async_trait]
149impl<F, T> StreamOperator for FoldOperator<F, T>
150where
151 F: Fn(T, Value) -> T + Send + Sync,
152 T: Clone + Send + Sync + Into<Value> + 'static,
153{
154 fn name(&self) -> &str {
155 &self.name
156 }
157}
158
159pub struct AsyncMapOperator<F, Fut>
161where
162 F: Fn(Value) -> Fut + Send + Sync,
163 Fut: std::future::Future<Output = Value> + Send,
164{
165 pub name: String,
166 pub func: Arc<F>,
167}
168
169impl<F, Fut> AsyncMapOperator<F, Fut>
170where
171 F: Fn(Value) -> Fut + Send + Sync,
172 Fut: std::future::Future<Output = Value> + Send,
173{
174 pub fn new(name: impl Into<String>, func: F) -> Self {
175 Self {
176 name: name.into(),
177 func: Arc::new(func),
178 }
179 }
180}
181
182#[async_trait]
183impl<F, Fut> StreamProcessor for AsyncMapOperator<F, Fut>
184where
185 F: Fn(Value) -> Fut + Send + Sync,
186 Fut: std::future::Future<Output = Value> + Send,
187{
188 async fn process_item(&self, item: Value, _ctx: &Context) -> RuleResult {
189 Ok((self.func)(item).await)
190 }
191}
192
193#[async_trait]
194impl<F, Fut> StreamOperator for AsyncMapOperator<F, Fut>
195where
196 F: Fn(Value) -> Fut + Send + Sync,
197 Fut: std::future::Future<Output = Value> + Send,
198{
199 fn name(&self) -> &str {
200 &self.name
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use std::collections::HashMap;
208
209 #[tokio::test]
210 async fn test_map_operator() {
211 let op = MapOperator::new("double", |v: Value| {
212 if let Some(n) = v.as_i64() {
213 Value::Number((n * 2).into())
214 } else {
215 v
216 }
217 });
218
219 let ctx = Context {
220 data: HashMap::new(),
221 };
222
223 let result = op.process_item(Value::Number(5.into()), &ctx).await.unwrap();
224 assert_eq!(result, Value::Number(10.into()));
225 }
226
227 #[tokio::test]
228 async fn test_filter_operator() {
229 let op = FilterOperator::new("even_only", |v: &Value| {
230 v.as_i64().map(|n| n % 2 == 0).unwrap_or(false)
231 });
232
233 let ctx = Context {
234 data: HashMap::new(),
235 };
236
237 let result = op.process_item(Value::Number(4.into()), &ctx).await;
238 assert!(result.is_ok());
239
240 let result = op.process_item(Value::Number(5.into()), &ctx).await;
241 assert!(result.is_err());
242 }
243
244 #[tokio::test]
245 async fn test_fold_operator() {
246 let op = FoldOperator::new("sum", 0i64, |acc: i64, v: Value| {
247 acc + v.as_i64().unwrap_or(0)
248 });
249
250 let ctx = Context {
251 data: HashMap::new(),
252 };
253
254 let items: Vec<Value> = (1..=5).map(|i| Value::Number(i.into())).collect();
255 let results = op.process_chunk(items, &ctx).await.unwrap();
256
257 assert_eq!(results.len(), 1);
258 assert_eq!(results[0], Value::Number(15.into()));
259 }
260
261 #[tokio::test]
262 async fn test_async_map_operator() {
263 let op = AsyncMapOperator::new("async_double", |v: Value| async move {
264 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
265 if let Some(n) = v.as_i64() {
266 Value::Number((n * 2).into())
267 } else {
268 v
269 }
270 });
271
272 let ctx = Context {
273 data: HashMap::new(),
274 };
275
276 let result = op.process_item(Value::Number(5.into()), &ctx).await.unwrap();
277 assert_eq!(result, Value::Number(10.into()));
278 }
279}