Skip to main content

rvf_runtime/
dos.rs

1//! DoS hardening for ADR-033 §3.3.1.
2//!
3//! Provides per-connection budget tokens, negative caching of degenerate
4//! queries, and optional proof-of-work for public endpoints.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9/// Per-connection token bucket for rate-limiting distance operations.
10///
11/// Each query consumes tokens from the bucket. When tokens are exhausted,
12/// queries are rejected until the bucket refills.
13pub struct BudgetTokenBucket {
14    /// Maximum tokens (distance ops) per window.
15    max_tokens: u64,
16    /// Current available tokens.
17    tokens: u64,
18    /// Window duration for token refill.
19    window: Duration,
20    /// Start of current window.
21    window_start: Instant,
22}
23
24impl BudgetTokenBucket {
25    /// Create a new token bucket.
26    ///
27    /// # Arguments
28    /// * `max_tokens` - Maximum distance ops per window.
29    /// * `window` - Duration of each refill window.
30    pub fn new(max_tokens: u64, window: Duration) -> Self {
31        Self {
32            max_tokens,
33            tokens: max_tokens,
34            window,
35            window_start: Instant::now(),
36        }
37    }
38
39    /// Try to consume `cost` tokens. Returns `Ok(remaining)` if sufficient
40    /// tokens are available, `Err(deficit)` if not.
41    pub fn try_consume(&mut self, cost: u64) -> Result<u64, u64> {
42        self.maybe_refill();
43
44        if cost <= self.tokens {
45            self.tokens -= cost;
46            Ok(self.tokens)
47        } else {
48            Err(cost - self.tokens)
49        }
50    }
51
52    /// Check remaining tokens without consuming.
53    pub fn remaining(&mut self) -> u64 {
54        self.maybe_refill();
55        self.tokens
56    }
57
58    /// Force a refill (for testing or manual reset).
59    pub fn refill(&mut self) {
60        self.tokens = self.max_tokens;
61        self.window_start = Instant::now();
62    }
63
64    fn maybe_refill(&mut self) {
65        if self.window_start.elapsed() >= self.window {
66            self.tokens = self.max_tokens;
67            self.window_start = Instant::now();
68        }
69    }
70}
71
72/// Quantized query signature for negative caching.
73///
74/// The query vector is quantized to int8 and hashed to produce a
75/// compact fingerprint for degenerate query detection.
76#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
77pub struct QuerySignature {
78    hash: u64,
79}
80
81impl QuerySignature {
82    /// Compute a signature from a query vector.
83    ///
84    /// Quantizes to int8, then hashes with FNV-1a for speed.
85    pub fn from_query(query: &[f32]) -> Self {
86        // FNV-1a hash of quantized vector.
87        let mut hash: u64 = 0xcbf29ce484222325;
88        for &val in query {
89            // Quantize to int8 range [-128, 127].
90            let quantized = (val.clamp(-1.0, 1.0) * 127.0) as i8;
91            hash ^= quantized as u64;
92            hash = hash.wrapping_mul(0x100000001b3);
93        }
94        Self { hash }
95    }
96}
97
98/// Negative cache entry tracking degenerate query hits.
99struct NegativeCacheEntry {
100    hit_count: u32,
101    first_seen: Instant,
102    last_seen: Instant,
103}
104
105/// Negative cache for degenerate queries.
106///
107/// If a query signature triggers degenerate mode more than N times
108/// in a window, forces `SafetyNetBudget::DISABLED` for subsequent
109/// matches, preventing repeated budget burn on the same attack vector.
110pub struct NegativeCache {
111    entries: HashMap<QuerySignature, NegativeCacheEntry>,
112    /// Number of degenerate hits before a signature is blacklisted.
113    threshold: u32,
114    /// Window duration for counting hits.
115    window: Duration,
116    /// Maximum cache size to prevent memory exhaustion.
117    max_entries: usize,
118}
119
120impl NegativeCache {
121    /// Create a new negative cache.
122    ///
123    /// # Arguments
124    /// * `threshold` - Number of degenerate hits before blacklisting.
125    /// * `window` - Duration window for counting hits.
126    /// * `max_entries` - Maximum cache entries.
127    pub fn new(threshold: u32, window: Duration, max_entries: usize) -> Self {
128        Self {
129            entries: HashMap::new(),
130            threshold,
131            window,
132            max_entries,
133        }
134    }
135
136    /// Record a degenerate query hit. Returns `true` if the query is
137    /// now blacklisted (should force DISABLED safety net).
138    pub fn record_degenerate(&mut self, sig: QuerySignature) -> bool {
139        let now = Instant::now();
140
141        // Evict expired entries periodically.
142        if self.entries.len() >= self.max_entries {
143            self.evict_expired(now);
144        }
145
146        // If still at capacity, evict oldest.
147        if self.entries.len() >= self.max_entries {
148            self.evict_oldest();
149        }
150
151        let entry = self.entries.entry(sig).or_insert(NegativeCacheEntry {
152            hit_count: 0,
153            first_seen: now,
154            last_seen: now,
155        });
156
157        // Reset if outside window.
158        if now.duration_since(entry.first_seen) > self.window {
159            entry.hit_count = 0;
160            entry.first_seen = now;
161        }
162
163        entry.hit_count += 1;
164        entry.last_seen = now;
165
166        entry.hit_count >= self.threshold
167    }
168
169    /// Check if a query signature is blacklisted.
170    pub fn is_blacklisted(&self, sig: &QuerySignature) -> bool {
171        if let Some(entry) = self.entries.get(sig) {
172            entry.hit_count >= self.threshold
173        } else {
174            false
175        }
176    }
177
178    /// Number of currently tracked signatures.
179    pub fn len(&self) -> usize {
180        self.entries.len()
181    }
182
183    /// Check if the cache is empty.
184    pub fn is_empty(&self) -> bool {
185        self.entries.is_empty()
186    }
187
188    fn evict_expired(&mut self, now: Instant) {
189        self.entries.retain(|_, entry| {
190            now.duration_since(entry.first_seen) <= self.window
191        });
192    }
193
194    fn evict_oldest(&mut self) {
195        if let Some(oldest_key) = self
196            .entries
197            .iter()
198            .min_by_key(|(_, e)| e.last_seen)
199            .map(|(k, _)| *k)
200        {
201            self.entries.remove(&oldest_key);
202        }
203    }
204}
205
206/// Proof-of-work challenge for public endpoints.
207///
208/// The caller must find a nonce such that `hash(challenge || nonce)`
209/// has `difficulty` leading zero bits. This is opt-in, not default.
210#[derive(Clone, Debug)]
211pub struct ProofOfWork {
212    /// The challenge bytes (typically random).
213    pub challenge: [u8; 16],
214    /// Required leading zero bits in the hash. Capped at MAX_DIFFICULTY.
215    pub difficulty: u8,
216}
217
218impl ProofOfWork {
219    /// Maximum allowed difficulty (24 bits = ~16M hashes average).
220    /// Higher values risk CPU-bound DoS.
221    pub const MAX_DIFFICULTY: u8 = 24;
222
223    /// Verify that a nonce satisfies the proof-of-work requirement.
224    ///
225    /// Uses FNV-1a for speed (this is DoS mitigation, not cryptographic security).
226    /// Clamps difficulty to MAX_DIFFICULTY to prevent compute DoS.
227    pub fn verify(&self, nonce: u64) -> bool {
228        let mut hash: u64 = 0xcbf29ce484222325;
229        for &byte in &self.challenge {
230            hash ^= byte as u64;
231            hash = hash.wrapping_mul(0x100000001b3);
232        }
233        for &byte in &nonce.to_le_bytes() {
234            hash ^= byte as u64;
235            hash = hash.wrapping_mul(0x100000001b3);
236        }
237
238        let clamped = self.difficulty.min(Self::MAX_DIFFICULTY);
239        let leading_zeros = hash.leading_zeros() as u8;
240        leading_zeros >= clamped
241    }
242
243    /// Find a valid nonce (for testing / client-side use).
244    /// Returns `None` if no nonce found within `max_attempts`.
245    pub fn solve(&self) -> Option<u64> {
246        let max_attempts: u64 = 1u64 << self.difficulty.min(Self::MAX_DIFFICULTY).min(30);
247        for nonce in 0..max_attempts.saturating_mul(4) {
248            if self.verify(nonce) {
249                return Some(nonce);
250            }
251        }
252        None
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn token_bucket_basic() {
262        let mut bucket = BudgetTokenBucket::new(100, Duration::from_secs(1));
263        assert_eq!(bucket.remaining(), 100);
264        assert_eq!(bucket.try_consume(30), Ok(70));
265        assert_eq!(bucket.remaining(), 70);
266    }
267
268    #[test]
269    fn token_bucket_exhaustion() {
270        let mut bucket = BudgetTokenBucket::new(10, Duration::from_secs(60));
271        assert_eq!(bucket.try_consume(10), Ok(0));
272        assert!(bucket.try_consume(1).is_err());
273    }
274
275    #[test]
276    fn token_bucket_refill() {
277        let mut bucket = BudgetTokenBucket::new(100, Duration::from_millis(1));
278        bucket.try_consume(100).unwrap();
279        assert!(bucket.try_consume(1).is_err());
280        std::thread::sleep(Duration::from_millis(2));
281        assert_eq!(bucket.remaining(), 100);
282    }
283
284    #[test]
285    fn token_bucket_manual_refill() {
286        let mut bucket = BudgetTokenBucket::new(100, Duration::from_secs(60));
287        bucket.try_consume(100).unwrap();
288        bucket.refill();
289        assert_eq!(bucket.remaining(), 100);
290    }
291
292    #[test]
293    fn query_signature_deterministic() {
294        let query = vec![0.1, 0.2, 0.3, 0.4];
295        let sig1 = QuerySignature::from_query(&query);
296        let sig2 = QuerySignature::from_query(&query);
297        assert_eq!(sig1, sig2);
298    }
299
300    #[test]
301    fn query_signature_different_vectors() {
302        let sig1 = QuerySignature::from_query(&[0.1, 0.2, 0.3]);
303        let sig2 = QuerySignature::from_query(&[0.4, 0.5, 0.6]);
304        assert_ne!(sig1, sig2);
305    }
306
307    #[test]
308    fn negative_cache_below_threshold() {
309        let mut cache = NegativeCache::new(3, Duration::from_secs(60), 1000);
310        let sig = QuerySignature::from_query(&[0.1, 0.2]);
311        assert!(!cache.record_degenerate(sig));
312        assert!(!cache.record_degenerate(sig));
313        assert!(!cache.is_blacklisted(&sig));
314    }
315
316    #[test]
317    fn negative_cache_reaches_threshold() {
318        let mut cache = NegativeCache::new(3, Duration::from_secs(60), 1000);
319        let sig = QuerySignature::from_query(&[0.1, 0.2]);
320        cache.record_degenerate(sig);
321        cache.record_degenerate(sig);
322        assert!(cache.record_degenerate(sig)); // 3rd hit = blacklisted.
323        assert!(cache.is_blacklisted(&sig));
324    }
325
326    #[test]
327    fn negative_cache_max_entries() {
328        let mut cache = NegativeCache::new(100, Duration::from_secs(60), 5);
329        for i in 0..10 {
330            let sig = QuerySignature::from_query(&[i as f32]);
331            cache.record_degenerate(sig);
332        }
333        assert!(cache.len() <= 5);
334    }
335
336    #[test]
337    fn negative_cache_empty() {
338        let cache = NegativeCache::new(3, Duration::from_secs(60), 1000);
339        assert!(cache.is_empty());
340        assert_eq!(cache.len(), 0);
341    }
342
343    #[test]
344    fn proof_of_work_low_difficulty() {
345        let pow = ProofOfWork {
346            challenge: [0xAB; 16],
347            difficulty: 1, // Very easy.
348        };
349        let nonce = pow.solve().expect("should solve easily");
350        assert!(pow.verify(nonce));
351    }
352
353    #[test]
354    fn proof_of_work_wrong_nonce() {
355        let pow = ProofOfWork {
356            challenge: [0xAB; 16],
357            difficulty: 16, // Moderate difficulty.
358        };
359        // Random nonce is very unlikely to pass.
360        assert!(!pow.verify(0xDEADBEEF));
361    }
362
363    #[test]
364    fn proof_of_work_solve_and_verify() {
365        let pow = ProofOfWork {
366            challenge: [0x42; 16],
367            difficulty: 8,
368        };
369        let nonce = pow.solve().expect("should solve d=8");
370        assert!(pow.verify(nonce));
371    }
372
373    #[test]
374    fn proof_of_work_max_difficulty_clamped() {
375        let pow = ProofOfWork {
376            challenge: [0x42; 16],
377            difficulty: 255, // Extreme — will be clamped to MAX_DIFFICULTY.
378        };
379        // verify() clamps internally, so this is equivalent to d=24.
380        // solve() uses clamped difficulty too.
381        assert_eq!(pow.difficulty.min(ProofOfWork::MAX_DIFFICULTY), 24);
382    }
383}