Skip to main content

this/events/sinks/
counter.rs

1//! Counter sink — updates numeric fields on entities
2//!
3//! Increments, decrements, or sets a numeric field on an entity.
4//! Useful for maintaining derived counters like `follower_count`,
5//! `like_count`, etc. in response to events.
6//!
7//! ```yaml
8//! sinks:
9//!   - name: like-counter
10//!     type: counter
11//!     config:
12//!       field: like_count
13//!       operation: increment
14//! ```
15//!
16//! # Payload format
17//!
18//! The payload (from `map` operator) must include:
19//! - `entity_type`: Type of the entity to update
20//! - `entity_id`: ID of the entity to update
21//!
22//! The field name and operation come from the sink configuration
23//! or can be overridden in the payload:
24//! - `field`: Name of the numeric field (default from config)
25//! - `operation`: "increment", "decrement", or "set" (default from config)
26//! - `value`: Amount to increment/decrement or value to set (default: 1)
27
28use crate::config::sinks::SinkType;
29use crate::events::sinks::Sink;
30use anyhow::{Result, anyhow};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::sync::Arc;
35use tokio::sync::{Mutex, RwLock};
36
37/// Counter operations
38#[derive(Debug, Clone, PartialEq)]
39pub enum CounterOperation {
40    /// Add a value to the current count
41    Increment,
42    /// Subtract a value from the current count
43    Decrement,
44    /// Set the counter to an absolute value
45    Set,
46}
47
48impl CounterOperation {
49    /// Parse from a string
50    pub fn parse(s: &str) -> Result<Self> {
51        match s {
52            "increment" | "inc" | "add" => Ok(Self::Increment),
53            "decrement" | "dec" | "sub" | "subtract" => Ok(Self::Decrement),
54            "set" => Ok(Self::Set),
55            _ => Err(anyhow!(
56                "invalid counter operation '{}': expected 'increment', 'decrement', or 'set'",
57                s
58            )),
59        }
60    }
61
62    /// Apply the operation to a current value
63    pub fn apply(&self, current: f64, amount: f64) -> f64 {
64        match self {
65            Self::Increment => current + amount,
66            Self::Decrement => (current - amount).max(0.0), // Never go negative
67            Self::Set => amount,
68        }
69    }
70}
71
72/// Trait for reading and updating entity fields
73///
74/// Abstracts the entity storage so the counter sink can work
75/// without depending on the server layer.
76#[async_trait]
77pub trait EntityFieldUpdater: Send + Sync + std::fmt::Debug {
78    /// Read a numeric field from an entity
79    ///
80    /// Returns the current field value, or 0.0 if the field doesn't exist.
81    async fn read_field(&self, entity_type: &str, entity_id: &str, field: &str) -> Result<f64>;
82
83    /// Write a numeric field to an entity
84    async fn write_field(
85        &self,
86        entity_type: &str,
87        entity_id: &str,
88        field: &str,
89        value: f64,
90    ) -> Result<()>;
91}
92
93/// Counter sink configuration
94#[derive(Debug, Clone)]
95pub struct CounterConfig {
96    /// Default field name to update
97    pub field: String,
98
99    /// Default operation
100    pub operation: CounterOperation,
101}
102
103/// Counter notification sink
104///
105/// Updates numeric fields on entities. Used for maintaining derived
106/// counters (follower_count, like_count, etc.) in response to events.
107///
108/// Uses per-key locks to ensure atomic read-modify-write operations,
109/// preventing TOCTOU race conditions under concurrent access.
110#[derive(Debug)]
111pub struct CounterSink {
112    /// Default counter configuration
113    config: CounterConfig,
114
115    /// Entity field updater
116    updater: Arc<dyn EntityFieldUpdater>,
117
118    /// Per-key locks for atomic read-modify-write
119    /// Key format: "entity_type:entity_id:field"
120    key_locks: RwLock<HashMap<String, Arc<Mutex<()>>>>,
121}
122
123impl CounterSink {
124    /// Create a new CounterSink
125    pub fn new(updater: Arc<dyn EntityFieldUpdater>, config: CounterConfig) -> Self {
126        Self {
127            config,
128            updater,
129            key_locks: RwLock::new(HashMap::new()),
130        }
131    }
132
133    /// Get or create a lock for the given key
134    async fn get_lock(&self, key: &str) -> Arc<Mutex<()>> {
135        // Fast path: check if lock already exists (read lock)
136        {
137            let locks = self.key_locks.read().await;
138            if let Some(lock) = locks.get(key) {
139                return lock.clone();
140            }
141        }
142
143        // Slow path: create the lock (write lock)
144        let mut locks = self.key_locks.write().await;
145        // Double-check after acquiring write lock
146        locks
147            .entry(key.to_string())
148            .or_insert_with(|| Arc::new(Mutex::new(())))
149            .clone()
150    }
151}
152
153#[async_trait]
154impl Sink for CounterSink {
155    async fn deliver(
156        &self,
157        payload: Value,
158        _recipient_id: Option<&str>,
159        context_vars: &HashMap<String, Value>,
160    ) -> Result<()> {
161        // Extract entity_type and entity_id from payload or context
162        let entity_type = payload
163            .get("entity_type")
164            .and_then(|v| v.as_str())
165            .or_else(|| context_vars.get("entity_type").and_then(|v| v.as_str()))
166            .ok_or_else(|| anyhow!("counter sink: entity_type not found in payload or context"))?
167            .to_string();
168
169        let entity_id = payload
170            .get("entity_id")
171            .and_then(|v| v.as_str())
172            .or_else(|| context_vars.get("entity_id").and_then(|v| v.as_str()))
173            .ok_or_else(|| anyhow!("counter sink: entity_id not found in payload or context"))?
174            .to_string();
175
176        // Field name: payload overrides config default
177        let field = payload
178            .get("field")
179            .and_then(|v| v.as_str())
180            .unwrap_or(&self.config.field)
181            .to_string();
182
183        // Operation: payload overrides config default
184        let operation = if let Some(op_str) = payload.get("operation").and_then(|v| v.as_str()) {
185            CounterOperation::parse(op_str)?
186        } else {
187            self.config.operation.clone()
188        };
189
190        // Value: default 1
191        let amount = payload.get("value").and_then(|v| v.as_f64()).unwrap_or(1.0);
192
193        // Acquire per-key lock for atomic read-modify-write
194        let lock_key = format!("{}:{}:{}", entity_type, entity_id, field);
195        let lock = self.get_lock(&lock_key).await;
196        let _guard = lock.lock().await;
197
198        // Read current value
199        let current = self
200            .updater
201            .read_field(&entity_type, &entity_id, &field)
202            .await?;
203
204        // Apply operation
205        let new_value = operation.apply(current, amount);
206
207        tracing::debug!(
208            entity_type = %entity_type,
209            entity_id = %entity_id,
210            field = %field,
211            current = current,
212            operation = ?operation,
213            amount = amount,
214            new_value = new_value,
215            "counter sink: updating field"
216        );
217
218        // Write new value
219        self.updater
220            .write_field(&entity_type, &entity_id, &field, new_value)
221            .await?;
222
223        Ok(())
224    }
225
226    fn name(&self) -> &str {
227        "counter"
228    }
229
230    fn sink_type(&self) -> SinkType {
231        SinkType::Counter
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use serde_json::json;
239    use tokio::sync::RwLock;
240
241    /// Mock entity storage
242    #[derive(Debug)]
243    struct MockEntityStore {
244        /// Fields keyed by "entity_type:entity_id:field"
245        fields: RwLock<HashMap<String, f64>>,
246    }
247
248    impl MockEntityStore {
249        fn new() -> Self {
250            Self {
251                fields: RwLock::new(HashMap::new()),
252            }
253        }
254
255        fn key(entity_type: &str, entity_id: &str, field: &str) -> String {
256            format!("{}:{}:{}", entity_type, entity_id, field)
257        }
258
259        async fn set(&self, entity_type: &str, entity_id: &str, field: &str, value: f64) {
260            self.fields
261                .write()
262                .await
263                .insert(Self::key(entity_type, entity_id, field), value);
264        }
265    }
266
267    #[async_trait]
268    impl EntityFieldUpdater for MockEntityStore {
269        async fn read_field(&self, entity_type: &str, entity_id: &str, field: &str) -> Result<f64> {
270            let store = self.fields.read().await;
271            Ok(*store
272                .get(&Self::key(entity_type, entity_id, field))
273                .unwrap_or(&0.0))
274        }
275
276        async fn write_field(
277            &self,
278            entity_type: &str,
279            entity_id: &str,
280            field: &str,
281            value: f64,
282        ) -> Result<()> {
283            self.fields
284                .write()
285                .await
286                .insert(Self::key(entity_type, entity_id, field), value);
287            Ok(())
288        }
289    }
290
291    fn increment_config(field: &str) -> CounterConfig {
292        CounterConfig {
293            field: field.to_string(),
294            operation: CounterOperation::Increment,
295        }
296    }
297
298    #[tokio::test]
299    async fn test_counter_increment() {
300        let store = Arc::new(MockEntityStore::new());
301        store.set("capture", "cap-1", "like_count", 5.0).await;
302
303        let sink = CounterSink::new(store.clone(), increment_config("like_count"));
304
305        let payload = json!({
306            "entity_type": "capture",
307            "entity_id": "cap-1"
308        });
309
310        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
311
312        let value = store
313            .read_field("capture", "cap-1", "like_count")
314            .await
315            .unwrap();
316        assert_eq!(value, 6.0);
317    }
318
319    #[tokio::test]
320    async fn test_counter_increment_from_zero() {
321        let store = Arc::new(MockEntityStore::new());
322        let sink = CounterSink::new(store.clone(), increment_config("like_count"));
323
324        let payload = json!({
325            "entity_type": "capture",
326            "entity_id": "cap-1"
327        });
328
329        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
330
331        let value = store
332            .read_field("capture", "cap-1", "like_count")
333            .await
334            .unwrap();
335        assert_eq!(value, 1.0);
336    }
337
338    #[tokio::test]
339    async fn test_counter_decrement() {
340        let store = Arc::new(MockEntityStore::new());
341        store.set("capture", "cap-1", "like_count", 5.0).await;
342
343        let sink = CounterSink::new(
344            store.clone(),
345            CounterConfig {
346                field: "like_count".to_string(),
347                operation: CounterOperation::Decrement,
348            },
349        );
350
351        let payload = json!({
352            "entity_type": "capture",
353            "entity_id": "cap-1"
354        });
355
356        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
357
358        let value = store
359            .read_field("capture", "cap-1", "like_count")
360            .await
361            .unwrap();
362        assert_eq!(value, 4.0);
363    }
364
365    #[tokio::test]
366    async fn test_counter_decrement_floor_at_zero() {
367        let store = Arc::new(MockEntityStore::new());
368        store.set("capture", "cap-1", "like_count", 0.0).await;
369
370        let sink = CounterSink::new(
371            store.clone(),
372            CounterConfig {
373                field: "like_count".to_string(),
374                operation: CounterOperation::Decrement,
375            },
376        );
377
378        let payload = json!({
379            "entity_type": "capture",
380            "entity_id": "cap-1"
381        });
382
383        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
384
385        let value = store
386            .read_field("capture", "cap-1", "like_count")
387            .await
388            .unwrap();
389        assert_eq!(value, 0.0); // Never goes negative
390    }
391
392    #[tokio::test]
393    async fn test_counter_set() {
394        let store = Arc::new(MockEntityStore::new());
395        store.set("capture", "cap-1", "like_count", 5.0).await;
396
397        let sink = CounterSink::new(
398            store.clone(),
399            CounterConfig {
400                field: "like_count".to_string(),
401                operation: CounterOperation::Set,
402            },
403        );
404
405        let payload = json!({
406            "entity_type": "capture",
407            "entity_id": "cap-1",
408            "value": 42
409        });
410
411        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
412
413        let value = store
414            .read_field("capture", "cap-1", "like_count")
415            .await
416            .unwrap();
417        assert_eq!(value, 42.0);
418    }
419
420    #[tokio::test]
421    async fn test_counter_custom_amount() {
422        let store = Arc::new(MockEntityStore::new());
423        store.set("user", "u-1", "follower_count", 10.0).await;
424
425        let sink = CounterSink::new(store.clone(), increment_config("follower_count"));
426
427        let payload = json!({
428            "entity_type": "user",
429            "entity_id": "u-1",
430            "value": 5
431        });
432
433        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
434
435        let value = store
436            .read_field("user", "u-1", "follower_count")
437            .await
438            .unwrap();
439        assert_eq!(value, 15.0);
440    }
441
442    #[tokio::test]
443    async fn test_counter_override_field_and_operation() {
444        let store = Arc::new(MockEntityStore::new());
445        store.set("user", "u-1", "comment_count", 3.0).await;
446
447        // Config says increment like_count, but payload overrides both
448        let sink = CounterSink::new(store.clone(), increment_config("like_count"));
449
450        let payload = json!({
451            "entity_type": "user",
452            "entity_id": "u-1",
453            "field": "comment_count",
454            "operation": "decrement"
455        });
456
457        sink.deliver(payload, None, &HashMap::new()).await.unwrap();
458
459        let value = store
460            .read_field("user", "u-1", "comment_count")
461            .await
462            .unwrap();
463        assert_eq!(value, 2.0);
464    }
465
466    #[tokio::test]
467    async fn test_counter_entity_from_context() {
468        let store = Arc::new(MockEntityStore::new());
469        store.set("capture", "cap-1", "like_count", 0.0).await;
470
471        let sink = CounterSink::new(store.clone(), increment_config("like_count"));
472
473        let payload = json!({}); // No entity info in payload
474
475        let mut vars = HashMap::new();
476        vars.insert(
477            "entity_type".to_string(),
478            Value::String("capture".to_string()),
479        );
480        vars.insert("entity_id".to_string(), Value::String("cap-1".to_string()));
481
482        sink.deliver(payload, None, &vars).await.unwrap();
483
484        let value = store
485            .read_field("capture", "cap-1", "like_count")
486            .await
487            .unwrap();
488        assert_eq!(value, 1.0);
489    }
490
491    #[tokio::test]
492    async fn test_counter_missing_entity_type_error() {
493        let store = Arc::new(MockEntityStore::new());
494        let sink = CounterSink::new(store, increment_config("like_count"));
495
496        let payload = json!({"entity_id": "cap-1"});
497        let result = sink.deliver(payload, None, &HashMap::new()).await;
498        assert!(result.is_err());
499        assert!(result.unwrap_err().to_string().contains("entity_type"));
500    }
501
502    #[tokio::test]
503    async fn test_counter_missing_entity_id_error() {
504        let store = Arc::new(MockEntityStore::new());
505        let sink = CounterSink::new(store, increment_config("like_count"));
506
507        let payload = json!({"entity_type": "capture"});
508        let result = sink.deliver(payload, None, &HashMap::new()).await;
509        assert!(result.is_err());
510        assert!(result.unwrap_err().to_string().contains("entity_id"));
511    }
512
513    #[test]
514    fn test_counter_operation_parse() {
515        assert_eq!(
516            CounterOperation::parse("increment").unwrap(),
517            CounterOperation::Increment
518        );
519        assert_eq!(
520            CounterOperation::parse("inc").unwrap(),
521            CounterOperation::Increment
522        );
523        assert_eq!(
524            CounterOperation::parse("decrement").unwrap(),
525            CounterOperation::Decrement
526        );
527        assert_eq!(
528            CounterOperation::parse("dec").unwrap(),
529            CounterOperation::Decrement
530        );
531        assert_eq!(
532            CounterOperation::parse("set").unwrap(),
533            CounterOperation::Set
534        );
535        assert!(CounterOperation::parse("invalid").is_err());
536    }
537
538    #[test]
539    fn test_counter_sink_name_and_type() {
540        let store = Arc::new(MockEntityStore::new());
541        let sink = CounterSink::new(store, increment_config("like_count"));
542        assert_eq!(sink.name(), "counter");
543        assert_eq!(sink.sink_type(), SinkType::Counter);
544    }
545
546    #[tokio::test]
547    async fn test_counter_concurrent_increments() {
548        let store = Arc::new(MockEntityStore::new());
549        store.set("capture", "cap-1", "like_count", 0.0).await;
550
551        let sink = Arc::new(CounterSink::new(
552            store.clone(),
553            increment_config("like_count"),
554        ));
555
556        // Spawn 50 concurrent increment tasks
557        let mut handles = Vec::new();
558        for _ in 0..50 {
559            let sink = sink.clone();
560            handles.push(tokio::spawn(async move {
561                let payload = json!({
562                    "entity_type": "capture",
563                    "entity_id": "cap-1"
564                });
565                sink.deliver(payload, None, &HashMap::new()).await.unwrap();
566            }));
567        }
568
569        for h in handles {
570            h.await.unwrap();
571        }
572
573        // Without per-key locks, this would be less than 50 due to TOCTOU
574        let value = store
575            .read_field("capture", "cap-1", "like_count")
576            .await
577            .unwrap();
578        assert_eq!(
579            value, 50.0,
580            "All 50 increments should be applied atomically"
581        );
582    }
583}