synth_ai_core/streaming/
config.rs1use super::types::StreamType;
4use serde_json::Value;
5use std::collections::HashSet;
6
7#[derive(Debug, Clone)]
9pub struct StreamConfig {
10 pub enabled_streams: HashSet<StreamType>,
12 pub event_types: Option<HashSet<String>>,
14 pub event_types_exclude: Option<HashSet<String>>,
16 pub event_levels: Option<HashSet<String>>,
18 pub metric_names: Option<HashSet<String>>,
20 pub sample_rate: f64,
22 pub max_events_per_poll: Option<usize>,
24 pub deduplicate: bool,
26 pub poll_interval_seconds: f64,
28}
29
30impl Default for StreamConfig {
31 fn default() -> Self {
32 Self::with_default_filters()
33 }
34}
35
36impl StreamConfig {
37 pub fn all() -> Self {
39 Self {
40 enabled_streams: [
41 StreamType::Status,
42 StreamType::Events,
43 StreamType::Metrics,
44 StreamType::Timeline,
45 ]
46 .into_iter()
47 .collect(),
48 event_types: None,
49 event_types_exclude: None,
50 event_levels: None,
51 metric_names: None,
52 sample_rate: 1.0,
53 max_events_per_poll: None,
54 deduplicate: true,
55 poll_interval_seconds: 2.0,
56 }
57 }
58
59 pub fn with_default_filters() -> Self {
61 let mut config = Self::all();
62 config.event_types_exclude = Some(
64 [
65 "sft.progress",
66 "sft.loss",
67 "sft.upstream.status",
68 "internal.heartbeat",
69 ]
70 .iter()
71 .map(|s| s.to_string())
72 .collect(),
73 );
74 config
75 }
76
77 pub fn minimal() -> Self {
79 Self {
80 enabled_streams: [StreamType::Status].into_iter().collect(),
81 event_types: None,
82 event_types_exclude: None,
83 event_levels: None,
84 metric_names: None,
85 sample_rate: 1.0,
86 max_events_per_poll: None,
87 deduplicate: true,
88 poll_interval_seconds: 5.0,
89 }
90 }
91
92 pub fn errors_only() -> Self {
94 Self {
95 enabled_streams: [StreamType::Status, StreamType::Events].into_iter().collect(),
96 event_types: None,
97 event_types_exclude: None,
98 event_levels: Some(["error", "warning"].iter().map(|s| s.to_string()).collect()),
99 metric_names: None,
100 sample_rate: 1.0,
101 max_events_per_poll: None,
102 deduplicate: true,
103 poll_interval_seconds: 2.0,
104 }
105 }
106
107 pub fn metrics_only() -> Self {
109 Self {
110 enabled_streams: [StreamType::Status, StreamType::Metrics].into_iter().collect(),
111 event_types: None,
112 event_types_exclude: None,
113 event_levels: None,
114 metric_names: None,
115 sample_rate: 1.0,
116 max_events_per_poll: None,
117 deduplicate: true,
118 poll_interval_seconds: 1.0,
119 }
120 }
121
122 pub fn enable_stream(mut self, stream_type: StreamType) -> Self {
124 self.enabled_streams.insert(stream_type);
125 self
126 }
127
128 pub fn disable_stream(mut self, stream_type: StreamType) -> Self {
130 self.enabled_streams.remove(&stream_type);
131 self
132 }
133
134 pub fn include_event_type(mut self, event_type: impl Into<String>) -> Self {
136 let types = self.event_types.get_or_insert_with(HashSet::new);
137 types.insert(event_type.into());
138 self
139 }
140
141 pub fn exclude_event_type(mut self, event_type: impl Into<String>) -> Self {
143 let types = self.event_types_exclude.get_or_insert_with(HashSet::new);
144 types.insert(event_type.into());
145 self
146 }
147
148 pub fn with_levels(mut self, levels: Vec<&str>) -> Self {
150 self.event_levels = Some(levels.into_iter().map(String::from).collect());
151 self
152 }
153
154 pub fn with_interval(mut self, seconds: f64) -> Self {
156 self.poll_interval_seconds = seconds;
157 self
158 }
159
160 pub fn with_sample_rate(mut self, rate: f64) -> Self {
162 self.sample_rate = rate.clamp(0.0, 1.0);
163 self
164 }
165
166 pub fn without_deduplication(mut self) -> Self {
168 self.deduplicate = false;
169 self
170 }
171
172 pub fn is_stream_enabled(&self, stream_type: StreamType) -> bool {
174 self.enabled_streams.contains(&stream_type)
175 }
176
177 pub fn should_include_event(&self, event: &Value) -> bool {
179 let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
180
181 if let Some(ref exclude) = self.event_types_exclude {
183 if exclude.contains(event_type) {
184 return false;
185 }
186 }
187
188 if let Some(ref include) = self.event_types {
190 if !include.contains(event_type) {
191 return false;
192 }
193 }
194
195 if let Some(ref levels) = self.event_levels {
197 let level = event.get("level").and_then(|v| v.as_str()).unwrap_or("");
198 if !levels.contains(level) {
199 return false;
200 }
201 }
202
203 if self.sample_rate < 1.0 {
205 use std::hash::{Hash, Hasher};
206 let mut hasher = std::collections::hash_map::DefaultHasher::new();
207 event.to_string().hash(&mut hasher);
208 let hash = hasher.finish();
209 let threshold = (self.sample_rate * u64::MAX as f64) as u64;
210 if hash > threshold {
211 return false;
212 }
213 }
214
215 true
216 }
217
218 pub fn should_include_metric(&self, metric: &Value) -> bool {
220 if let Some(ref names) = self.metric_names {
221 let name = metric.get("name").and_then(|v| v.as_str()).unwrap_or("");
222 return names.contains(name);
223 }
224 true
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_all_config() {
234 let config = StreamConfig::all();
235 assert!(config.is_stream_enabled(StreamType::Status));
236 assert!(config.is_stream_enabled(StreamType::Events));
237 assert!(config.is_stream_enabled(StreamType::Metrics));
238 assert!(config.is_stream_enabled(StreamType::Timeline));
239 }
240
241 #[test]
242 fn test_minimal_config() {
243 let config = StreamConfig::minimal();
244 assert!(config.is_stream_enabled(StreamType::Status));
245 assert!(!config.is_stream_enabled(StreamType::Events));
246 }
247
248 #[test]
249 fn test_event_blacklist() {
250 let config = StreamConfig::all().exclude_event_type("internal.heartbeat");
251
252 let allowed = serde_json::json!({"type": "progress"});
253 let blocked = serde_json::json!({"type": "internal.heartbeat"});
254
255 assert!(config.should_include_event(&allowed));
256 assert!(!config.should_include_event(&blocked));
257 }
258
259 #[test]
260 fn test_event_whitelist() {
261 let config = StreamConfig::all()
262 .include_event_type("error")
263 .include_event_type("warning");
264
265 let allowed = serde_json::json!({"type": "error"});
266 let blocked = serde_json::json!({"type": "progress"});
267
268 assert!(config.should_include_event(&allowed));
269 assert!(!config.should_include_event(&blocked));
270 }
271
272 #[test]
273 fn test_level_filter() {
274 let config = StreamConfig::errors_only();
275
276 let error = serde_json::json!({"type": "test", "level": "error"});
277 let warning = serde_json::json!({"type": "test", "level": "warning"});
278 let info = serde_json::json!({"type": "test", "level": "info"});
279
280 assert!(config.should_include_event(&error));
281 assert!(config.should_include_event(&warning));
282 assert!(!config.should_include_event(&info));
283 }
284
285 #[test]
286 fn test_stream_enable_disable() {
287 let config = StreamConfig::all()
288 .disable_stream(StreamType::Timeline)
289 .disable_stream(StreamType::Metrics);
290
291 assert!(config.is_stream_enabled(StreamType::Status));
292 assert!(config.is_stream_enabled(StreamType::Events));
293 assert!(!config.is_stream_enabled(StreamType::Metrics));
294 assert!(!config.is_stream_enabled(StreamType::Timeline));
295 }
296
297 #[test]
298 fn test_builder_pattern() {
299 let config = StreamConfig::minimal()
300 .enable_stream(StreamType::Events)
301 .exclude_event_type("heartbeat")
302 .with_interval(1.0)
303 .without_deduplication();
304
305 assert!(config.is_stream_enabled(StreamType::Events));
306 assert_eq!(config.poll_interval_seconds, 1.0);
307 assert!(!config.deduplicate);
308 }
309}