Skip to main content

synth_ai_core/streaming/
config.rs

1//! Stream configuration for filtering and controlling stream behavior.
2
3use super::types::StreamType;
4use serde_json::Value;
5use std::collections::HashSet;
6
7/// Configuration for stream filtering and behavior.
8#[derive(Debug, Clone)]
9pub struct StreamConfig {
10    /// Which stream types to enable.
11    pub enabled_streams: HashSet<StreamType>,
12    /// Whitelist of event types to include (None = include all).
13    pub event_types: Option<HashSet<String>>,
14    /// Blacklist of event types to exclude.
15    pub event_types_exclude: Option<HashSet<String>>,
16    /// Filter by event levels (e.g., "error", "warning", "info").
17    pub event_levels: Option<HashSet<String>>,
18    /// Filter metrics by name.
19    pub metric_names: Option<HashSet<String>>,
20    /// Filter metrics by phase.
21    pub metric_phases: Option<HashSet<String>>,
22    /// Filter timeline entries by phase.
23    pub timeline_phases: Option<HashSet<String>>,
24    /// Sampling rate (0.0-1.0) for events.
25    pub sample_rate: f64,
26    /// Maximum events to return per poll.
27    pub max_events_per_poll: Option<usize>,
28    /// Enable deduplication.
29    pub deduplicate: bool,
30    /// Polling interval in seconds.
31    pub poll_interval_seconds: f64,
32}
33
34impl Default for StreamConfig {
35    fn default() -> Self {
36        Self::with_default_filters()
37    }
38}
39
40impl StreamConfig {
41    /// Create a config that includes all streams with no filtering.
42    pub fn all() -> Self {
43        Self {
44            enabled_streams: [
45                StreamType::Status,
46                StreamType::Events,
47                StreamType::Metrics,
48                StreamType::Timeline,
49            ]
50            .into_iter()
51            .collect(),
52            event_types: None,
53            event_types_exclude: None,
54            event_levels: None,
55            metric_names: None,
56            metric_phases: None,
57            timeline_phases: None,
58            sample_rate: 1.0,
59            max_events_per_poll: None,
60            deduplicate: true,
61            poll_interval_seconds: 2.0,
62        }
63    }
64
65    /// Create a config with sensible default filters.
66    pub fn with_default_filters() -> Self {
67        let mut config = Self::all();
68        // Exclude noisy internal events by default
69        config.event_types_exclude = Some(
70            [
71                "sft.progress",
72                "sft.loss",
73                "sft.upstream.status",
74                "internal.heartbeat",
75            ]
76            .iter()
77            .map(|s| s.to_string())
78            .collect(),
79        );
80        config
81    }
82
83    /// Create a minimal config (status only).
84    pub fn minimal() -> Self {
85        Self {
86            enabled_streams: [StreamType::Status].into_iter().collect(),
87            event_types: None,
88            event_types_exclude: None,
89            event_levels: None,
90            metric_names: None,
91            metric_phases: None,
92            timeline_phases: None,
93            sample_rate: 1.0,
94            max_events_per_poll: None,
95            deduplicate: true,
96            poll_interval_seconds: 5.0,
97        }
98    }
99
100    /// Create a config for errors and warnings only.
101    pub fn errors_only() -> Self {
102        Self {
103            enabled_streams: [StreamType::Status, StreamType::Events]
104                .into_iter()
105                .collect(),
106            event_types: None,
107            event_types_exclude: None,
108            event_levels: Some(["error", "warning"].iter().map(|s| s.to_string()).collect()),
109            metric_names: None,
110            metric_phases: None,
111            timeline_phases: None,
112            sample_rate: 1.0,
113            max_events_per_poll: None,
114            deduplicate: true,
115            poll_interval_seconds: 2.0,
116        }
117    }
118
119    /// Create a config for metrics only.
120    pub fn metrics_only() -> Self {
121        Self {
122            enabled_streams: [StreamType::Status, StreamType::Metrics]
123                .into_iter()
124                .collect(),
125            event_types: None,
126            event_types_exclude: None,
127            event_levels: None,
128            metric_names: None,
129            metric_phases: None,
130            timeline_phases: None,
131            sample_rate: 1.0,
132            max_events_per_poll: None,
133            deduplicate: true,
134            poll_interval_seconds: 1.0,
135        }
136    }
137
138    /// Enable a specific stream type.
139    pub fn enable_stream(mut self, stream_type: StreamType) -> Self {
140        self.enabled_streams.insert(stream_type);
141        self
142    }
143
144    /// Disable a specific stream type.
145    pub fn disable_stream(mut self, stream_type: StreamType) -> Self {
146        self.enabled_streams.remove(&stream_type);
147        self
148    }
149
150    /// Add an event type to the whitelist.
151    pub fn include_event_type(mut self, event_type: impl Into<String>) -> Self {
152        let types = self.event_types.get_or_insert_with(HashSet::new);
153        types.insert(event_type.into());
154        self
155    }
156
157    /// Add an event type to the blacklist.
158    pub fn exclude_event_type(mut self, event_type: impl Into<String>) -> Self {
159        let types = self.event_types_exclude.get_or_insert_with(HashSet::new);
160        types.insert(event_type.into());
161        self
162    }
163
164    /// Filter by event levels.
165    pub fn with_levels(mut self, levels: Vec<&str>) -> Self {
166        self.event_levels = Some(levels.into_iter().map(String::from).collect());
167        self
168    }
169
170    /// Filter metrics by phase.
171    pub fn with_metric_phases(mut self, phases: Vec<&str>) -> Self {
172        self.metric_phases = Some(phases.into_iter().map(String::from).collect());
173        self
174    }
175
176    /// Filter timeline entries by phase.
177    pub fn with_timeline_phases(mut self, phases: Vec<&str>) -> Self {
178        self.timeline_phases = Some(phases.into_iter().map(String::from).collect());
179        self
180    }
181
182    /// Set the polling interval.
183    pub fn with_interval(mut self, seconds: f64) -> Self {
184        self.poll_interval_seconds = seconds;
185        self
186    }
187
188    /// Set the sample rate.
189    pub fn with_sample_rate(mut self, rate: f64) -> Self {
190        self.sample_rate = rate.clamp(0.0, 1.0);
191        self
192    }
193
194    /// Disable deduplication.
195    pub fn without_deduplication(mut self) -> Self {
196        self.deduplicate = false;
197        self
198    }
199
200    /// Check if a stream type is enabled.
201    pub fn is_stream_enabled(&self, stream_type: StreamType) -> bool {
202        self.enabled_streams.contains(&stream_type)
203    }
204
205    /// Check if an event should be included based on filters.
206    pub fn should_include_event(&self, event: &Value) -> bool {
207        let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
208
209        // Blacklist takes precedence
210        if let Some(ref exclude) = self.event_types_exclude {
211            if exclude.contains(event_type) {
212                return false;
213            }
214        }
215
216        // Whitelist check (if whitelist exists, event must be in it)
217        if let Some(ref include) = self.event_types {
218            if !include.contains(event_type) {
219                return false;
220            }
221        }
222
223        // Level check
224        if let Some(ref levels) = self.event_levels {
225            let level = event.get("level").and_then(|v| v.as_str()).unwrap_or("");
226            if !levels.contains(level) {
227                return false;
228            }
229        }
230
231        // Sample rate check
232        if self.sample_rate < 1.0 {
233            use std::hash::{Hash, Hasher};
234            let mut hasher = std::collections::hash_map::DefaultHasher::new();
235            event.to_string().hash(&mut hasher);
236            let hash = hasher.finish();
237            let threshold = (self.sample_rate * u64::MAX as f64) as u64;
238            if hash > threshold {
239                return false;
240            }
241        }
242
243        true
244    }
245
246    /// Check if a metric should be included based on filters.
247    pub fn should_include_metric(&self, metric: &Value) -> bool {
248        if let Some(ref names) = self.metric_names {
249            let name = metric.get("name").and_then(|v| v.as_str()).unwrap_or("");
250            if !names.contains(name) {
251                return false;
252            }
253        }
254
255        if let Some(ref phases) = self.metric_phases {
256            let phase = metric.get("phase").and_then(|v| v.as_str()).unwrap_or("");
257            return phases.contains(phase);
258        }
259        true
260    }
261
262    /// Check if a timeline entry should be included based on filters.
263    pub fn should_include_timeline(&self, entry: &Value) -> bool {
264        if let Some(ref phases) = self.timeline_phases {
265            let phase = entry.get("phase").and_then(|v| v.as_str()).unwrap_or("");
266            return phases.contains(phase);
267        }
268        true
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_all_config() {
278        let config = StreamConfig::all();
279        assert!(config.is_stream_enabled(StreamType::Status));
280        assert!(config.is_stream_enabled(StreamType::Events));
281        assert!(config.is_stream_enabled(StreamType::Metrics));
282        assert!(config.is_stream_enabled(StreamType::Timeline));
283    }
284
285    #[test]
286    fn test_minimal_config() {
287        let config = StreamConfig::minimal();
288        assert!(config.is_stream_enabled(StreamType::Status));
289        assert!(!config.is_stream_enabled(StreamType::Events));
290    }
291
292    #[test]
293    fn test_event_blacklist() {
294        let config = StreamConfig::all().exclude_event_type("internal.heartbeat");
295
296        let allowed = serde_json::json!({"type": "progress"});
297        let blocked = serde_json::json!({"type": "internal.heartbeat"});
298
299        assert!(config.should_include_event(&allowed));
300        assert!(!config.should_include_event(&blocked));
301    }
302
303    #[test]
304    fn test_event_whitelist() {
305        let config = StreamConfig::all()
306            .include_event_type("error")
307            .include_event_type("warning");
308
309        let allowed = serde_json::json!({"type": "error"});
310        let blocked = serde_json::json!({"type": "progress"});
311
312        assert!(config.should_include_event(&allowed));
313        assert!(!config.should_include_event(&blocked));
314    }
315
316    #[test]
317    fn test_level_filter() {
318        let config = StreamConfig::errors_only();
319
320        let error = serde_json::json!({"type": "test", "level": "error"});
321        let warning = serde_json::json!({"type": "test", "level": "warning"});
322        let info = serde_json::json!({"type": "test", "level": "info"});
323
324        assert!(config.should_include_event(&error));
325        assert!(config.should_include_event(&warning));
326        assert!(!config.should_include_event(&info));
327    }
328
329    #[test]
330    fn test_stream_enable_disable() {
331        let config = StreamConfig::all()
332            .disable_stream(StreamType::Timeline)
333            .disable_stream(StreamType::Metrics);
334
335        assert!(config.is_stream_enabled(StreamType::Status));
336        assert!(config.is_stream_enabled(StreamType::Events));
337        assert!(!config.is_stream_enabled(StreamType::Metrics));
338        assert!(!config.is_stream_enabled(StreamType::Timeline));
339    }
340
341    #[test]
342    fn test_builder_pattern() {
343        let config = StreamConfig::minimal()
344            .enable_stream(StreamType::Events)
345            .exclude_event_type("heartbeat")
346            .with_interval(1.0)
347            .without_deduplication();
348
349        assert!(config.is_stream_enabled(StreamType::Events));
350        assert_eq!(config.poll_interval_seconds, 1.0);
351        assert!(!config.deduplicate);
352    }
353}