swarm_engine_core/pipeline/
sink.rs1use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Instant;
6
7use super::WatchEvent;
8use crate::error::SwarmError;
9use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
10
11pub trait EventSink: Send {
13 type Event: Send;
15
16 fn process(
18 &mut self,
19 event: Self::Event,
20 ) -> impl std::future::Future<Output = Result<(), SwarmError>> + Send;
21}
22
23pub struct LearningSink {
38 learning_path: Arc<PathBuf>,
39 max_sessions: usize,
40 trigger: Arc<dyn TrainTrigger>,
42 event_count: usize,
44 last_train_at: Option<Instant>,
46 last_train_count: usize,
48}
49
50impl LearningSink {
51 pub fn new(learning_path: PathBuf, max_sessions: usize) -> Self {
59 Self {
60 learning_path: Arc::new(learning_path),
61 max_sessions,
62 trigger: Arc::new(AlwaysTrigger),
63 event_count: 0,
64 last_train_at: None,
65 last_train_count: 0,
66 }
67 }
68
69 pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
87 self.trigger = trigger;
88 self
89 }
90
91 pub fn learning_path(&self) -> &PathBuf {
93 &self.learning_path
94 }
95
96 pub fn event_count(&self) -> usize {
98 self.event_count
99 }
100
101 fn should_train(&self) -> bool {
103 let ctx =
108 TriggerContext::with_count(self.event_count).last_train_count(self.last_train_count);
109
110 self.trigger.should_train(&ctx).unwrap_or(false)
112 }
113
114 fn mark_trained(&mut self) {
116 self.last_train_at = Some(Instant::now());
117 self.last_train_count = self.event_count;
118 }
119}
120
121impl EventSink for LearningSink {
122 type Event = WatchEvent;
123
124 async fn process(&mut self, event: Self::Event) -> Result<(), SwarmError> {
125 self.event_count += 1;
127
128 if !self.should_train() {
130 tracing::debug!(
131 scenario = %event.scenario,
132 event_count = self.event_count,
133 trigger = self.trigger.name(),
134 "Trigger not met, skipping learning"
135 );
136 return Ok(());
137 }
138
139 tracing::info!(
140 scenario = %event.scenario,
141 event_count = self.event_count,
142 trigger = self.trigger.name(),
143 "Trigger condition met, running offline learning"
144 );
145
146 let learning_path = Arc::clone(&self.learning_path);
147 let scenario = event.scenario.clone();
148 let max_sessions = self.max_sessions;
149
150 let result = tokio::task::spawn_blocking(move || {
152 use crate::learn::LearningStore;
153
154 let store = LearningStore::new(&*learning_path)?;
155 store.run_offline_learning(&scenario, max_sessions)
156 })
157 .await;
158
159 match result {
160 Ok(Ok(model)) => {
161 tracing::info!(
162 scenario = %event.scenario,
163 sessions = model.analyzed_sessions,
164 "Offline learning completed"
165 );
166 self.mark_trained();
168 Ok(())
169 }
170 Ok(Err(e)) => {
171 tracing::warn!(
172 scenario = %event.scenario,
173 error = %e,
174 "Offline learning failed"
175 );
176 Ok(())
179 }
180 Err(e) => {
181 tracing::error!(
182 scenario = %event.scenario,
183 error = %e,
184 "Blocking task panicked"
185 );
186 Ok(())
187 }
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use std::sync::atomic::{AtomicUsize, Ordering};
196
197 pub struct CountingSink {
199 count: Arc<AtomicUsize>,
200 }
201
202 impl CountingSink {
203 pub fn new() -> Self {
204 Self {
205 count: Arc::new(AtomicUsize::new(0)),
206 }
207 }
208
209 pub fn count(&self) -> usize {
210 self.count.load(Ordering::SeqCst)
211 }
212 }
213
214 impl EventSink for CountingSink {
215 type Event = WatchEvent;
216
217 async fn process(&mut self, _event: Self::Event) -> Result<(), SwarmError> {
218 self.count.fetch_add(1, Ordering::SeqCst);
219 Ok(())
220 }
221 }
222
223 #[tokio::test]
224 async fn test_counting_sink() {
225 let mut sink = CountingSink::new();
226 assert_eq!(sink.count(), 0);
227
228 sink.process(WatchEvent::new("test".into())).await.unwrap();
229 assert_eq!(sink.count(), 1);
230
231 sink.process(WatchEvent::new("test2".into())).await.unwrap();
232 assert_eq!(sink.count(), 2);
233 }
234
235 #[test]
236 fn test_learning_sink_creation() {
237 let sink = LearningSink::new(PathBuf::from("/tmp/test"), 20);
238 assert_eq!(sink.learning_path().to_str().unwrap(), "/tmp/test");
239 }
240}