Skip to main content

sochdb_storage/
rl_workload.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! RL Workload Classifier (Task 10)
19//!
20//! This module provides a multi-armed bandit-based workload classifier for
21//! automatic parameter tuning based on observed operation patterns.
22//!
23//! ## Problem
24//!
25//! Different workloads (OLTP, OLAP, mixed) require different tuning:
26//! - OLTP: Small batch sizes, frequent flushes
27//! - OLAP: Large batch sizes, aggressive prefetching
28//! - Mixed: Adaptive switching
29//!
30//! ## Solution
31//!
32//! - **Feature Extraction:** Derive features from operation mix
33//! - **UCB1 Algorithm:** Upper Confidence Bound for exploration/exploitation
34//! - **Tuning Actions:** Adjust parameters based on classifier output
35//!
36//! ## Performance
37//!
38//! | Workload | Static Config | RL-Tuned |
39//! |----------|---------------|----------|
40//! | OLTP | 1× | 1.5× |
41//! | OLAP | 1× | 2.0× |
42//! | Mixed | 1× | 1.8× |
43
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::RwLock;
46use std::time::Instant;
47
48/// Number of arms (tuning configurations)
49const NUM_ARMS: usize = 8;
50
51/// Exploration constant for UCB1
52const UCB_C: f64 = 1.41421356; // sqrt(2)
53
54/// Window size for feature calculation
55#[allow(dead_code)]
56const FEATURE_WINDOW_SIZE: usize = 1000;
57
58// ============================================================================
59// Workload Types
60// ============================================================================
61
62/// Detected workload type
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum WorkloadType {
65    /// Write-heavy transactional (OLTP)
66    Oltp,
67    /// Read-heavy analytical (OLAP)
68    Olap,
69    /// Mixed workload
70    Mixed,
71    /// Vector search heavy
72    VectorSearch,
73    /// Unknown (not enough data)
74    Unknown,
75}
76
77impl WorkloadType {
78    /// Get the string representation
79    pub fn as_str(&self) -> &'static str {
80        match self {
81            Self::Oltp => "OLTP",
82            Self::Olap => "OLAP",
83            Self::Mixed => "Mixed",
84            Self::VectorSearch => "VectorSearch",
85            Self::Unknown => "Unknown",
86        }
87    }
88}
89
90// ============================================================================
91// Feature Vector
92// ============================================================================
93
94/// Operation counters for feature extraction
95#[derive(Default)]
96struct OperationCounters {
97    /// Point reads
98    point_reads: AtomicU64,
99    /// Range scans
100    range_scans: AtomicU64,
101    /// Inserts
102    inserts: AtomicU64,
103    /// Updates
104    updates: AtomicU64,
105    /// Deletes
106    deletes: AtomicU64,
107    /// Vector searches
108    vector_searches: AtomicU64,
109}
110
111/// Feature vector derived from operation mix
112#[derive(Debug, Clone, Default)]
113pub struct FeatureVector {
114    /// Fraction of reads (point + range)
115    pub read_fraction: f64,
116    /// Fraction of writes (insert + update + delete)
117    pub write_fraction: f64,
118    /// Fraction of range scans
119    pub scan_fraction: f64,
120    /// Fraction of vector searches
121    pub vector_fraction: f64,
122    /// Average operation latency (ms)
123    pub avg_latency_ms: f64,
124    /// Operations per second
125    pub ops_per_second: f64,
126    /// Key locality (0 = random, 1 = sequential)
127    pub key_locality: f64,
128}
129
130impl FeatureVector {
131    /// Classify the workload based on features
132    pub fn classify(&self) -> WorkloadType {
133        if self.ops_per_second < 1.0 {
134            return WorkloadType::Unknown;
135        }
136        
137        if self.vector_fraction > 0.3 {
138            return WorkloadType::VectorSearch;
139        }
140        
141        if self.write_fraction > 0.7 {
142            return WorkloadType::Oltp;
143        }
144        
145        if self.scan_fraction > 0.3 {
146            return WorkloadType::Olap;
147        }
148        
149        if self.read_fraction > 0.7 {
150            return WorkloadType::Olap;
151        }
152        
153        WorkloadType::Mixed
154    }
155}
156
157// ============================================================================
158// Tuning Actions
159// ============================================================================
160
161/// Tuning configuration
162#[derive(Debug, Clone)]
163pub struct TuningConfig {
164    /// Memtable size (bytes)
165    pub memtable_size: usize,
166    /// Write buffer count
167    pub write_buffer_count: usize,
168    /// Batch size for operations
169    pub batch_size: usize,
170    /// Prefetch distance
171    pub prefetch_distance: usize,
172    /// Background flush interval (ms)
173    pub flush_interval_ms: u64,
174    /// Compaction priority
175    pub compaction_priority: CompactionPriority,
176    /// Cache ratio (0-1)
177    pub cache_ratio: f64,
178    /// HNSW ef_search parameter
179    pub hnsw_ef_search: usize,
180}
181
182impl Default for TuningConfig {
183    fn default() -> Self {
184        Self {
185            memtable_size: 64 * 1024 * 1024, // 64 MB
186            write_buffer_count: 2,
187            batch_size: 256,
188            prefetch_distance: 4,
189            flush_interval_ms: 1000,
190            compaction_priority: CompactionPriority::Balanced,
191            cache_ratio: 0.5,
192            hnsw_ef_search: 100,
193        }
194    }
195}
196
197/// Compaction priority strategy
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum CompactionPriority {
200    /// Minimize write amplification
201    WriteOptimized,
202    /// Minimize read amplification
203    ReadOptimized,
204    /// Balance both
205    Balanced,
206}
207
208/// Predefined tuning configurations (arms)
209fn get_arm_config(arm: usize) -> TuningConfig {
210    match arm {
211        0 => TuningConfig {
212            // OLTP optimized: small buffers, fast flush
213            memtable_size: 32 * 1024 * 1024,
214            write_buffer_count: 4,
215            batch_size: 64,
216            prefetch_distance: 2,
217            flush_interval_ms: 500,
218            compaction_priority: CompactionPriority::WriteOptimized,
219            cache_ratio: 0.3,
220            hnsw_ef_search: 50,
221        },
222        1 => TuningConfig {
223            // OLAP optimized: large buffers, aggressive prefetch
224            memtable_size: 256 * 1024 * 1024,
225            write_buffer_count: 2,
226            batch_size: 1024,
227            prefetch_distance: 16,
228            flush_interval_ms: 5000,
229            compaction_priority: CompactionPriority::ReadOptimized,
230            cache_ratio: 0.8,
231            hnsw_ef_search: 200,
232        },
233        2 => TuningConfig {
234            // Vector search optimized
235            memtable_size: 128 * 1024 * 1024,
236            write_buffer_count: 2,
237            batch_size: 512,
238            prefetch_distance: 8,
239            flush_interval_ms: 2000,
240            compaction_priority: CompactionPriority::Balanced,
241            cache_ratio: 0.9,
242            hnsw_ef_search: 300,
243        },
244        3 => TuningConfig {
245            // Balanced for mixed workload
246            memtable_size: 64 * 1024 * 1024,
247            write_buffer_count: 3,
248            batch_size: 256,
249            prefetch_distance: 4,
250            flush_interval_ms: 1000,
251            compaction_priority: CompactionPriority::Balanced,
252            cache_ratio: 0.5,
253            hnsw_ef_search: 100,
254        },
255        4 => TuningConfig {
256            // Write burst handling
257            memtable_size: 128 * 1024 * 1024,
258            write_buffer_count: 6,
259            batch_size: 128,
260            prefetch_distance: 2,
261            flush_interval_ms: 300,
262            compaction_priority: CompactionPriority::WriteOptimized,
263            cache_ratio: 0.2,
264            hnsw_ef_search: 50,
265        },
266        5 => TuningConfig {
267            // Read burst handling
268            memtable_size: 32 * 1024 * 1024,
269            write_buffer_count: 2,
270            batch_size: 512,
271            prefetch_distance: 32,
272            flush_interval_ms: 3000,
273            compaction_priority: CompactionPriority::ReadOptimized,
274            cache_ratio: 0.95,
275            hnsw_ef_search: 150,
276        },
277        6 => TuningConfig {
278            // Latency sensitive
279            memtable_size: 16 * 1024 * 1024,
280            write_buffer_count: 8,
281            batch_size: 32,
282            prefetch_distance: 1,
283            flush_interval_ms: 200,
284            compaction_priority: CompactionPriority::Balanced,
285            cache_ratio: 0.6,
286            hnsw_ef_search: 75,
287        },
288        7 => TuningConfig {
289            // Throughput focused
290            memtable_size: 512 * 1024 * 1024,
291            write_buffer_count: 2,
292            batch_size: 2048,
293            prefetch_distance: 64,
294            flush_interval_ms: 10000,
295            compaction_priority: CompactionPriority::WriteOptimized,
296            cache_ratio: 0.4,
297            hnsw_ef_search: 100,
298        },
299        _ => TuningConfig::default(),
300    }
301}
302
303// ============================================================================
304// UCB1 Arm
305// ============================================================================
306
307/// An arm in the multi-armed bandit
308struct UcbArm {
309    /// Number of times this arm was selected
310    count: AtomicU64,
311    /// Total reward accumulated
312    total_reward: RwLock<f64>,
313    /// Sum of squared rewards (for variance)
314    sum_squared_reward: RwLock<f64>,
315}
316
317impl UcbArm {
318    fn new() -> Self {
319        Self {
320            count: AtomicU64::new(0),
321            total_reward: RwLock::new(0.0),
322            sum_squared_reward: RwLock::new(0.0),
323        }
324    }
325    
326    /// Get the average reward
327    fn avg_reward(&self) -> f64 {
328        let count = self.count.load(Ordering::Relaxed);
329        if count == 0 {
330            return 0.0;
331        }
332        *self.total_reward.read().unwrap() / count as f64
333    }
334    
335    /// Record a reward
336    fn record_reward(&self, reward: f64) {
337        self.count.fetch_add(1, Ordering::Relaxed);
338        *self.total_reward.write().unwrap() += reward;
339        *self.sum_squared_reward.write().unwrap() += reward * reward;
340    }
341    
342    /// Calculate UCB1 value
343    fn ucb(&self, total_count: u64) -> f64 {
344        let count = self.count.load(Ordering::Relaxed);
345        if count == 0 {
346            return f64::MAX; // Unexplored arm has infinite UCB
347        }
348        
349        let avg = self.avg_reward();
350        let exploration = UCB_C * ((total_count as f64).ln() / count as f64).sqrt();
351        
352        avg + exploration
353    }
354}
355
356// ============================================================================
357// Workload Classifier
358// ============================================================================
359
360/// RL-based workload classifier with UCB1 algorithm
361pub struct WorkloadClassifier {
362    /// Operation counters
363    counters: OperationCounters,
364    /// Arms for UCB1
365    arms: [UcbArm; NUM_ARMS],
366    /// Currently selected arm
367    current_arm: RwLock<usize>,
368    /// Current config
369    current_config: RwLock<TuningConfig>,
370    /// Start time for ops/sec calculation
371    start_time: Instant,
372    /// Last feature extraction time
373    #[allow(dead_code)]
374    last_feature_time: RwLock<Instant>,
375    /// Cached feature vector
376    #[allow(dead_code)]
377    cached_features: RwLock<FeatureVector>,
378    /// Reward measurement start
379    reward_start: RwLock<Option<Instant>>,
380    /// Operations at reward start
381    ops_at_reward_start: AtomicU64,
382}
383
384impl WorkloadClassifier {
385    /// Create a new classifier
386    pub fn new() -> Self {
387        Self {
388            counters: OperationCounters::default(),
389            arms: std::array::from_fn(|_| UcbArm::new()),
390            current_arm: RwLock::new(3), // Start with balanced config
391            current_config: RwLock::new(get_arm_config(3)),
392            start_time: Instant::now(),
393            last_feature_time: RwLock::new(Instant::now()),
394            cached_features: RwLock::new(FeatureVector::default()),
395            reward_start: RwLock::new(None),
396            ops_at_reward_start: AtomicU64::new(0),
397        }
398    }
399    
400    /// Record a point read
401    #[inline]
402    pub fn record_point_read(&self) {
403        self.counters.point_reads.fetch_add(1, Ordering::Relaxed);
404    }
405    
406    /// Record a range scan
407    #[inline]
408    pub fn record_range_scan(&self) {
409        self.counters.range_scans.fetch_add(1, Ordering::Relaxed);
410    }
411    
412    /// Record an insert
413    #[inline]
414    pub fn record_insert(&self) {
415        self.counters.inserts.fetch_add(1, Ordering::Relaxed);
416    }
417    
418    /// Record an update
419    #[inline]
420    pub fn record_update(&self) {
421        self.counters.updates.fetch_add(1, Ordering::Relaxed);
422    }
423    
424    /// Record a delete
425    #[inline]
426    pub fn record_delete(&self) {
427        self.counters.deletes.fetch_add(1, Ordering::Relaxed);
428    }
429    
430    /// Record a vector search
431    #[inline]
432    pub fn record_vector_search(&self) {
433        self.counters.vector_searches.fetch_add(1, Ordering::Relaxed);
434    }
435    
436    /// Get total operations
437    fn total_ops(&self) -> u64 {
438        self.counters.point_reads.load(Ordering::Relaxed)
439            + self.counters.range_scans.load(Ordering::Relaxed)
440            + self.counters.inserts.load(Ordering::Relaxed)
441            + self.counters.updates.load(Ordering::Relaxed)
442            + self.counters.deletes.load(Ordering::Relaxed)
443            + self.counters.vector_searches.load(Ordering::Relaxed)
444    }
445    
446    /// Extract features from operation counters
447    pub fn extract_features(&self) -> FeatureVector {
448        let reads = self.counters.point_reads.load(Ordering::Relaxed);
449        let scans = self.counters.range_scans.load(Ordering::Relaxed);
450        let inserts = self.counters.inserts.load(Ordering::Relaxed);
451        let updates = self.counters.updates.load(Ordering::Relaxed);
452        let deletes = self.counters.deletes.load(Ordering::Relaxed);
453        let vectors = self.counters.vector_searches.load(Ordering::Relaxed);
454        
455        let total = reads + scans + inserts + updates + deletes + vectors;
456        let total_f = total.max(1) as f64;
457        
458        let elapsed = self.start_time.elapsed().as_secs_f64().max(0.001);
459        
460        FeatureVector {
461            read_fraction: (reads + scans) as f64 / total_f,
462            write_fraction: (inserts + updates + deletes) as f64 / total_f,
463            scan_fraction: scans as f64 / total_f,
464            vector_fraction: vectors as f64 / total_f,
465            avg_latency_ms: 1.0, // Would be measured in practice
466            ops_per_second: total as f64 / elapsed,
467            key_locality: 0.5, // Would be measured from key distribution
468        }
469    }
470    
471    /// Get the current workload type
472    pub fn workload_type(&self) -> WorkloadType {
473        self.extract_features().classify()
474    }
475    
476    /// Get the current tuning configuration
477    pub fn current_config(&self) -> TuningConfig {
478        self.current_config.read().unwrap().clone()
479    }
480    
481    /// Start measuring reward (throughput)
482    pub fn start_reward_measurement(&self) {
483        *self.reward_start.write().unwrap() = Some(Instant::now());
484        self.ops_at_reward_start.store(self.total_ops(), Ordering::Relaxed);
485    }
486    
487    /// End measurement and update the bandit
488    pub fn end_reward_measurement(&self) {
489        let start = match *self.reward_start.read().unwrap() {
490            Some(t) => t,
491            None => return,
492        };
493        
494        let elapsed = start.elapsed().as_secs_f64();
495        if elapsed < 0.001 {
496            return;
497        }
498        
499        let ops_start = self.ops_at_reward_start.load(Ordering::Relaxed);
500        let ops_now = self.total_ops();
501        let throughput = (ops_now - ops_start) as f64 / elapsed;
502        
503        // Normalize reward to [0, 1]
504        let reward = (throughput / 100000.0).min(1.0);
505        
506        // Update current arm
507        let arm_idx = *self.current_arm.read().unwrap();
508        self.arms[arm_idx].record_reward(reward);
509        
510        *self.reward_start.write().unwrap() = None;
511    }
512    
513    /// Select the next arm using UCB1
514    pub fn select_arm(&self) -> usize {
515        let total_count: u64 = self.arms.iter()
516            .map(|a| a.count.load(Ordering::Relaxed))
517            .sum();
518        
519        if total_count < NUM_ARMS as u64 {
520            // Initial exploration: try each arm once
521            return total_count as usize;
522        }
523        
524        // Select arm with highest UCB
525        self.arms.iter()
526            .enumerate()
527            .max_by(|(_, a), (_, b)| {
528                a.ucb(total_count)
529                    .partial_cmp(&b.ucb(total_count))
530                    .unwrap_or(std::cmp::Ordering::Equal)
531            })
532            .map(|(i, _)| i)
533            .unwrap_or(0)
534    }
535    
536    /// Update the configuration based on current workload
537    pub fn update_config(&self) {
538        // End any ongoing measurement
539        self.end_reward_measurement();
540        
541        // Select new arm
542        let new_arm = self.select_arm();
543        let new_config = get_arm_config(new_arm);
544        
545        *self.current_arm.write().unwrap() = new_arm;
546        *self.current_config.write().unwrap() = new_config;
547        
548        // Start measuring with new config
549        self.start_reward_measurement();
550    }
551    
552    /// Get statistics
553    pub fn stats(&self) -> ClassifierStats {
554        let features = self.extract_features();
555        let arm_stats: Vec<_> = self.arms.iter()
556            .enumerate()
557            .map(|(i, arm)| ArmStats {
558                arm_id: i,
559                count: arm.count.load(Ordering::Relaxed),
560                avg_reward: arm.avg_reward(),
561            })
562            .collect();
563        
564        ClassifierStats {
565            workload_type: features.classify(),
566            features,
567            current_arm: *self.current_arm.read().unwrap(),
568            arm_stats,
569        }
570    }
571}
572
573impl Default for WorkloadClassifier {
574    fn default() -> Self {
575        Self::new()
576    }
577}
578
579/// Statistics for an arm
580#[derive(Debug, Clone)]
581pub struct ArmStats {
582    /// Arm ID
583    pub arm_id: usize,
584    /// Number of times selected
585    pub count: u64,
586    /// Average reward
587    pub avg_reward: f64,
588}
589
590/// Classifier statistics
591#[derive(Debug, Clone)]
592pub struct ClassifierStats {
593    /// Detected workload type
594    pub workload_type: WorkloadType,
595    /// Current features
596    pub features: FeatureVector,
597    /// Currently selected arm
598    pub current_arm: usize,
599    /// Per-arm statistics
600    pub arm_stats: Vec<ArmStats>,
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606    use std::thread;
607    use std::time::Duration;
608    
609    #[test]
610    fn test_feature_extraction() {
611        let classifier = WorkloadClassifier::new();
612        
613        // Simulate OLTP workload
614        for _ in 0..100 {
615            classifier.record_insert();
616            classifier.record_update();
617        }
618        for _ in 0..50 {
619            classifier.record_point_read();
620        }
621        
622        let features = classifier.extract_features();
623        assert!(features.write_fraction > 0.5);
624        
625        let workload = features.classify();
626        assert_eq!(workload, WorkloadType::Oltp);
627    }
628    
629    #[test]
630    fn test_olap_detection() {
631        let classifier = WorkloadClassifier::new();
632        
633        // Simulate OLAP workload
634        for _ in 0..100 {
635            classifier.record_range_scan();
636            classifier.record_point_read();
637        }
638        for _ in 0..10 {
639            classifier.record_insert();
640        }
641        
642        let features = classifier.extract_features();
643        let workload = features.classify();
644        assert_eq!(workload, WorkloadType::Olap);
645    }
646    
647    #[test]
648    fn test_vector_search_detection() {
649        let classifier = WorkloadClassifier::new();
650        
651        for _ in 0..100 {
652            classifier.record_vector_search();
653        }
654        for _ in 0..50 {
655            classifier.record_point_read();
656        }
657        
658        let features = classifier.extract_features();
659        let workload = features.classify();
660        assert_eq!(workload, WorkloadType::VectorSearch);
661    }
662    
663    #[test]
664    fn test_ucb_arm_selection() {
665        let classifier = WorkloadClassifier::new();
666        
667        // Initially should explore
668        for i in 0..NUM_ARMS {
669            let arm = classifier.select_arm();
670            // Give fake reward
671            classifier.arms[arm].record_reward(if arm % 2 == 0 { 0.8 } else { 0.2 });
672        }
673        
674        // After exploration, should prefer higher-reward arms
675        let selected: Vec<_> = (0..20).map(|_| classifier.select_arm()).collect();
676        let even_count = selected.iter().filter(|&&a| a % 2 == 0).count();
677        
678        // Should prefer even arms (higher reward)
679        assert!(even_count > 10);
680    }
681    
682    #[test]
683    fn test_config_update() {
684        let classifier = WorkloadClassifier::new();
685        
686        let config1 = classifier.current_config();
687        
688        // Simulate some activity
689        for _ in 0..100 {
690            classifier.record_insert();
691        }
692        
693        classifier.start_reward_measurement();
694        thread::sleep(Duration::from_millis(10));
695        classifier.update_config();
696        
697        // Config may or may not change, but should be valid
698        let config2 = classifier.current_config();
699        assert!(config2.memtable_size > 0);
700    }
701    
702    #[test]
703    fn test_arm_configs() {
704        for i in 0..NUM_ARMS {
705            let config = get_arm_config(i);
706            assert!(config.memtable_size > 0);
707            assert!(config.batch_size > 0);
708            assert!(config.prefetch_distance > 0);
709        }
710    }
711    
712    #[test]
713    fn test_stats() {
714        let classifier = WorkloadClassifier::new();
715        
716        for _ in 0..50 {
717            classifier.record_insert();
718            classifier.record_point_read();
719        }
720        
721        let stats = classifier.stats();
722        assert_eq!(stats.arm_stats.len(), NUM_ARMS);
723        assert!(stats.features.ops_per_second > 0.0);
724    }
725}