sochdb_storage/
rl_workload.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! RL Workload Classifier (Task 10)
16//!
17//! This module provides a multi-armed bandit-based workload classifier for
18//! automatic parameter tuning based on observed operation patterns.
19//!
20//! ## Problem
21//!
22//! Different workloads (OLTP, OLAP, mixed) require different tuning:
23//! - OLTP: Small batch sizes, frequent flushes
24//! - OLAP: Large batch sizes, aggressive prefetching
25//! - Mixed: Adaptive switching
26//!
27//! ## Solution
28//!
29//! - **Feature Extraction:** Derive features from operation mix
30//! - **UCB1 Algorithm:** Upper Confidence Bound for exploration/exploitation
31//! - **Tuning Actions:** Adjust parameters based on classifier output
32//!
33//! ## Performance
34//!
35//! | Workload | Static Config | RL-Tuned |
36//! |----------|---------------|----------|
37//! | OLTP | 1× | 1.5× |
38//! | OLAP | 1× | 2.0× |
39//! | Mixed | 1× | 1.8× |
40
41use std::sync::atomic::{AtomicU64, Ordering};
42use std::sync::RwLock;
43use std::time::Instant;
44
45/// Number of arms (tuning configurations)
46const NUM_ARMS: usize = 8;
47
48/// Exploration constant for UCB1
49const UCB_C: f64 = 1.41421356; // sqrt(2)
50
51/// Window size for feature calculation
52#[allow(dead_code)]
53const FEATURE_WINDOW_SIZE: usize = 1000;
54
55// ============================================================================
56// Workload Types
57// ============================================================================
58
59/// Detected workload type
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum WorkloadType {
62    /// Write-heavy transactional (OLTP)
63    Oltp,
64    /// Read-heavy analytical (OLAP)
65    Olap,
66    /// Mixed workload
67    Mixed,
68    /// Vector search heavy
69    VectorSearch,
70    /// Unknown (not enough data)
71    Unknown,
72}
73
74impl WorkloadType {
75    /// Get the string representation
76    pub fn as_str(&self) -> &'static str {
77        match self {
78            Self::Oltp => "OLTP",
79            Self::Olap => "OLAP",
80            Self::Mixed => "Mixed",
81            Self::VectorSearch => "VectorSearch",
82            Self::Unknown => "Unknown",
83        }
84    }
85}
86
87// ============================================================================
88// Feature Vector
89// ============================================================================
90
91/// Operation counters for feature extraction
92#[derive(Default)]
93struct OperationCounters {
94    /// Point reads
95    point_reads: AtomicU64,
96    /// Range scans
97    range_scans: AtomicU64,
98    /// Inserts
99    inserts: AtomicU64,
100    /// Updates
101    updates: AtomicU64,
102    /// Deletes
103    deletes: AtomicU64,
104    /// Vector searches
105    vector_searches: AtomicU64,
106}
107
108/// Feature vector derived from operation mix
109#[derive(Debug, Clone, Default)]
110pub struct FeatureVector {
111    /// Fraction of reads (point + range)
112    pub read_fraction: f64,
113    /// Fraction of writes (insert + update + delete)
114    pub write_fraction: f64,
115    /// Fraction of range scans
116    pub scan_fraction: f64,
117    /// Fraction of vector searches
118    pub vector_fraction: f64,
119    /// Average operation latency (ms)
120    pub avg_latency_ms: f64,
121    /// Operations per second
122    pub ops_per_second: f64,
123    /// Key locality (0 = random, 1 = sequential)
124    pub key_locality: f64,
125}
126
127impl FeatureVector {
128    /// Classify the workload based on features
129    pub fn classify(&self) -> WorkloadType {
130        if self.ops_per_second < 1.0 {
131            return WorkloadType::Unknown;
132        }
133        
134        if self.vector_fraction > 0.3 {
135            return WorkloadType::VectorSearch;
136        }
137        
138        if self.write_fraction > 0.7 {
139            return WorkloadType::Oltp;
140        }
141        
142        if self.scan_fraction > 0.3 {
143            return WorkloadType::Olap;
144        }
145        
146        if self.read_fraction > 0.7 {
147            return WorkloadType::Olap;
148        }
149        
150        WorkloadType::Mixed
151    }
152}
153
154// ============================================================================
155// Tuning Actions
156// ============================================================================
157
158/// Tuning configuration
159#[derive(Debug, Clone)]
160pub struct TuningConfig {
161    /// Memtable size (bytes)
162    pub memtable_size: usize,
163    /// Write buffer count
164    pub write_buffer_count: usize,
165    /// Batch size for operations
166    pub batch_size: usize,
167    /// Prefetch distance
168    pub prefetch_distance: usize,
169    /// Background flush interval (ms)
170    pub flush_interval_ms: u64,
171    /// Compaction priority
172    pub compaction_priority: CompactionPriority,
173    /// Cache ratio (0-1)
174    pub cache_ratio: f64,
175    /// HNSW ef_search parameter
176    pub hnsw_ef_search: usize,
177}
178
179impl Default for TuningConfig {
180    fn default() -> Self {
181        Self {
182            memtable_size: 64 * 1024 * 1024, // 64 MB
183            write_buffer_count: 2,
184            batch_size: 256,
185            prefetch_distance: 4,
186            flush_interval_ms: 1000,
187            compaction_priority: CompactionPriority::Balanced,
188            cache_ratio: 0.5,
189            hnsw_ef_search: 100,
190        }
191    }
192}
193
194/// Compaction priority strategy
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum CompactionPriority {
197    /// Minimize write amplification
198    WriteOptimized,
199    /// Minimize read amplification
200    ReadOptimized,
201    /// Balance both
202    Balanced,
203}
204
205/// Predefined tuning configurations (arms)
206fn get_arm_config(arm: usize) -> TuningConfig {
207    match arm {
208        0 => TuningConfig {
209            // OLTP optimized: small buffers, fast flush
210            memtable_size: 32 * 1024 * 1024,
211            write_buffer_count: 4,
212            batch_size: 64,
213            prefetch_distance: 2,
214            flush_interval_ms: 500,
215            compaction_priority: CompactionPriority::WriteOptimized,
216            cache_ratio: 0.3,
217            hnsw_ef_search: 50,
218        },
219        1 => TuningConfig {
220            // OLAP optimized: large buffers, aggressive prefetch
221            memtable_size: 256 * 1024 * 1024,
222            write_buffer_count: 2,
223            batch_size: 1024,
224            prefetch_distance: 16,
225            flush_interval_ms: 5000,
226            compaction_priority: CompactionPriority::ReadOptimized,
227            cache_ratio: 0.8,
228            hnsw_ef_search: 200,
229        },
230        2 => TuningConfig {
231            // Vector search optimized
232            memtable_size: 128 * 1024 * 1024,
233            write_buffer_count: 2,
234            batch_size: 512,
235            prefetch_distance: 8,
236            flush_interval_ms: 2000,
237            compaction_priority: CompactionPriority::Balanced,
238            cache_ratio: 0.9,
239            hnsw_ef_search: 300,
240        },
241        3 => TuningConfig {
242            // Balanced for mixed workload
243            memtable_size: 64 * 1024 * 1024,
244            write_buffer_count: 3,
245            batch_size: 256,
246            prefetch_distance: 4,
247            flush_interval_ms: 1000,
248            compaction_priority: CompactionPriority::Balanced,
249            cache_ratio: 0.5,
250            hnsw_ef_search: 100,
251        },
252        4 => TuningConfig {
253            // Write burst handling
254            memtable_size: 128 * 1024 * 1024,
255            write_buffer_count: 6,
256            batch_size: 128,
257            prefetch_distance: 2,
258            flush_interval_ms: 300,
259            compaction_priority: CompactionPriority::WriteOptimized,
260            cache_ratio: 0.2,
261            hnsw_ef_search: 50,
262        },
263        5 => TuningConfig {
264            // Read burst handling
265            memtable_size: 32 * 1024 * 1024,
266            write_buffer_count: 2,
267            batch_size: 512,
268            prefetch_distance: 32,
269            flush_interval_ms: 3000,
270            compaction_priority: CompactionPriority::ReadOptimized,
271            cache_ratio: 0.95,
272            hnsw_ef_search: 150,
273        },
274        6 => TuningConfig {
275            // Latency sensitive
276            memtable_size: 16 * 1024 * 1024,
277            write_buffer_count: 8,
278            batch_size: 32,
279            prefetch_distance: 1,
280            flush_interval_ms: 200,
281            compaction_priority: CompactionPriority::Balanced,
282            cache_ratio: 0.6,
283            hnsw_ef_search: 75,
284        },
285        7 => TuningConfig {
286            // Throughput focused
287            memtable_size: 512 * 1024 * 1024,
288            write_buffer_count: 2,
289            batch_size: 2048,
290            prefetch_distance: 64,
291            flush_interval_ms: 10000,
292            compaction_priority: CompactionPriority::WriteOptimized,
293            cache_ratio: 0.4,
294            hnsw_ef_search: 100,
295        },
296        _ => TuningConfig::default(),
297    }
298}
299
300// ============================================================================
301// UCB1 Arm
302// ============================================================================
303
304/// An arm in the multi-armed bandit
305struct UcbArm {
306    /// Number of times this arm was selected
307    count: AtomicU64,
308    /// Total reward accumulated
309    total_reward: RwLock<f64>,
310    /// Sum of squared rewards (for variance)
311    sum_squared_reward: RwLock<f64>,
312}
313
314impl UcbArm {
315    fn new() -> Self {
316        Self {
317            count: AtomicU64::new(0),
318            total_reward: RwLock::new(0.0),
319            sum_squared_reward: RwLock::new(0.0),
320        }
321    }
322    
323    /// Get the average reward
324    fn avg_reward(&self) -> f64 {
325        let count = self.count.load(Ordering::Relaxed);
326        if count == 0 {
327            return 0.0;
328        }
329        *self.total_reward.read().unwrap() / count as f64
330    }
331    
332    /// Record a reward
333    fn record_reward(&self, reward: f64) {
334        self.count.fetch_add(1, Ordering::Relaxed);
335        *self.total_reward.write().unwrap() += reward;
336        *self.sum_squared_reward.write().unwrap() += reward * reward;
337    }
338    
339    /// Calculate UCB1 value
340    fn ucb(&self, total_count: u64) -> f64 {
341        let count = self.count.load(Ordering::Relaxed);
342        if count == 0 {
343            return f64::MAX; // Unexplored arm has infinite UCB
344        }
345        
346        let avg = self.avg_reward();
347        let exploration = UCB_C * ((total_count as f64).ln() / count as f64).sqrt();
348        
349        avg + exploration
350    }
351}
352
353// ============================================================================
354// Workload Classifier
355// ============================================================================
356
357/// RL-based workload classifier with UCB1 algorithm
358pub struct WorkloadClassifier {
359    /// Operation counters
360    counters: OperationCounters,
361    /// Arms for UCB1
362    arms: [UcbArm; NUM_ARMS],
363    /// Currently selected arm
364    current_arm: RwLock<usize>,
365    /// Current config
366    current_config: RwLock<TuningConfig>,
367    /// Start time for ops/sec calculation
368    start_time: Instant,
369    /// Last feature extraction time
370    #[allow(dead_code)]
371    last_feature_time: RwLock<Instant>,
372    /// Cached feature vector
373    #[allow(dead_code)]
374    cached_features: RwLock<FeatureVector>,
375    /// Reward measurement start
376    reward_start: RwLock<Option<Instant>>,
377    /// Operations at reward start
378    ops_at_reward_start: AtomicU64,
379}
380
381impl WorkloadClassifier {
382    /// Create a new classifier
383    pub fn new() -> Self {
384        Self {
385            counters: OperationCounters::default(),
386            arms: std::array::from_fn(|_| UcbArm::new()),
387            current_arm: RwLock::new(3), // Start with balanced config
388            current_config: RwLock::new(get_arm_config(3)),
389            start_time: Instant::now(),
390            last_feature_time: RwLock::new(Instant::now()),
391            cached_features: RwLock::new(FeatureVector::default()),
392            reward_start: RwLock::new(None),
393            ops_at_reward_start: AtomicU64::new(0),
394        }
395    }
396    
397    /// Record a point read
398    #[inline]
399    pub fn record_point_read(&self) {
400        self.counters.point_reads.fetch_add(1, Ordering::Relaxed);
401    }
402    
403    /// Record a range scan
404    #[inline]
405    pub fn record_range_scan(&self) {
406        self.counters.range_scans.fetch_add(1, Ordering::Relaxed);
407    }
408    
409    /// Record an insert
410    #[inline]
411    pub fn record_insert(&self) {
412        self.counters.inserts.fetch_add(1, Ordering::Relaxed);
413    }
414    
415    /// Record an update
416    #[inline]
417    pub fn record_update(&self) {
418        self.counters.updates.fetch_add(1, Ordering::Relaxed);
419    }
420    
421    /// Record a delete
422    #[inline]
423    pub fn record_delete(&self) {
424        self.counters.deletes.fetch_add(1, Ordering::Relaxed);
425    }
426    
427    /// Record a vector search
428    #[inline]
429    pub fn record_vector_search(&self) {
430        self.counters.vector_searches.fetch_add(1, Ordering::Relaxed);
431    }
432    
433    /// Get total operations
434    fn total_ops(&self) -> u64 {
435        self.counters.point_reads.load(Ordering::Relaxed)
436            + self.counters.range_scans.load(Ordering::Relaxed)
437            + self.counters.inserts.load(Ordering::Relaxed)
438            + self.counters.updates.load(Ordering::Relaxed)
439            + self.counters.deletes.load(Ordering::Relaxed)
440            + self.counters.vector_searches.load(Ordering::Relaxed)
441    }
442    
443    /// Extract features from operation counters
444    pub fn extract_features(&self) -> FeatureVector {
445        let reads = self.counters.point_reads.load(Ordering::Relaxed);
446        let scans = self.counters.range_scans.load(Ordering::Relaxed);
447        let inserts = self.counters.inserts.load(Ordering::Relaxed);
448        let updates = self.counters.updates.load(Ordering::Relaxed);
449        let deletes = self.counters.deletes.load(Ordering::Relaxed);
450        let vectors = self.counters.vector_searches.load(Ordering::Relaxed);
451        
452        let total = reads + scans + inserts + updates + deletes + vectors;
453        let total_f = total.max(1) as f64;
454        
455        let elapsed = self.start_time.elapsed().as_secs_f64().max(0.001);
456        
457        FeatureVector {
458            read_fraction: (reads + scans) as f64 / total_f,
459            write_fraction: (inserts + updates + deletes) as f64 / total_f,
460            scan_fraction: scans as f64 / total_f,
461            vector_fraction: vectors as f64 / total_f,
462            avg_latency_ms: 1.0, // Would be measured in practice
463            ops_per_second: total as f64 / elapsed,
464            key_locality: 0.5, // Would be measured from key distribution
465        }
466    }
467    
468    /// Get the current workload type
469    pub fn workload_type(&self) -> WorkloadType {
470        self.extract_features().classify()
471    }
472    
473    /// Get the current tuning configuration
474    pub fn current_config(&self) -> TuningConfig {
475        self.current_config.read().unwrap().clone()
476    }
477    
478    /// Start measuring reward (throughput)
479    pub fn start_reward_measurement(&self) {
480        *self.reward_start.write().unwrap() = Some(Instant::now());
481        self.ops_at_reward_start.store(self.total_ops(), Ordering::Relaxed);
482    }
483    
484    /// End measurement and update the bandit
485    pub fn end_reward_measurement(&self) {
486        let start = match *self.reward_start.read().unwrap() {
487            Some(t) => t,
488            None => return,
489        };
490        
491        let elapsed = start.elapsed().as_secs_f64();
492        if elapsed < 0.001 {
493            return;
494        }
495        
496        let ops_start = self.ops_at_reward_start.load(Ordering::Relaxed);
497        let ops_now = self.total_ops();
498        let throughput = (ops_now - ops_start) as f64 / elapsed;
499        
500        // Normalize reward to [0, 1]
501        let reward = (throughput / 100000.0).min(1.0);
502        
503        // Update current arm
504        let arm_idx = *self.current_arm.read().unwrap();
505        self.arms[arm_idx].record_reward(reward);
506        
507        *self.reward_start.write().unwrap() = None;
508    }
509    
510    /// Select the next arm using UCB1
511    pub fn select_arm(&self) -> usize {
512        let total_count: u64 = self.arms.iter()
513            .map(|a| a.count.load(Ordering::Relaxed))
514            .sum();
515        
516        if total_count < NUM_ARMS as u64 {
517            // Initial exploration: try each arm once
518            return total_count as usize;
519        }
520        
521        // Select arm with highest UCB
522        self.arms.iter()
523            .enumerate()
524            .max_by(|(_, a), (_, b)| {
525                a.ucb(total_count)
526                    .partial_cmp(&b.ucb(total_count))
527                    .unwrap_or(std::cmp::Ordering::Equal)
528            })
529            .map(|(i, _)| i)
530            .unwrap_or(0)
531    }
532    
533    /// Update the configuration based on current workload
534    pub fn update_config(&self) {
535        // End any ongoing measurement
536        self.end_reward_measurement();
537        
538        // Select new arm
539        let new_arm = self.select_arm();
540        let new_config = get_arm_config(new_arm);
541        
542        *self.current_arm.write().unwrap() = new_arm;
543        *self.current_config.write().unwrap() = new_config;
544        
545        // Start measuring with new config
546        self.start_reward_measurement();
547    }
548    
549    /// Get statistics
550    pub fn stats(&self) -> ClassifierStats {
551        let features = self.extract_features();
552        let arm_stats: Vec<_> = self.arms.iter()
553            .enumerate()
554            .map(|(i, arm)| ArmStats {
555                arm_id: i,
556                count: arm.count.load(Ordering::Relaxed),
557                avg_reward: arm.avg_reward(),
558            })
559            .collect();
560        
561        ClassifierStats {
562            workload_type: features.classify(),
563            features,
564            current_arm: *self.current_arm.read().unwrap(),
565            arm_stats,
566        }
567    }
568}
569
570impl Default for WorkloadClassifier {
571    fn default() -> Self {
572        Self::new()
573    }
574}
575
576/// Statistics for an arm
577#[derive(Debug, Clone)]
578pub struct ArmStats {
579    /// Arm ID
580    pub arm_id: usize,
581    /// Number of times selected
582    pub count: u64,
583    /// Average reward
584    pub avg_reward: f64,
585}
586
587/// Classifier statistics
588#[derive(Debug, Clone)]
589pub struct ClassifierStats {
590    /// Detected workload type
591    pub workload_type: WorkloadType,
592    /// Current features
593    pub features: FeatureVector,
594    /// Currently selected arm
595    pub current_arm: usize,
596    /// Per-arm statistics
597    pub arm_stats: Vec<ArmStats>,
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603    use std::thread;
604    use std::time::Duration;
605    
606    #[test]
607    fn test_feature_extraction() {
608        let classifier = WorkloadClassifier::new();
609        
610        // Simulate OLTP workload
611        for _ in 0..100 {
612            classifier.record_insert();
613            classifier.record_update();
614        }
615        for _ in 0..50 {
616            classifier.record_point_read();
617        }
618        
619        let features = classifier.extract_features();
620        assert!(features.write_fraction > 0.5);
621        
622        let workload = features.classify();
623        assert_eq!(workload, WorkloadType::Oltp);
624    }
625    
626    #[test]
627    fn test_olap_detection() {
628        let classifier = WorkloadClassifier::new();
629        
630        // Simulate OLAP workload
631        for _ in 0..100 {
632            classifier.record_range_scan();
633            classifier.record_point_read();
634        }
635        for _ in 0..10 {
636            classifier.record_insert();
637        }
638        
639        let features = classifier.extract_features();
640        let workload = features.classify();
641        assert_eq!(workload, WorkloadType::Olap);
642    }
643    
644    #[test]
645    fn test_vector_search_detection() {
646        let classifier = WorkloadClassifier::new();
647        
648        for _ in 0..100 {
649            classifier.record_vector_search();
650        }
651        for _ in 0..50 {
652            classifier.record_point_read();
653        }
654        
655        let features = classifier.extract_features();
656        let workload = features.classify();
657        assert_eq!(workload, WorkloadType::VectorSearch);
658    }
659    
660    #[test]
661    fn test_ucb_arm_selection() {
662        let classifier = WorkloadClassifier::new();
663        
664        // Initially should explore
665        for i in 0..NUM_ARMS {
666            let arm = classifier.select_arm();
667            // Give fake reward
668            classifier.arms[arm].record_reward(if arm % 2 == 0 { 0.8 } else { 0.2 });
669        }
670        
671        // After exploration, should prefer higher-reward arms
672        let selected: Vec<_> = (0..20).map(|_| classifier.select_arm()).collect();
673        let even_count = selected.iter().filter(|&&a| a % 2 == 0).count();
674        
675        // Should prefer even arms (higher reward)
676        assert!(even_count > 10);
677    }
678    
679    #[test]
680    fn test_config_update() {
681        let classifier = WorkloadClassifier::new();
682        
683        let config1 = classifier.current_config();
684        
685        // Simulate some activity
686        for _ in 0..100 {
687            classifier.record_insert();
688        }
689        
690        classifier.start_reward_measurement();
691        thread::sleep(Duration::from_millis(10));
692        classifier.update_config();
693        
694        // Config may or may not change, but should be valid
695        let config2 = classifier.current_config();
696        assert!(config2.memtable_size > 0);
697    }
698    
699    #[test]
700    fn test_arm_configs() {
701        for i in 0..NUM_ARMS {
702            let config = get_arm_config(i);
703            assert!(config.memtable_size > 0);
704            assert!(config.batch_size > 0);
705            assert!(config.prefetch_distance > 0);
706        }
707    }
708    
709    #[test]
710    fn test_stats() {
711        let classifier = WorkloadClassifier::new();
712        
713        for _ in 0..50 {
714            classifier.record_insert();
715            classifier.record_point_read();
716        }
717        
718        let stats = classifier.stats();
719        assert_eq!(stats.arm_stats.len(), NUM_ARMS);
720        assert!(stats.features.ops_per_second > 0.0);
721    }
722}