Skip to main content

swarm_engine_core/pipeline/
sink.rs

1//! Event sink trait and implementations.
2
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Instant;
6
7use super::WatchEvent;
8use crate::error::SwarmError;
9use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
10
11/// Trait for event sinks (terminal processors).
12pub trait EventSink: Send {
13    /// Event type consumed by this sink.
14    type Event: Send;
15
16    /// Process an event.
17    fn process(
18        &mut self,
19        event: Self::Event,
20    ) -> impl std::future::Future<Output = Result<(), SwarmError>> + Send;
21}
22
23/// Learning sink - triggers offline learning when events arrive.
24///
25/// Uses `spawn_blocking` internally to run synchronous `LearningStore::run_offline_learning`
26/// without blocking the async runtime.
27///
28/// ## Trigger Integration
29///
30/// By default, learning runs on every event (AlwaysTrigger).
31/// Use `with_trigger()` to customize when learning runs:
32///
33/// ```ignore
34/// let sink = LearningSink::new(path, 20)
35///     .with_trigger(TriggerBuilder::every_n_episodes(10));
36/// ```
37pub struct LearningSink {
38    learning_path: Arc<PathBuf>,
39    max_sessions: usize,
40    /// Trigger for deciding when to run learning
41    trigger: Arc<dyn TrainTrigger>,
42    /// Number of events received since last training
43    event_count: usize,
44    /// Timestamp of last training (for TimeTrigger)
45    last_train_at: Option<Instant>,
46    /// Event count at last training
47    last_train_count: usize,
48}
49
50impl LearningSink {
51    /// Create a new learning sink.
52    ///
53    /// By default, runs learning on every event (AlwaysTrigger).
54    ///
55    /// # Arguments
56    /// * `learning_path` - Path to learning data directory
57    /// * `max_sessions` - Maximum sessions to analyze
58    pub fn new(learning_path: PathBuf, max_sessions: usize) -> Self {
59        Self {
60            learning_path: Arc::new(learning_path),
61            max_sessions,
62            trigger: Arc::new(AlwaysTrigger),
63            event_count: 0,
64            last_train_at: None,
65            last_train_count: 0,
66        }
67    }
68
69    /// Set a custom trigger for controlling when learning runs.
70    ///
71    /// # Example
72    /// ```ignore
73    /// use swarm_engine_core::learn::TriggerBuilder;
74    ///
75    /// // Run learning every 10 events
76    /// let sink = LearningSink::new(path, 20)
77    ///     .with_trigger(TriggerBuilder::every_n_episodes(10));
78    ///
79    /// // Run learning every 5 minutes OR when 50 events accumulated
80    /// let sink = LearningSink::new(path, 20)
81    ///     .with_trigger(Arc::new(OrTrigger::new(vec![
82    ///         TriggerBuilder::every_minutes(5),
83    ///         TriggerBuilder::every_n_episodes(50),
84    ///     ])));
85    /// ```
86    pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
87        self.trigger = trigger;
88        self
89    }
90
91    /// Get the learning path.
92    pub fn learning_path(&self) -> &PathBuf {
93        &self.learning_path
94    }
95
96    /// Get the current event count.
97    pub fn event_count(&self) -> usize {
98        self.event_count
99    }
100
101    /// Check if training should run based on the trigger.
102    fn should_train(&self) -> bool {
103        // Note: TimeTrigger uses Unix timestamp (ms), but we track Instant internally.
104        // For simplicity, we don't fully support TimeTrigger via LearningSink.
105        // Full TimeTrigger support requires LearnProcess with EpisodeStore.
106        // CountTrigger and AlwaysTrigger work correctly.
107        let ctx =
108            TriggerContext::with_count(self.event_count).last_train_count(self.last_train_count);
109
110        // Ignore errors - if trigger fails, don't train
111        self.trigger.should_train(&ctx).unwrap_or(false)
112    }
113
114    /// Mark that training was performed.
115    fn mark_trained(&mut self) {
116        self.last_train_at = Some(Instant::now());
117        self.last_train_count = self.event_count;
118    }
119}
120
121impl EventSink for LearningSink {
122    type Event = WatchEvent;
123
124    async fn process(&mut self, event: Self::Event) -> Result<(), SwarmError> {
125        // Increment event counter
126        self.event_count += 1;
127
128        // Check trigger condition
129        if !self.should_train() {
130            tracing::debug!(
131                scenario = %event.scenario,
132                event_count = self.event_count,
133                trigger = self.trigger.name(),
134                "Trigger not met, skipping learning"
135            );
136            return Ok(());
137        }
138
139        tracing::info!(
140            scenario = %event.scenario,
141            event_count = self.event_count,
142            trigger = self.trigger.name(),
143            "Trigger condition met, running offline learning"
144        );
145
146        let learning_path = Arc::clone(&self.learning_path);
147        let scenario = event.scenario.clone();
148        let max_sessions = self.max_sessions;
149
150        // Run synchronous LearningStore operations in a blocking task
151        let result = tokio::task::spawn_blocking(move || {
152            use crate::learn::LearningStore;
153
154            let store = LearningStore::new(&*learning_path)?;
155            store.run_offline_learning(&scenario, max_sessions)
156        })
157        .await;
158
159        match result {
160            Ok(Ok(model)) => {
161                tracing::info!(
162                    scenario = %event.scenario,
163                    sessions = model.analyzed_sessions,
164                    "Offline learning completed"
165                );
166                // Mark training completed
167                self.mark_trained();
168                Ok(())
169            }
170            Ok(Err(e)) => {
171                tracing::warn!(
172                    scenario = %event.scenario,
173                    error = %e,
174                    "Offline learning failed"
175                );
176                // Don't propagate - continue processing other events
177                // Don't mark as trained on failure
178                Ok(())
179            }
180            Err(e) => {
181                tracing::error!(
182                    scenario = %event.scenario,
183                    error = %e,
184                    "Blocking task panicked"
185                );
186                Ok(())
187            }
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use std::sync::atomic::{AtomicUsize, Ordering};
196
197    /// Counting sink for testing.
198    pub struct CountingSink {
199        count: Arc<AtomicUsize>,
200    }
201
202    impl CountingSink {
203        pub fn new() -> Self {
204            Self {
205                count: Arc::new(AtomicUsize::new(0)),
206            }
207        }
208
209        pub fn count(&self) -> usize {
210            self.count.load(Ordering::SeqCst)
211        }
212    }
213
214    impl EventSink for CountingSink {
215        type Event = WatchEvent;
216
217        async fn process(&mut self, _event: Self::Event) -> Result<(), SwarmError> {
218            self.count.fetch_add(1, Ordering::SeqCst);
219            Ok(())
220        }
221    }
222
223    #[tokio::test]
224    async fn test_counting_sink() {
225        let mut sink = CountingSink::new();
226        assert_eq!(sink.count(), 0);
227
228        sink.process(WatchEvent::new("test".into())).await.unwrap();
229        assert_eq!(sink.count(), 1);
230
231        sink.process(WatchEvent::new("test2".into())).await.unwrap();
232        assert_eq!(sink.count(), 2);
233    }
234
235    #[test]
236    fn test_learning_sink_creation() {
237        let sink = LearningSink::new(PathBuf::from("/tmp/test"), 20);
238        assert_eq!(sink.learning_path().to_str().unwrap(), "/tmp/test");
239    }
240}