1use super::types::StreamType;
4use serde_json::Value;
5use std::collections::HashSet;
6
7#[derive(Debug, Clone)]
9pub struct StreamConfig {
10 pub enabled_streams: HashSet<StreamType>,
12 pub event_types: Option<HashSet<String>>,
14 pub event_types_exclude: Option<HashSet<String>>,
16 pub event_levels: Option<HashSet<String>>,
18 pub metric_names: Option<HashSet<String>>,
20 pub metric_phases: Option<HashSet<String>>,
22 pub timeline_phases: Option<HashSet<String>>,
24 pub sample_rate: f64,
26 pub max_events_per_poll: Option<usize>,
28 pub deduplicate: bool,
30 pub poll_interval_seconds: f64,
32}
33
34impl Default for StreamConfig {
35 fn default() -> Self {
36 Self::with_default_filters()
37 }
38}
39
40impl StreamConfig {
41 pub fn all() -> Self {
43 Self {
44 enabled_streams: [
45 StreamType::Status,
46 StreamType::Events,
47 StreamType::Metrics,
48 StreamType::Timeline,
49 ]
50 .into_iter()
51 .collect(),
52 event_types: None,
53 event_types_exclude: None,
54 event_levels: None,
55 metric_names: None,
56 metric_phases: None,
57 timeline_phases: None,
58 sample_rate: 1.0,
59 max_events_per_poll: None,
60 deduplicate: true,
61 poll_interval_seconds: 2.0,
62 }
63 }
64
65 pub fn with_default_filters() -> Self {
67 let mut config = Self::all();
68 config.event_types_exclude = Some(
70 [
71 "sft.progress",
72 "sft.loss",
73 "sft.upstream.status",
74 "internal.heartbeat",
75 ]
76 .iter()
77 .map(|s| s.to_string())
78 .collect(),
79 );
80 config
81 }
82
83 pub fn minimal() -> Self {
85 Self {
86 enabled_streams: [StreamType::Status].into_iter().collect(),
87 event_types: None,
88 event_types_exclude: None,
89 event_levels: None,
90 metric_names: None,
91 metric_phases: None,
92 timeline_phases: None,
93 sample_rate: 1.0,
94 max_events_per_poll: None,
95 deduplicate: true,
96 poll_interval_seconds: 5.0,
97 }
98 }
99
100 pub fn errors_only() -> Self {
102 Self {
103 enabled_streams: [StreamType::Status, StreamType::Events]
104 .into_iter()
105 .collect(),
106 event_types: None,
107 event_types_exclude: None,
108 event_levels: Some(["error", "warning"].iter().map(|s| s.to_string()).collect()),
109 metric_names: None,
110 metric_phases: None,
111 timeline_phases: None,
112 sample_rate: 1.0,
113 max_events_per_poll: None,
114 deduplicate: true,
115 poll_interval_seconds: 2.0,
116 }
117 }
118
119 pub fn metrics_only() -> Self {
121 Self {
122 enabled_streams: [StreamType::Status, StreamType::Metrics]
123 .into_iter()
124 .collect(),
125 event_types: None,
126 event_types_exclude: None,
127 event_levels: None,
128 metric_names: None,
129 metric_phases: None,
130 timeline_phases: None,
131 sample_rate: 1.0,
132 max_events_per_poll: None,
133 deduplicate: true,
134 poll_interval_seconds: 1.0,
135 }
136 }
137
138 pub fn enable_stream(mut self, stream_type: StreamType) -> Self {
140 self.enabled_streams.insert(stream_type);
141 self
142 }
143
144 pub fn disable_stream(mut self, stream_type: StreamType) -> Self {
146 self.enabled_streams.remove(&stream_type);
147 self
148 }
149
150 pub fn include_event_type(mut self, event_type: impl Into<String>) -> Self {
152 let types = self.event_types.get_or_insert_with(HashSet::new);
153 types.insert(event_type.into());
154 self
155 }
156
157 pub fn exclude_event_type(mut self, event_type: impl Into<String>) -> Self {
159 let types = self.event_types_exclude.get_or_insert_with(HashSet::new);
160 types.insert(event_type.into());
161 self
162 }
163
164 pub fn with_levels(mut self, levels: Vec<&str>) -> Self {
166 self.event_levels = Some(levels.into_iter().map(String::from).collect());
167 self
168 }
169
170 pub fn with_metric_phases(mut self, phases: Vec<&str>) -> Self {
172 self.metric_phases = Some(phases.into_iter().map(String::from).collect());
173 self
174 }
175
176 pub fn with_timeline_phases(mut self, phases: Vec<&str>) -> Self {
178 self.timeline_phases = Some(phases.into_iter().map(String::from).collect());
179 self
180 }
181
182 pub fn with_interval(mut self, seconds: f64) -> Self {
184 self.poll_interval_seconds = seconds;
185 self
186 }
187
188 pub fn with_sample_rate(mut self, rate: f64) -> Self {
190 self.sample_rate = rate.clamp(0.0, 1.0);
191 self
192 }
193
194 pub fn without_deduplication(mut self) -> Self {
196 self.deduplicate = false;
197 self
198 }
199
200 pub fn is_stream_enabled(&self, stream_type: StreamType) -> bool {
202 self.enabled_streams.contains(&stream_type)
203 }
204
205 pub fn should_include_event(&self, event: &Value) -> bool {
207 let event_type = event.get("type").and_then(|v| v.as_str()).unwrap_or("");
208
209 if let Some(ref exclude) = self.event_types_exclude {
211 if exclude.contains(event_type) {
212 return false;
213 }
214 }
215
216 if let Some(ref include) = self.event_types {
218 if !include.contains(event_type) {
219 return false;
220 }
221 }
222
223 if let Some(ref levels) = self.event_levels {
225 let level = event.get("level").and_then(|v| v.as_str()).unwrap_or("");
226 if !levels.contains(level) {
227 return false;
228 }
229 }
230
231 if self.sample_rate < 1.0 {
233 use std::hash::{Hash, Hasher};
234 let mut hasher = std::collections::hash_map::DefaultHasher::new();
235 event.to_string().hash(&mut hasher);
236 let hash = hasher.finish();
237 let threshold = (self.sample_rate * u64::MAX as f64) as u64;
238 if hash > threshold {
239 return false;
240 }
241 }
242
243 true
244 }
245
246 pub fn should_include_metric(&self, metric: &Value) -> bool {
248 if let Some(ref names) = self.metric_names {
249 let name = metric.get("name").and_then(|v| v.as_str()).unwrap_or("");
250 if !names.contains(name) {
251 return false;
252 }
253 }
254
255 if let Some(ref phases) = self.metric_phases {
256 let phase = metric.get("phase").and_then(|v| v.as_str()).unwrap_or("");
257 return phases.contains(phase);
258 }
259 true
260 }
261
262 pub fn should_include_timeline(&self, entry: &Value) -> bool {
264 if let Some(ref phases) = self.timeline_phases {
265 let phase = entry.get("phase").and_then(|v| v.as_str()).unwrap_or("");
266 return phases.contains(phase);
267 }
268 true
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_all_config() {
278 let config = StreamConfig::all();
279 assert!(config.is_stream_enabled(StreamType::Status));
280 assert!(config.is_stream_enabled(StreamType::Events));
281 assert!(config.is_stream_enabled(StreamType::Metrics));
282 assert!(config.is_stream_enabled(StreamType::Timeline));
283 }
284
285 #[test]
286 fn test_minimal_config() {
287 let config = StreamConfig::minimal();
288 assert!(config.is_stream_enabled(StreamType::Status));
289 assert!(!config.is_stream_enabled(StreamType::Events));
290 }
291
292 #[test]
293 fn test_event_blacklist() {
294 let config = StreamConfig::all().exclude_event_type("internal.heartbeat");
295
296 let allowed = serde_json::json!({"type": "progress"});
297 let blocked = serde_json::json!({"type": "internal.heartbeat"});
298
299 assert!(config.should_include_event(&allowed));
300 assert!(!config.should_include_event(&blocked));
301 }
302
303 #[test]
304 fn test_event_whitelist() {
305 let config = StreamConfig::all()
306 .include_event_type("error")
307 .include_event_type("warning");
308
309 let allowed = serde_json::json!({"type": "error"});
310 let blocked = serde_json::json!({"type": "progress"});
311
312 assert!(config.should_include_event(&allowed));
313 assert!(!config.should_include_event(&blocked));
314 }
315
316 #[test]
317 fn test_level_filter() {
318 let config = StreamConfig::errors_only();
319
320 let error = serde_json::json!({"type": "test", "level": "error"});
321 let warning = serde_json::json!({"type": "test", "level": "warning"});
322 let info = serde_json::json!({"type": "test", "level": "info"});
323
324 assert!(config.should_include_event(&error));
325 assert!(config.should_include_event(&warning));
326 assert!(!config.should_include_event(&info));
327 }
328
329 #[test]
330 fn test_stream_enable_disable() {
331 let config = StreamConfig::all()
332 .disable_stream(StreamType::Timeline)
333 .disable_stream(StreamType::Metrics);
334
335 assert!(config.is_stream_enabled(StreamType::Status));
336 assert!(config.is_stream_enabled(StreamType::Events));
337 assert!(!config.is_stream_enabled(StreamType::Metrics));
338 assert!(!config.is_stream_enabled(StreamType::Timeline));
339 }
340
341 #[test]
342 fn test_builder_pattern() {
343 let config = StreamConfig::minimal()
344 .enable_stream(StreamType::Events)
345 .exclude_event_type("heartbeat")
346 .with_interval(1.0)
347 .without_deduplication();
348
349 assert!(config.is_stream_enabled(StreamType::Events));
350 assert_eq!(config.poll_interval_seconds, 1.0);
351 assert!(!config.deduplicate);
352 }
353}