1use crate::config::events::BatchConfig;
24use crate::events::context::FlowContext;
25use crate::events::operators::deduplicate::parse_duration;
26use crate::events::operators::{OpResult, PipelineOperator};
27use anyhow::{Result, anyhow};
28use async_trait::async_trait;
29use serde_json::json;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::time::{Duration, Instant};
33use tokio::sync::RwLock;
34
35#[derive(Debug, Clone)]
37struct BatchBucket {
38 items: Vec<String>,
40 started_at: Instant,
42}
43
44#[derive(Debug)]
46pub struct BatchOp {
47 key: String,
49
50 window: Duration,
52
53 min_count: u32,
55
56 buckets: Arc<RwLock<HashMap<String, BatchBucket>>>,
58}
59
60impl BatchOp {
61 pub fn from_config(config: &BatchConfig) -> Result<Self> {
63 let window = parse_duration(&config.window)?;
64 Ok(Self {
65 key: config.key.clone(),
66 window,
67 min_count: config.min_count,
68 buckets: Arc::new(RwLock::new(HashMap::new())),
69 })
70 }
71
72 #[cfg(test)]
74 fn with_params(key: &str, window: Duration, min_count: u32) -> Self {
75 Self {
76 key: key.to_string(),
77 window,
78 min_count,
79 buckets: Arc::new(RwLock::new(HashMap::new())),
80 }
81 }
82}
83
84#[async_trait]
85impl PipelineOperator for BatchOp {
86 async fn execute(&self, ctx: &mut FlowContext) -> Result<OpResult> {
87 let key_value = ctx
89 .get_var(&self.key)
90 .ok_or_else(|| anyhow!("batch: variable '{}' not found in context", self.key))?
91 .clone();
92
93 let key_str = value_to_string(&key_value);
94
95 let item_value = ctx
97 .get_var("source_id")
98 .or_else(|| ctx.get_var("entity_id"))
99 .map(value_to_string)
100 .unwrap_or_default();
101
102 let now = Instant::now();
103 let mut buckets = self.buckets.write().await;
104
105 let (should_flush, should_discard) = if let Some(bucket) = buckets.get(&key_str) {
107 let window_expired = now.duration_since(bucket.started_at) >= self.window;
108 if window_expired && bucket.items.len() as u32 >= self.min_count {
109 (true, false)
110 } else if window_expired {
111 (false, true)
112 } else {
113 (false, false)
114 }
115 } else {
116 (false, false)
117 };
118
119 if should_flush {
120 let bucket = buckets.remove(&key_str).unwrap();
122 let count = bucket.items.len();
123
124 ctx.set_var(
125 "_batch",
126 json!({
127 "count": count,
128 "key": key_str,
129 "items": bucket.items,
130 }),
131 );
132
133 Ok(OpResult::Continue)
134 } else if should_discard {
135 buckets.remove(&key_str);
137 Ok(OpResult::Drop)
138 } else {
139 let bucket = buckets.entry(key_str).or_insert_with(|| BatchBucket {
141 items: Vec::new(),
142 started_at: now,
143 });
144 bucket.items.push(item_value);
145 Ok(OpResult::Drop)
146 }
147 }
148
149 fn name(&self) -> &str {
150 "batch"
151 }
152}
153
154fn value_to_string(value: &serde_json::Value) -> String {
156 match value {
157 serde_json::Value::String(s) => s.clone(),
158 serde_json::Value::Number(n) => n.to_string(),
159 serde_json::Value::Bool(b) => b.to_string(),
160 serde_json::Value::Null => "null".to_string(),
161 other => other.to_string(),
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::core::events::{FrameworkEvent, LinkEvent};
169 use crate::core::service::LinkService;
170 use std::collections::HashMap as StdHashMap;
171 use std::sync::Arc;
172 use uuid::Uuid;
173
174 struct MockLinkService;
175
176 #[async_trait]
177 impl LinkService for MockLinkService {
178 async fn create(
179 &self,
180 _: crate::core::link::LinkEntity,
181 ) -> Result<crate::core::link::LinkEntity> {
182 unimplemented!()
183 }
184 async fn get(&self, _: &Uuid) -> Result<Option<crate::core::link::LinkEntity>> {
185 unimplemented!()
186 }
187 async fn list(&self) -> Result<Vec<crate::core::link::LinkEntity>> {
188 unimplemented!()
189 }
190 async fn find_by_source(
191 &self,
192 _: &Uuid,
193 _: Option<&str>,
194 _: Option<&str>,
195 ) -> Result<Vec<crate::core::link::LinkEntity>> {
196 unimplemented!()
197 }
198 async fn find_by_target(
199 &self,
200 _: &Uuid,
201 _: Option<&str>,
202 _: Option<&str>,
203 ) -> Result<Vec<crate::core::link::LinkEntity>> {
204 unimplemented!()
205 }
206 async fn update(
207 &self,
208 _: &Uuid,
209 _: crate::core::link::LinkEntity,
210 ) -> Result<crate::core::link::LinkEntity> {
211 unimplemented!()
212 }
213 async fn delete(&self, _: &Uuid) -> Result<()> {
214 unimplemented!()
215 }
216 async fn delete_by_entity(&self, _: &Uuid) -> Result<()> {
217 unimplemented!()
218 }
219 }
220
221 fn make_link_context(source_id: Uuid, target_id: Uuid) -> FlowContext {
222 let event = FrameworkEvent::Link(LinkEvent::Created {
223 link_type: "likes".to_string(),
224 link_id: Uuid::new_v4(),
225 source_id,
226 target_id,
227 metadata: None,
228 });
229 FlowContext::new(
230 event,
231 Arc::new(MockLinkService) as Arc<dyn LinkService>,
232 StdHashMap::new(),
233 )
234 }
235
236 #[tokio::test]
237 async fn test_batch_accumulates_within_window() {
238 let target_id = Uuid::new_v4();
239 let op = BatchOp::with_params("target_id", Duration::from_secs(60), 1);
240
241 let mut ctx1 = make_link_context(Uuid::new_v4(), target_id);
243 let result1 = op.execute(&mut ctx1).await.unwrap();
244 assert!(matches!(result1, OpResult::Drop));
245
246 let mut ctx2 = make_link_context(Uuid::new_v4(), target_id);
248 let result2 = op.execute(&mut ctx2).await.unwrap();
249 assert!(matches!(result2, OpResult::Drop));
250 }
251
252 #[tokio::test]
253 async fn test_batch_flushes_after_window() {
254 let target_id = Uuid::new_v4();
255 let op = BatchOp::with_params("target_id", Duration::from_millis(50), 1);
256
257 for _ in 0..3 {
259 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
260 let _ = op.execute(&mut ctx).await.unwrap();
261 }
262
263 tokio::time::sleep(Duration::from_millis(60)).await;
265
266 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
268 let result = op.execute(&mut ctx).await.unwrap();
269 assert!(matches!(result, OpResult::Continue));
270
271 let batch = ctx.get_var("_batch").unwrap();
273 assert_eq!(batch["count"], 3);
274 assert_eq!(batch["key"], target_id.to_string());
275 assert_eq!(batch["items"].as_array().unwrap().len(), 3);
276 }
277
278 #[tokio::test]
279 async fn test_batch_min_count_not_met() {
280 let target_id = Uuid::new_v4();
281 let op = BatchOp::with_params("target_id", Duration::from_millis(50), 5);
283
284 for _ in 0..2 {
286 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
287 let _ = op.execute(&mut ctx).await.unwrap();
288 }
289
290 tokio::time::sleep(Duration::from_millis(60)).await;
292
293 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
295 let result = op.execute(&mut ctx).await.unwrap();
296 assert!(matches!(result, OpResult::Drop));
297 }
298
299 #[tokio::test]
300 async fn test_batch_different_keys_independent() {
301 let target_a = Uuid::new_v4();
302 let target_b = Uuid::new_v4();
303 let op = BatchOp::with_params("target_id", Duration::from_millis(50), 1);
304
305 let mut ctx_a = make_link_context(Uuid::new_v4(), target_a);
307 let _ = op.execute(&mut ctx_a).await.unwrap();
308
309 let mut ctx_b = make_link_context(Uuid::new_v4(), target_b);
311 let _ = op.execute(&mut ctx_b).await.unwrap();
312
313 tokio::time::sleep(Duration::from_millis(60)).await;
315
316 let mut ctx_a2 = make_link_context(Uuid::new_v4(), target_a);
318 let result_a = op.execute(&mut ctx_a2).await.unwrap();
319 assert!(matches!(result_a, OpResult::Continue));
320 assert_eq!(ctx_a2.get_var("_batch").unwrap()["count"], 1);
321
322 let mut ctx_b2 = make_link_context(Uuid::new_v4(), target_b);
324 let result_b = op.execute(&mut ctx_b2).await.unwrap();
325 assert!(matches!(result_b, OpResult::Continue));
326 assert_eq!(ctx_b2.get_var("_batch").unwrap()["count"], 1);
327 }
328
329 #[tokio::test]
330 async fn test_batch_missing_key_errors() {
331 let op = BatchOp::with_params("nonexistent", Duration::from_secs(60), 1);
332 let mut ctx = make_link_context(Uuid::new_v4(), Uuid::new_v4());
333
334 let result = op.execute(&mut ctx).await;
335 assert!(result.is_err());
336 }
337
338 #[tokio::test]
339 async fn test_buckets_cleaned_after_flush() {
340 let target_id = Uuid::new_v4();
341 let op = BatchOp::with_params("target_id", Duration::from_millis(50), 1);
342
343 for _ in 0..2 {
345 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
346 let _ = op.execute(&mut ctx).await.unwrap();
347 }
348
349 assert_eq!(op.buckets.read().await.len(), 1);
351
352 tokio::time::sleep(Duration::from_millis(60)).await;
354
355 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
357 let result = op.execute(&mut ctx).await.unwrap();
358 assert!(matches!(result, OpResult::Continue));
359
360 assert_eq!(op.buckets.read().await.len(), 0);
362 }
363
364 #[tokio::test]
365 async fn test_buckets_cleaned_after_expired_min_count_not_met() {
366 let target_id = Uuid::new_v4();
367 let op = BatchOp::with_params("target_id", Duration::from_millis(50), 10);
369
370 for _ in 0..2 {
372 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
373 let _ = op.execute(&mut ctx).await.unwrap();
374 }
375
376 assert_eq!(op.buckets.read().await.len(), 1);
377
378 tokio::time::sleep(Duration::from_millis(60)).await;
380
381 let mut ctx = make_link_context(Uuid::new_v4(), target_id);
383 let result = op.execute(&mut ctx).await.unwrap();
384 assert!(matches!(result, OpResult::Drop));
385
386 assert_eq!(op.buckets.read().await.len(), 0);
388 }
389
390 #[tokio::test]
391 async fn test_multiple_keys_independent_cleanup() {
392 let target_a = Uuid::new_v4();
393 let target_b = Uuid::new_v4();
394 let op = BatchOp::with_params("target_id", Duration::from_millis(50), 1);
395
396 let mut ctx_a = make_link_context(Uuid::new_v4(), target_a);
398 let _ = op.execute(&mut ctx_a).await.unwrap();
399 let mut ctx_b = make_link_context(Uuid::new_v4(), target_b);
400 let _ = op.execute(&mut ctx_b).await.unwrap();
401
402 assert_eq!(op.buckets.read().await.len(), 2);
403
404 tokio::time::sleep(Duration::from_millis(60)).await;
406
407 let mut ctx_a2 = make_link_context(Uuid::new_v4(), target_a);
409 let result_a = op.execute(&mut ctx_a2).await.unwrap();
410 assert!(matches!(result_a, OpResult::Continue));
411
412 assert_eq!(op.buckets.read().await.len(), 1);
414 assert!(!op.buckets.read().await.contains_key(&target_a.to_string()));
415 assert!(op.buckets.read().await.contains_key(&target_b.to_string()));
416 }
417}