Skip to main content

synth_ai_core/streaming/
config.rs

1//! Stream configuration for filtering and controlling stream behavior.
2
3use super::types::StreamType;
4use serde_json::Value;
5use std::collections::HashSet;
6
7/// Configuration for stream filtering and behavior.
8#[derive(Debug, Clone)]
9pub struct StreamConfig {
10    /// Which stream types to enable.
11    pub enabled_streams: HashSet<StreamType>,
12    /// Whitelist of event types to include (None = include all).
13    pub event_types: Option<HashSet<String>>,
14    /// Blacklist of event types to exclude.
15    pub event_types_exclude: Option<HashSet<String>>,
16    /// Filter by event levels (e.g., "error", "warning", "info").
17    pub event_levels: Option<HashSet<String>>,
18    /// Filter metrics by name.
19    pub metric_names: Option<HashSet<String>>,
20    /// Sampling rate (0.0-1.0) for events.
21    pub sample_rate: f64,
22    /// Maximum events to return per poll.
23    pub max_events_per_poll: Option<usize>,
24    /// Enable deduplication.
25    pub deduplicate: bool,
26    /// Polling interval in seconds.
27    pub poll_interval_seconds: f64,
28}
29
30impl Default for StreamConfig {
31    fn default() -> Self {
32        Self::with_default_filters()
33    }
34}
35
36impl StreamConfig {
37    /// Create a config that includes all streams with no filtering.
38    pub fn all() -> Self {
39        Self {
40            enabled_streams: [
41                StreamType::Status,
42                StreamType::Events,
43                StreamType::Metrics,
44                StreamType::Timeline,
45            ]
46            .into_iter()
47            .collect(),
48            event_types: None,
49            event_types_exclude: None,
50            event_levels: None,
51            metric_names: None,
52            sample_rate: 1.0,
53            max_events_per_poll: None,
54            deduplicate: true,
55            poll_interval_seconds: 2.0,
56        }
57    }
58
59    /// Create a config with sensible default filters.
60    pub fn with_default_filters() -> Self {
61        let mut config = Self::all();
62        // Exclude noisy internal events by default
63        config.event_types_exclude = Some(
64            [
65                "sft.progress",
66                "sft.loss",
67                "sft.upstream.status",
68                "internal.heartbeat",
69            ]
70            .iter()
71            .map(|s| s.to_string())
72            .collect(),
73        );
74        config
75    }
76
77    /// Create a minimal config (status only).
78    pub fn minimal() -> Self {
79        Self {
80            enabled_streams: [StreamType::Status].into_iter().collect(),
81            event_types: None,
82            event_types_exclude: None,
83            event_levels: None,
84            metric_names: None,
85            sample_rate: 1.0,
86            max_events_per_poll: None,
87            deduplicate: true,
88            poll_interval_seconds: 5.0,
89        }
90    }
91
92    /// Create a config for errors and warnings only.
93    pub fn errors_only() -> Self {
94        Self {
95            enabled_streams: [StreamType::Status, StreamType::Events].into_iter().collect(),
96            event_types: None,
97            event_types_exclude: None,
98            event_levels: Some(["error", "warning"].iter().map(|s| s.to_string()).collect()),
99            metric_names: None,
100            sample_rate: 1.0,
101            max_events_per_poll: None,
102            deduplicate: true,
103            poll_interval_seconds: 2.0,
104        }
105    }
106
107    /// Create a config for metrics only.
108    pub fn metrics_only() -> Self {
109        Self {
110            enabled_streams: [StreamType::Status, StreamType::Metrics].into_iter().collect(),
111            event_types: None,
112            event_types_exclude: None,
113            event_levels: None,
114            metric_names: None,
115            sample_rate: 1.0,
116            max_events_per_poll: None,
117            deduplicate: true,
118            poll_interval_seconds: 1.0,
119        }
120    }
121
122    /// Enable a specific stream type.
123    pub fn enable_stream(mut self, stream_type: StreamType) -> Self {
124        self.enabled_streams.insert(stream_type);
125        self
126    }
127
128    /// Disable a specific stream type.
129    pub fn disable_stream(mut self, stream_type: StreamType) -> Self {
130        self.enabled_streams.remove(&stream_type);
131        self
132    }
133
134    /// Add an event type to the whitelist.
135    pub fn include_event_type(mut self, event_type: impl Into<String>) -> Self {
136        let types = self.event_types.get_or_insert_with(HashSet::new);
137        types.insert(event_type.into());
138        self
139    }
140
141    /// Add an event type to the blacklist.
142    pub fn exclude_event_type(mut self, event_type: impl Into<String>) -> Self {
143        let types = self.event_types_exclude.get_or_insert_with(HashSet::new);
144        types.insert(event_type.into());
145        self
146    }
147
148    /// Filter by event levels.
149    pub fn with_levels(mut self, levels: Vec<&str>) -> Self {
150        self.event_levels = Some(levels.into_iter().map(String::from).collect());
151        self
152    }
153
154    /// Set the polling interval.
155    pub fn with_interval(mut self, seconds: f64) -> Self {
156        self.poll_interval_seconds = seconds;
157        self
158    }
159
160    /// Set the sample rate.
161    pub fn with_sample_rate(mut self, rate: f64) -> Self {
162        self.sample_rate = rate.clamp(0.0, 1.0);
163        self
164    }
165
166    /// Disable deduplication.
167    pub fn without_deduplication(mut self) -> Self {
168        self.deduplicate = false;
169        self
170    }
171
172    /// Check if a stream type is enabled.
173    pub fn is_stream_enabled(&self, stream_type: StreamType) -> bool {
174        self.enabled_streams.contains(&stream_type)
175    }
176
177    /// Check if an event should be included based on filters.
178    pub fn should_include_event(&self, event: &Value) -> bool {
179        let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
180
181        // Blacklist takes precedence
182        if let Some(ref exclude) = self.event_types_exclude {
183            if exclude.contains(event_type) {
184                return false;
185            }
186        }
187
188        // Whitelist check (if whitelist exists, event must be in it)
189        if let Some(ref include) = self.event_types {
190            if !include.contains(event_type) {
191                return false;
192            }
193        }
194
195        // Level check
196        if let Some(ref levels) = self.event_levels {
197            let level = event.get("level").and_then(|v| v.as_str()).unwrap_or("");
198            if !levels.contains(level) {
199                return false;
200            }
201        }
202
203        // Sample rate check
204        if self.sample_rate < 1.0 {
205            use std::hash::{Hash, Hasher};
206            let mut hasher = std::collections::hash_map::DefaultHasher::new();
207            event.to_string().hash(&mut hasher);
208            let hash = hasher.finish();
209            let threshold = (self.sample_rate * u64::MAX as f64) as u64;
210            if hash > threshold {
211                return false;
212            }
213        }
214
215        true
216    }
217
218    /// Check if a metric should be included based on filters.
219    pub fn should_include_metric(&self, metric: &Value) -> bool {
220        if let Some(ref names) = self.metric_names {
221            let name = metric.get("name").and_then(|v| v.as_str()).unwrap_or("");
222            return names.contains(name);
223        }
224        true
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_all_config() {
234        let config = StreamConfig::all();
235        assert!(config.is_stream_enabled(StreamType::Status));
236        assert!(config.is_stream_enabled(StreamType::Events));
237        assert!(config.is_stream_enabled(StreamType::Metrics));
238        assert!(config.is_stream_enabled(StreamType::Timeline));
239    }
240
241    #[test]
242    fn test_minimal_config() {
243        let config = StreamConfig::minimal();
244        assert!(config.is_stream_enabled(StreamType::Status));
245        assert!(!config.is_stream_enabled(StreamType::Events));
246    }
247
248    #[test]
249    fn test_event_blacklist() {
250        let config = StreamConfig::all().exclude_event_type("internal.heartbeat");
251
252        let allowed = serde_json::json!({"type": "progress"});
253        let blocked = serde_json::json!({"type": "internal.heartbeat"});
254
255        assert!(config.should_include_event(&allowed));
256        assert!(!config.should_include_event(&blocked));
257    }
258
259    #[test]
260    fn test_event_whitelist() {
261        let config = StreamConfig::all()
262            .include_event_type("error")
263            .include_event_type("warning");
264
265        let allowed = serde_json::json!({"type": "error"});
266        let blocked = serde_json::json!({"type": "progress"});
267
268        assert!(config.should_include_event(&allowed));
269        assert!(!config.should_include_event(&blocked));
270    }
271
272    #[test]
273    fn test_level_filter() {
274        let config = StreamConfig::errors_only();
275
276        let error = serde_json::json!({"type": "test", "level": "error"});
277        let warning = serde_json::json!({"type": "test", "level": "warning"});
278        let info = serde_json::json!({"type": "test", "level": "info"});
279
280        assert!(config.should_include_event(&error));
281        assert!(config.should_include_event(&warning));
282        assert!(!config.should_include_event(&info));
283    }
284
285    #[test]
286    fn test_stream_enable_disable() {
287        let config = StreamConfig::all()
288            .disable_stream(StreamType::Timeline)
289            .disable_stream(StreamType::Metrics);
290
291        assert!(config.is_stream_enabled(StreamType::Status));
292        assert!(config.is_stream_enabled(StreamType::Events));
293        assert!(!config.is_stream_enabled(StreamType::Metrics));
294        assert!(!config.is_stream_enabled(StreamType::Timeline));
295    }
296
297    #[test]
298    fn test_builder_pattern() {
299        let config = StreamConfig::minimal()
300            .enable_stream(StreamType::Events)
301            .exclude_event_type("heartbeat")
302            .with_interval(1.0)
303            .without_deduplication();
304
305        assert!(config.is_stream_enabled(StreamType::Events));
306        assert_eq!(config.poll_interval_seconds, 1.0);
307        assert!(!config.deduplicate);
308    }
309}