1use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Mutex;
8
9use super::rng::DeterministicRng;
10use crate::constants::DST_FAULT_PROBABILITY_MAX;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum FaultType {
17 StorageWriteFail,
22 StorageReadFail,
24 StorageDeleteFail,
26 StorageCorruption,
28 StorageDiskFull,
30 StorageLatency,
32
33 DbConnectionFail,
38 DbQueryTimeout,
40 DbDeadlock,
42 DbSerializationFail,
44 DbPoolExhausted,
46
47 NetworkTimeout,
52 NetworkConnectionRefused,
54 NetworkDnsFail,
56 NetworkPartialWrite,
58 NetworkReset,
60
61 LlmTimeout,
66 LlmRateLimit,
68 LlmContextOverflow,
70 LlmInvalidResponse,
72 LlmServiceUnavailable,
74
75 EmbeddingTimeout,
80 EmbeddingRateLimit,
82 EmbeddingContextOverflow,
84 EmbeddingInvalidResponse,
86 EmbeddingServiceUnavailable,
88
89 VectorSearchTimeout,
94 VectorSearchFail,
96 VectorStoreFail,
98
99 ResourceOom,
104 ResourceFileLimit,
106 ResourceCpuThrottle,
108
109 TimeClockSkew,
114 TimeLeapSecond,
116}
117
118impl FaultType {
119 #[must_use]
121 pub fn as_str(&self) -> &'static str {
122 match self {
123 Self::StorageWriteFail => "storage_write_fail",
124 Self::StorageReadFail => "storage_read_fail",
125 Self::StorageDeleteFail => "storage_delete_fail",
126 Self::StorageCorruption => "storage_corruption",
127 Self::StorageDiskFull => "storage_disk_full",
128 Self::StorageLatency => "storage_latency",
129 Self::DbConnectionFail => "db_connection_fail",
130 Self::DbQueryTimeout => "db_query_timeout",
131 Self::DbDeadlock => "db_deadlock",
132 Self::DbSerializationFail => "db_serialization_fail",
133 Self::DbPoolExhausted => "db_pool_exhausted",
134 Self::NetworkTimeout => "network_timeout",
135 Self::NetworkConnectionRefused => "network_connection_refused",
136 Self::NetworkDnsFail => "network_dns_fail",
137 Self::NetworkPartialWrite => "network_partial_write",
138 Self::NetworkReset => "network_reset",
139 Self::LlmTimeout => "llm_timeout",
140 Self::LlmRateLimit => "llm_rate_limit",
141 Self::LlmContextOverflow => "llm_context_overflow",
142 Self::LlmInvalidResponse => "llm_invalid_response",
143 Self::LlmServiceUnavailable => "llm_service_unavailable",
144 Self::EmbeddingTimeout => "embedding_timeout",
145 Self::EmbeddingRateLimit => "embedding_rate_limit",
146 Self::EmbeddingContextOverflow => "embedding_context_overflow",
147 Self::EmbeddingInvalidResponse => "embedding_invalid_response",
148 Self::EmbeddingServiceUnavailable => "embedding_service_unavailable",
149 Self::VectorSearchTimeout => "vector_search_timeout",
150 Self::VectorSearchFail => "vector_search_fail",
151 Self::VectorStoreFail => "vector_store_fail",
152 Self::ResourceOom => "resource_oom",
153 Self::ResourceFileLimit => "resource_file_limit",
154 Self::ResourceCpuThrottle => "resource_cpu_throttle",
155 Self::TimeClockSkew => "time_clock_skew",
156 Self::TimeLeapSecond => "time_leap_second",
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct FaultConfig {
164 pub fault_type: FaultType,
166 pub probability: f64,
168 pub operation_filter: Option<String>,
170 pub max_injections: Option<u64>,
172}
173
174impl FaultConfig {
175 #[must_use]
180 pub fn new(fault_type: FaultType, probability: f64) -> Self {
181 assert!(
183 (0.0..=DST_FAULT_PROBABILITY_MAX).contains(&probability),
184 "probability must be in [0, {}], got {}",
185 DST_FAULT_PROBABILITY_MAX,
186 probability
187 );
188
189 Self {
190 fault_type,
191 probability,
192 operation_filter: None,
193 max_injections: None,
194 }
195 }
196
197 #[must_use]
199 pub fn with_filter(mut self, filter: impl Into<String>) -> Self {
200 self.operation_filter = Some(filter.into());
201 self
202 }
203
204 #[must_use]
206 pub fn with_max_injections(mut self, max: u64) -> Self {
207 assert!(max > 0, "max_injections must be positive");
209 self.max_injections = Some(max);
210 self
211 }
212}
213
214#[derive(Debug, Default)]
216struct FaultStats {
217 injection_count: AtomicU64,
218}
219
220#[derive(Debug)]
228pub struct FaultInjector {
229 rng: Mutex<DeterministicRng>,
231 configs: Vec<FaultConfig>,
232 stats: HashMap<FaultType, FaultStats>,
233 injection_counts: Mutex<HashMap<FaultType, u64>>,
235}
236
237impl FaultInjector {
238 #[must_use]
240 pub fn new(rng: DeterministicRng) -> Self {
241 Self {
242 rng: Mutex::new(rng),
243 configs: Vec::new(),
244 stats: HashMap::new(),
245 injection_counts: Mutex::new(HashMap::new()),
246 }
247 }
248
249 pub fn register(&mut self, config: FaultConfig) {
253 assert!(
255 config.probability >= 0.0,
256 "probability must be non-negative"
257 );
258 assert!(config.probability <= 1.0, "probability must be <= 1.0");
259
260 self.stats.entry(config.fault_type).or_default();
262 self.injection_counts
263 .lock()
264 .unwrap()
265 .entry(config.fault_type)
266 .or_insert(0);
267
268 self.configs.push(config);
269 }
270
271 pub fn should_inject(&self, operation: &str) -> Option<FaultType> {
278 for config in &self.configs {
279 if let Some(ref filter) = config.operation_filter {
281 if !operation.contains(filter) {
282 continue;
283 }
284 }
285
286 if let Some(max) = config.max_injections {
288 let counts = self.injection_counts.lock().unwrap();
289 let count = counts.get(&config.fault_type).copied().unwrap_or(0);
290 if count >= max {
291 continue;
292 }
293 }
294
295 let should_inject = {
297 let mut rng = self.rng.lock().unwrap();
298 rng.next_bool(config.probability)
299 };
300
301 if should_inject {
302 if let Some(stats) = self.stats.get(&config.fault_type) {
304 stats.injection_count.fetch_add(1, Ordering::Relaxed);
305 }
306 {
307 let mut counts = self.injection_counts.lock().unwrap();
308 if let Some(count) = counts.get_mut(&config.fault_type) {
309 *count += 1;
310 }
311 }
312
313 return Some(config.fault_type);
314 }
315 }
316
317 None
318 }
319
320 #[must_use]
322 pub fn injection_stats(&self) -> HashMap<String, u64> {
323 self.stats
324 .iter()
325 .map(|(fault_type, stats)| {
326 (
327 fault_type.as_str().to_string(),
328 stats.injection_count.load(Ordering::Relaxed),
329 )
330 })
331 .collect()
332 }
333
334 #[must_use]
336 pub fn total_injections(&self) -> u64 {
337 self.stats
338 .values()
339 .map(|s| s.injection_count.load(Ordering::Relaxed))
340 .sum()
341 }
342
343 pub fn reset_stats(&self) {
345 for stats in self.stats.values() {
346 stats.injection_count.store(0, Ordering::Relaxed);
347 }
348 let mut counts = self.injection_counts.lock().unwrap();
349 for count in counts.values_mut() {
350 *count = 0;
351 }
352 }
353}
354
355pub struct FaultInjectorBuilder {
359 rng: DeterministicRng,
360 configs: Vec<FaultConfig>,
361}
362
363impl FaultInjectorBuilder {
364 #[must_use]
366 pub fn new(rng: DeterministicRng) -> Self {
367 Self {
368 rng,
369 configs: Vec::new(),
370 }
371 }
372
373 #[must_use]
375 pub fn with_fault(mut self, config: FaultConfig) -> Self {
376 self.configs.push(config);
377 self
378 }
379
380 #[must_use]
382 pub fn with_storage_faults(self, probability: f64) -> Self {
383 self.with_fault(FaultConfig::new(FaultType::StorageWriteFail, probability))
384 .with_fault(FaultConfig::new(FaultType::StorageReadFail, probability))
385 }
386
387 #[must_use]
389 pub fn with_db_faults(self, probability: f64) -> Self {
390 self.with_fault(FaultConfig::new(FaultType::DbConnectionFail, probability))
391 .with_fault(FaultConfig::new(FaultType::DbQueryTimeout, probability))
392 }
393
394 #[must_use]
396 pub fn with_llm_faults(self, probability: f64) -> Self {
397 self.with_fault(FaultConfig::new(FaultType::LlmTimeout, probability))
398 .with_fault(FaultConfig::new(FaultType::LlmRateLimit, probability))
399 }
400
401 #[must_use]
403 pub fn build(self) -> FaultInjector {
404 let mut injector = FaultInjector::new(self.rng);
405 for config in self.configs {
406 injector.register(config);
407 }
408 injector
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use std::sync::Arc;
416
417 #[test]
418 fn test_no_faults_registered() {
419 let rng = DeterministicRng::new(42);
420 let injector = FaultInjector::new(rng);
421
422 for _ in 0..100 {
423 assert!(injector.should_inject("any_operation").is_none());
424 }
425 }
426
427 #[test]
428 fn test_always_inject() {
429 let rng = DeterministicRng::new(42);
430 let mut injector = FaultInjector::new(rng);
431 injector.register(FaultConfig::new(FaultType::StorageWriteFail, 1.0));
432
433 for _ in 0..10 {
434 assert_eq!(
435 injector.should_inject("storage_write"),
436 Some(FaultType::StorageWriteFail)
437 );
438 }
439 }
440
441 #[test]
442 fn test_never_inject() {
443 let rng = DeterministicRng::new(42);
444 let mut injector = FaultInjector::new(rng);
445 injector.register(FaultConfig::new(FaultType::StorageWriteFail, 0.0));
446
447 for _ in 0..100 {
448 assert!(injector.should_inject("storage_write").is_none());
449 }
450 }
451
452 #[test]
453 fn test_operation_filter() {
454 let rng = DeterministicRng::new(42);
455 let mut injector = FaultInjector::new(rng);
456 injector.register(FaultConfig::new(FaultType::StorageWriteFail, 1.0).with_filter("write"));
457
458 assert_eq!(
460 injector.should_inject("storage_write"),
461 Some(FaultType::StorageWriteFail)
462 );
463
464 assert!(injector.should_inject("storage_read").is_none());
466 }
467
468 #[test]
469 fn test_max_injections() {
470 let rng = DeterministicRng::new(42);
471 let mut injector = FaultInjector::new(rng);
472 injector
473 .register(FaultConfig::new(FaultType::StorageWriteFail, 1.0).with_max_injections(2));
474
475 assert_eq!(
477 injector.should_inject("op"),
478 Some(FaultType::StorageWriteFail)
479 );
480 assert_eq!(
481 injector.should_inject("op"),
482 Some(FaultType::StorageWriteFail)
483 );
484
485 assert!(injector.should_inject("op").is_none());
487 }
488
489 #[test]
490 fn test_injection_stats() {
491 let rng = DeterministicRng::new(42);
492 let mut injector = FaultInjector::new(rng);
493 injector.register(FaultConfig::new(FaultType::StorageWriteFail, 1.0));
494
495 injector.should_inject("op");
496 injector.should_inject("op");
497 injector.should_inject("op");
498
499 let stats = injector.injection_stats();
500 assert_eq!(stats.get("storage_write_fail"), Some(&3));
501 assert_eq!(injector.total_injections(), 3);
502 }
503
504 #[test]
505 fn test_reset_stats() {
506 let rng = DeterministicRng::new(42);
507 let mut injector = FaultInjector::new(rng);
508 injector.register(FaultConfig::new(FaultType::StorageWriteFail, 1.0));
509
510 injector.should_inject("op");
511 assert_eq!(injector.total_injections(), 1);
512
513 injector.reset_stats();
514 assert_eq!(injector.total_injections(), 0);
515 }
516
517 #[test]
518 fn test_fault_type_as_str() {
519 assert_eq!(FaultType::StorageWriteFail.as_str(), "storage_write_fail");
520 assert_eq!(FaultType::DbDeadlock.as_str(), "db_deadlock");
521 assert_eq!(FaultType::LlmRateLimit.as_str(), "llm_rate_limit");
522 }
523
524 #[test]
525 #[should_panic(expected = "probability must be in")]
526 fn test_invalid_probability() {
527 let _ = FaultConfig::new(FaultType::StorageWriteFail, 1.5);
528 }
529
530 #[test]
531 #[should_panic(expected = "max_injections must be positive")]
532 fn test_invalid_max_injections() {
533 let _ = FaultConfig::new(FaultType::StorageWriteFail, 0.5).with_max_injections(0);
534 }
535
536 #[test]
537 fn test_builder_pattern() {
538 let rng = DeterministicRng::new(42);
539 let injector = FaultInjectorBuilder::new(rng)
540 .with_storage_faults(0.1)
541 .with_db_faults(0.05)
542 .build();
543
544 assert_eq!(injector.total_injections(), 0);
546 }
547
548 #[test]
549 fn test_arc_sharing() {
550 let rng = DeterministicRng::new(42);
552 let injector = Arc::new(
553 FaultInjectorBuilder::new(rng)
554 .with_fault(FaultConfig::new(FaultType::StorageWriteFail, 1.0))
555 .build(),
556 );
557
558 assert_eq!(
560 injector.should_inject("storage_write"),
561 Some(FaultType::StorageWriteFail)
562 );
563
564 let injector2 = Arc::clone(&injector);
566 assert_eq!(
567 injector2.should_inject("storage_write"),
568 Some(FaultType::StorageWriteFail)
569 );
570
571 assert_eq!(injector.total_injections(), 2);
573 }
574}