1use std::sync::Arc;
33use std::time::Duration;
34
35use super::store::{EpisodeStore, StoreError};
36use crate::util::epoch_millis;
37
38pub struct TriggerContext<'a> {
49 pub store: Option<&'a dyn EpisodeStore>,
51
52 pub event_count: Option<usize>,
54
55 pub last_train_at: Option<u64>,
57
58 pub last_train_count: usize,
60
61 pub metrics: Option<&'a TriggerMetrics>,
63}
64
65impl<'a> TriggerContext<'a> {
66 pub fn with_store(store: &'a dyn EpisodeStore) -> Self {
68 Self {
69 store: Some(store),
70 event_count: None,
71 last_train_at: None,
72 last_train_count: 0,
73 metrics: None,
74 }
75 }
76
77 pub fn with_count(count: usize) -> Self {
79 Self {
80 store: None,
81 event_count: Some(count),
82 last_train_at: None,
83 last_train_count: 0,
84 metrics: None,
85 }
86 }
87
88 pub fn last_train_at(mut self, timestamp: u64) -> Self {
90 self.last_train_at = Some(timestamp);
91 self
92 }
93
94 pub fn last_train_count(mut self, count: usize) -> Self {
96 self.last_train_count = count;
97 self
98 }
99
100 pub fn metrics(mut self, metrics: &'a TriggerMetrics) -> Self {
102 self.metrics = Some(metrics);
103 self
104 }
105
106 pub fn current_count(&self) -> Result<usize, TriggerError> {
108 if let Some(count) = self.event_count {
109 return Ok(count);
110 }
111 if let Some(store) = self.store {
112 return Ok(store.count(None)?);
113 }
114 Ok(0)
116 }
117}
118
119#[derive(Debug, Clone, Default)]
121pub struct TriggerMetrics {
122 pub recent_success_rate: f64,
124 pub overall_success_rate: f64,
126 pub recent_sample_size: usize,
128}
129
130#[derive(Debug)]
136pub enum TriggerError {
137 Store(StoreError),
139 MetricsUnavailable(String),
141}
142
143impl std::fmt::Display for TriggerError {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 Self::Store(e) => write!(f, "Store error: {}", e),
147 Self::MetricsUnavailable(msg) => write!(f, "Metrics unavailable: {}", msg),
148 }
149 }
150}
151
152impl std::error::Error for TriggerError {}
153
154impl From<StoreError> for TriggerError {
155 fn from(e: StoreError) -> Self {
156 Self::Store(e)
157 }
158}
159
160pub trait TrainTrigger: Send + Sync {
166 fn should_train(&self, context: &TriggerContext) -> Result<bool, TriggerError>;
168
169 fn name(&self) -> &str;
171
172 fn describe(&self) -> String;
174}
175
176pub struct CountTrigger {
182 threshold: usize,
184}
185
186impl CountTrigger {
187 pub fn new(threshold: usize) -> Self {
188 Self { threshold }
189 }
190}
191
192impl TrainTrigger for CountTrigger {
193 fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
194 let current_count = ctx.current_count()?;
195 let new_episodes = current_count.saturating_sub(ctx.last_train_count);
196 Ok(new_episodes >= self.threshold)
197 }
198
199 fn name(&self) -> &str {
200 "count"
201 }
202
203 fn describe(&self) -> String {
204 format!("Train when {} new episodes accumulated", self.threshold)
205 }
206}
207
208pub struct TimeTrigger {
214 interval_secs: u64,
216}
217
218impl TimeTrigger {
219 pub fn new(interval: Duration) -> Self {
220 Self {
221 interval_secs: interval.as_secs(),
222 }
223 }
224
225 pub fn hours(hours: u64) -> Self {
226 Self {
227 interval_secs: hours * 3600,
228 }
229 }
230
231 pub fn minutes(minutes: u64) -> Self {
232 Self {
233 interval_secs: minutes * 60,
234 }
235 }
236}
237
238impl TrainTrigger for TimeTrigger {
239 fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
240 let Some(last_train) = ctx.last_train_at else {
241 let count = ctx.current_count()?;
243 return Ok(count > 0);
244 };
245
246 let now = epoch_millis();
247 let elapsed_secs = (now.saturating_sub(last_train)) / 1000;
248 Ok(elapsed_secs >= self.interval_secs)
249 }
250
251 fn name(&self) -> &str {
252 "time"
253 }
254
255 fn describe(&self) -> String {
256 if self.interval_secs >= 3600 {
257 format!("Train every {} hours", self.interval_secs / 3600)
258 } else if self.interval_secs >= 60 {
259 format!("Train every {} minutes", self.interval_secs / 60)
260 } else {
261 format!("Train every {} seconds", self.interval_secs)
262 }
263 }
264}
265
266pub struct QualityTrigger {
272 threshold: f64,
274 min_samples: usize,
276}
277
278impl QualityTrigger {
279 pub fn new(threshold: f64) -> Self {
280 Self {
281 threshold,
282 min_samples: 10,
283 }
284 }
285
286 pub fn with_min_samples(mut self, min: usize) -> Self {
287 self.min_samples = min;
288 self
289 }
290}
291
292impl TrainTrigger for QualityTrigger {
293 fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
294 let metrics = ctx.metrics.ok_or_else(|| {
295 TriggerError::MetricsUnavailable("QualityTrigger requires metrics".into())
296 })?;
297
298 if metrics.recent_sample_size < self.min_samples {
300 return Ok(false);
301 }
302
303 Ok(metrics.recent_success_rate < self.threshold)
304 }
305
306 fn name(&self) -> &str {
307 "quality"
308 }
309
310 fn describe(&self) -> String {
311 format!(
312 "Train when success rate < {:.0}% (min {} samples)",
313 self.threshold * 100.0,
314 self.min_samples
315 )
316 }
317}
318
319pub struct ManualTrigger;
325
326impl TrainTrigger for ManualTrigger {
327 fn should_train(&self, _ctx: &TriggerContext) -> Result<bool, TriggerError> {
328 Ok(false)
329 }
330
331 fn name(&self) -> &str {
332 "manual"
333 }
334
335 fn describe(&self) -> String {
336 "Manual trigger only".into()
337 }
338}
339
340pub struct NeverTrigger;
346
347impl TrainTrigger for NeverTrigger {
348 fn should_train(&self, _ctx: &TriggerContext) -> Result<bool, TriggerError> {
349 Ok(false)
350 }
351
352 fn name(&self) -> &str {
353 "never"
354 }
355
356 fn describe(&self) -> String {
357 "Never triggers".into()
358 }
359}
360
361pub struct AlwaysTrigger;
367
368impl TrainTrigger for AlwaysTrigger {
369 fn should_train(&self, _ctx: &TriggerContext) -> Result<bool, TriggerError> {
370 Ok(true)
371 }
372
373 fn name(&self) -> &str {
374 "always"
375 }
376
377 fn describe(&self) -> String {
378 "Always triggers".into()
379 }
380}
381
382pub struct OrTrigger {
388 triggers: Vec<Arc<dyn TrainTrigger>>,
389}
390
391impl OrTrigger {
392 pub fn new(triggers: Vec<Arc<dyn TrainTrigger>>) -> Self {
393 Self { triggers }
394 }
395}
396
397impl TrainTrigger for OrTrigger {
398 fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
399 for trigger in &self.triggers {
400 if trigger.should_train(ctx)? {
401 return Ok(true);
402 }
403 }
404 Ok(false)
405 }
406
407 fn name(&self) -> &str {
408 "or"
409 }
410
411 fn describe(&self) -> String {
412 let names: Vec<_> = self.triggers.iter().map(|t| t.name()).collect();
413 format!("OR({})", names.join(", "))
414 }
415}
416
417pub struct AndTrigger {
423 triggers: Vec<Arc<dyn TrainTrigger>>,
424}
425
426impl AndTrigger {
427 pub fn new(triggers: Vec<Arc<dyn TrainTrigger>>) -> Self {
428 Self { triggers }
429 }
430}
431
432impl TrainTrigger for AndTrigger {
433 fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
434 if self.triggers.is_empty() {
435 return Ok(false);
436 }
437 for trigger in &self.triggers {
438 if !trigger.should_train(ctx)? {
439 return Ok(false);
440 }
441 }
442 Ok(true)
443 }
444
445 fn name(&self) -> &str {
446 "and"
447 }
448
449 fn describe(&self) -> String {
450 let names: Vec<_> = self.triggers.iter().map(|t| t.name()).collect();
451 format!("AND({})", names.join(", "))
452 }
453}
454
455pub struct TriggerBuilder;
461
462impl TriggerBuilder {
463 pub fn every_n_episodes(n: usize) -> Arc<dyn TrainTrigger> {
465 Arc::new(CountTrigger::new(n))
466 }
467
468 pub fn every_hours(hours: u64) -> Arc<dyn TrainTrigger> {
470 Arc::new(TimeTrigger::hours(hours))
471 }
472
473 pub fn every_minutes(minutes: u64) -> Arc<dyn TrainTrigger> {
475 Arc::new(TimeTrigger::minutes(minutes))
476 }
477
478 pub fn on_quality_drop(threshold: f64) -> Arc<dyn TrainTrigger> {
480 Arc::new(QualityTrigger::new(threshold))
481 }
482
483 pub fn default_watch() -> Arc<dyn TrainTrigger> {
485 Arc::new(OrTrigger::new(vec![
486 Self::every_n_episodes(100),
487 Self::every_hours(1),
488 ]))
489 }
490
491 pub fn manual() -> Arc<dyn TrainTrigger> {
493 Arc::new(ManualTrigger)
494 }
495
496 pub fn never() -> Arc<dyn TrainTrigger> {
498 Arc::new(NeverTrigger)
499 }
500
501 pub fn always() -> Arc<dyn TrainTrigger> {
503 Arc::new(AlwaysTrigger)
504 }
505}
506
507#[cfg(test)]
512mod tests {
513 use super::*;
514 use crate::learn::store::{EpisodeDto, InMemoryEpisodeStore};
515 use crate::learn::{EpisodeId, EpisodeMetadata, Outcome};
516
517 fn create_test_store(count: usize) -> InMemoryEpisodeStore {
518 let store = InMemoryEpisodeStore::new();
519 for _ in 0..count {
520 let dto = EpisodeDto {
521 id: EpisodeId::new(),
522 learn_model: "test".to_string(),
523 outcome: Outcome::success(1.0),
524 metadata: EpisodeMetadata::new(),
525 record_ids: vec![],
526 };
527 store.append(&dto).unwrap();
528 }
529 store
530 }
531
532 fn create_context<'a>(
533 store: &'a dyn EpisodeStore,
534 last_train_at: Option<u64>,
535 last_train_count: usize,
536 metrics: Option<&'a TriggerMetrics>,
537 ) -> TriggerContext<'a> {
538 TriggerContext {
539 store: Some(store),
540 event_count: None,
541 last_train_at,
542 last_train_count,
543 metrics,
544 }
545 }
546
547 #[test]
552 fn test_count_trigger_below_threshold() {
553 let store = create_test_store(5);
554 let trigger = CountTrigger::new(10);
555 let ctx = create_context(&store, None, 0, None);
556
557 assert!(!trigger.should_train(&ctx).unwrap());
558 }
559
560 #[test]
561 fn test_count_trigger_at_threshold() {
562 let store = create_test_store(10);
563 let trigger = CountTrigger::new(10);
564 let ctx = create_context(&store, None, 0, None);
565
566 assert!(trigger.should_train(&ctx).unwrap());
567 }
568
569 #[test]
570 fn test_count_trigger_with_previous_count() {
571 let store = create_test_store(15);
572 let trigger = CountTrigger::new(10);
573
574 let ctx = create_context(&store, None, 10, None);
576 assert!(!trigger.should_train(&ctx).unwrap());
577
578 let ctx = create_context(&store, None, 5, None);
580 assert!(trigger.should_train(&ctx).unwrap());
581 }
582
583 #[test]
588 fn test_time_trigger_first_time_with_episodes() {
589 let store = create_test_store(5);
590 let trigger = TimeTrigger::hours(1);
591 let ctx = create_context(&store, None, 0, None);
592
593 assert!(trigger.should_train(&ctx).unwrap());
595 }
596
597 #[test]
598 fn test_time_trigger_first_time_no_episodes() {
599 let store = create_test_store(0);
600 let trigger = TimeTrigger::hours(1);
601 let ctx = create_context(&store, None, 0, None);
602
603 assert!(!trigger.should_train(&ctx).unwrap());
605 }
606
607 #[test]
608 fn test_time_trigger_not_elapsed() {
609 let store = create_test_store(5);
610 let trigger = TimeTrigger::hours(1);
611 let now = epoch_millis();
612 let ctx = create_context(&store, Some(now - 1000), 0, None); assert!(!trigger.should_train(&ctx).unwrap());
615 }
616
617 #[test]
618 fn test_time_trigger_elapsed() {
619 let store = create_test_store(5);
620 let trigger = TimeTrigger::hours(1);
621 let now = epoch_millis();
622 let ctx = create_context(&store, Some(now - 3601 * 1000), 0, None); assert!(trigger.should_train(&ctx).unwrap());
625 }
626
627 #[test]
632 fn test_quality_trigger_no_metrics() {
633 let store = create_test_store(5);
634 let trigger = QualityTrigger::new(0.5);
635 let ctx = create_context(&store, None, 0, None);
636
637 assert!(trigger.should_train(&ctx).is_err());
638 }
639
640 #[test]
641 fn test_quality_trigger_insufficient_samples() {
642 let store = create_test_store(5);
643 let trigger = QualityTrigger::new(0.5).with_min_samples(10);
644 let metrics = TriggerMetrics {
645 recent_success_rate: 0.3, overall_success_rate: 0.5,
647 recent_sample_size: 5, };
649 let ctx = create_context(&store, None, 0, Some(&metrics));
650
651 assert!(!trigger.should_train(&ctx).unwrap());
652 }
653
654 #[test]
655 fn test_quality_trigger_above_threshold() {
656 let store = create_test_store(5);
657 let trigger = QualityTrigger::new(0.5);
658 let metrics = TriggerMetrics {
659 recent_success_rate: 0.7,
660 overall_success_rate: 0.7,
661 recent_sample_size: 20,
662 };
663 let ctx = create_context(&store, None, 0, Some(&metrics));
664
665 assert!(!trigger.should_train(&ctx).unwrap());
666 }
667
668 #[test]
669 fn test_quality_trigger_below_threshold() {
670 let store = create_test_store(5);
671 let trigger = QualityTrigger::new(0.5);
672 let metrics = TriggerMetrics {
673 recent_success_rate: 0.3,
674 overall_success_rate: 0.5,
675 recent_sample_size: 20,
676 };
677 let ctx = create_context(&store, None, 0, Some(&metrics));
678
679 assert!(trigger.should_train(&ctx).unwrap());
680 }
681
682 #[test]
687 fn test_or_trigger_all_false() {
688 let store = create_test_store(5);
689 let trigger = OrTrigger::new(vec![
690 Arc::new(CountTrigger::new(100)),
691 Arc::new(NeverTrigger),
692 ]);
693 let ctx = create_context(&store, None, 0, None);
694
695 assert!(!trigger.should_train(&ctx).unwrap());
696 }
697
698 #[test]
699 fn test_or_trigger_one_true() {
700 let store = create_test_store(5);
701 let trigger = OrTrigger::new(vec![Arc::new(AlwaysTrigger), Arc::new(NeverTrigger)]);
702 let ctx = create_context(&store, None, 0, None);
703
704 assert!(trigger.should_train(&ctx).unwrap());
705 }
706
707 #[test]
712 fn test_and_trigger_empty() {
713 let store = create_test_store(5);
714 let trigger = AndTrigger::new(vec![]);
715 let ctx = create_context(&store, None, 0, None);
716
717 assert!(!trigger.should_train(&ctx).unwrap());
718 }
719
720 #[test]
721 fn test_and_trigger_all_true() {
722 let store = create_test_store(5);
723 let trigger = AndTrigger::new(vec![Arc::new(AlwaysTrigger), Arc::new(AlwaysTrigger)]);
724 let ctx = create_context(&store, None, 0, None);
725
726 assert!(trigger.should_train(&ctx).unwrap());
727 }
728
729 #[test]
730 fn test_and_trigger_one_false() {
731 let store = create_test_store(5);
732 let trigger = AndTrigger::new(vec![Arc::new(AlwaysTrigger), Arc::new(NeverTrigger)]);
733 let ctx = create_context(&store, None, 0, None);
734
735 assert!(!trigger.should_train(&ctx).unwrap());
736 }
737
738 #[test]
743 fn test_trigger_builder_default_watch() {
744 let trigger = TriggerBuilder::default_watch();
745 assert_eq!(trigger.name(), "or");
746 assert!(trigger.describe().contains("OR"));
747 }
748
749 #[test]
750 fn test_trigger_describe() {
751 assert_eq!(
752 CountTrigger::new(50).describe(),
753 "Train when 50 new episodes accumulated"
754 );
755 assert_eq!(TimeTrigger::hours(2).describe(), "Train every 2 hours");
756 assert_eq!(
757 TimeTrigger::minutes(30).describe(),
758 "Train every 30 minutes"
759 );
760 assert!(QualityTrigger::new(0.5).describe().contains("50%"));
761 }
762
763 #[test]
768 fn test_context_with_count_no_store() {
769 let ctx = TriggerContext::with_count(15);
771 let trigger = CountTrigger::new(10);
772
773 assert!(trigger.should_train(&ctx).unwrap());
775 }
776
777 #[test]
778 fn test_context_with_count_below_threshold() {
779 let ctx = TriggerContext::with_count(5);
780 let trigger = CountTrigger::new(10);
781
782 assert!(!trigger.should_train(&ctx).unwrap());
784 }
785
786 #[test]
787 fn test_context_with_count_and_last_train_count() {
788 let ctx = TriggerContext::with_count(20).last_train_count(15);
789 let trigger = CountTrigger::new(10);
790
791 assert!(!trigger.should_train(&ctx).unwrap());
793 }
794
795 #[test]
796 fn test_context_builder_fluent() {
797 let metrics = TriggerMetrics {
798 recent_success_rate: 0.3,
799 overall_success_rate: 0.5,
800 recent_sample_size: 20,
801 };
802
803 let now = epoch_millis();
804 let ctx = TriggerContext::with_count(100)
805 .last_train_at(now - 3600 * 1000) .last_train_count(50)
807 .metrics(&metrics);
808
809 let count_trigger = CountTrigger::new(10);
811 assert!(count_trigger.should_train(&ctx).unwrap());
812
813 let time_trigger = TimeTrigger::minutes(30);
815 assert!(time_trigger.should_train(&ctx).unwrap());
816
817 let quality_trigger = QualityTrigger::new(0.5);
819 assert!(quality_trigger.should_train(&ctx).unwrap());
820 }
821
822 #[test]
823 fn test_time_trigger_with_count_first_time() {
824 let ctx = TriggerContext::with_count(5);
826 let trigger = TimeTrigger::hours(1);
827
828 assert!(trigger.should_train(&ctx).unwrap());
829 }
830
831 #[test]
832 fn test_time_trigger_with_count_first_time_no_events() {
833 let ctx = TriggerContext::with_count(0);
835 let trigger = TimeTrigger::hours(1);
836
837 assert!(!trigger.should_train(&ctx).unwrap());
838 }
839}