Skip to main content

ruvector_sona/loops/
instant.rs

1//! Loop A - Instant Learning
2//!
3//! Per-request adaptation with <1ms overhead.
4
5use crate::lora::MicroLoRA;
6use crate::trajectory::{TrajectoryBuffer, TrajectoryIdGen};
7use crate::types::{LearningSignal, QueryTrajectory, SonaConfig};
8use parking_lot::RwLock;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12/// Configuration for instant loop
13#[derive(Clone, Debug)]
14pub struct InstantLoopConfig {
15    /// Micro-LoRA rank
16    pub micro_lora_rank: usize,
17    /// Micro-LoRA learning rate
18    pub micro_lora_lr: f32,
19    /// Buffer capacity
20    pub buffer_capacity: usize,
21    /// Flush threshold (apply updates every N signals)
22    pub flush_threshold: usize,
23}
24
25impl Default for InstantLoopConfig {
26    fn default() -> Self {
27        Self {
28            micro_lora_rank: 1,
29            micro_lora_lr: 0.001,
30            buffer_capacity: 10000,
31            flush_threshold: 100,
32        }
33    }
34}
35
36impl From<&SonaConfig> for InstantLoopConfig {
37    fn from(config: &SonaConfig) -> Self {
38        Self {
39            micro_lora_rank: config.micro_lora_rank,
40            micro_lora_lr: config.micro_lora_lr,
41            buffer_capacity: config.trajectory_capacity,
42            flush_threshold: 100,
43        }
44    }
45}
46
47/// Instant loop metrics
48#[derive(Debug, Default)]
49pub struct InstantLoopMetrics {
50    /// Total trajectories processed
51    pub trajectories_processed: AtomicU64,
52    /// Total signals accumulated
53    pub signals_accumulated: AtomicU64,
54    /// Total flushes performed
55    pub flushes_performed: AtomicU64,
56    /// Total updates applied
57    pub updates_applied: AtomicU64,
58}
59
60/// Instant learning loop (Loop A)
61pub struct InstantLoop {
62    /// Configuration
63    config: InstantLoopConfig,
64    /// Trajectory buffer
65    trajectory_buffer: Arc<TrajectoryBuffer>,
66    /// Micro-LoRA adapter
67    micro_lora: Arc<RwLock<MicroLoRA>>,
68    /// ID generator
69    id_gen: TrajectoryIdGen,
70    /// Pending signal count
71    pending_signals: AtomicU64,
72    /// Metrics
73    pub metrics: InstantLoopMetrics,
74}
75
76impl InstantLoop {
77    /// Create new instant loop
78    pub fn new(hidden_dim: usize, config: InstantLoopConfig) -> Self {
79        Self {
80            trajectory_buffer: Arc::new(TrajectoryBuffer::new(config.buffer_capacity)),
81            micro_lora: Arc::new(RwLock::new(MicroLoRA::new(
82                hidden_dim,
83                config.micro_lora_rank,
84            ))),
85            id_gen: TrajectoryIdGen::new(),
86            pending_signals: AtomicU64::new(0),
87            config,
88            metrics: InstantLoopMetrics::default(),
89        }
90    }
91
92    /// Create from SONA config
93    pub fn from_sona_config(config: &SonaConfig) -> Self {
94        Self::new(config.hidden_dim, InstantLoopConfig::from(config))
95    }
96
97    /// Generate next trajectory ID
98    pub fn next_id(&self) -> u64 {
99        self.id_gen.next()
100    }
101
102    /// Process completed trajectory
103    pub fn on_trajectory(&self, trajectory: QueryTrajectory) {
104        // Record to buffer
105        self.trajectory_buffer.record(trajectory.clone());
106        self.metrics
107            .trajectories_processed
108            .fetch_add(1, Ordering::Relaxed);
109
110        // Generate learning signal
111        let signal = LearningSignal::from_trajectory(&trajectory);
112
113        // Accumulate gradient (non-blocking)
114        if let Some(mut lora) = self.micro_lora.try_write() {
115            lora.accumulate_gradient(&signal);
116            self.metrics
117                .signals_accumulated
118                .fetch_add(1, Ordering::Relaxed);
119
120            let pending = self.pending_signals.fetch_add(1, Ordering::Relaxed) + 1;
121
122            // Auto-flush if threshold reached
123            if pending >= self.config.flush_threshold as u64 {
124                self.flush_internal(&mut lora);
125            }
126        }
127    }
128
129    /// Manually flush accumulated updates
130    pub fn flush(&self) {
131        if let Some(mut lora) = self.micro_lora.try_write() {
132            self.flush_internal(&mut lora);
133        }
134    }
135
136    fn flush_internal(&self, lora: &mut MicroLoRA) {
137        let pending = lora.pending_updates();
138        if pending > 0 {
139            lora.apply_accumulated(self.config.micro_lora_lr);
140            self.pending_signals.store(0, Ordering::Relaxed);
141            self.metrics
142                .flushes_performed
143                .fetch_add(1, Ordering::Relaxed);
144            self.metrics
145                .updates_applied
146                .fetch_add(pending as u64, Ordering::Relaxed);
147        }
148    }
149
150    /// Drain trajectories for background processing
151    pub fn drain_trajectories(&self) -> Vec<QueryTrajectory> {
152        self.trajectory_buffer.drain()
153    }
154
155    /// Drain up to N trajectories
156    pub fn drain_trajectories_n(&self, n: usize) -> Vec<QueryTrajectory> {
157        self.trajectory_buffer.drain_n(n)
158    }
159
160    /// Get micro-LoRA reference for inference
161    pub fn micro_lora(&self) -> &Arc<RwLock<MicroLoRA>> {
162        &self.micro_lora
163    }
164
165    /// Get trajectory buffer reference
166    pub fn buffer(&self) -> &Arc<TrajectoryBuffer> {
167        &self.trajectory_buffer
168    }
169
170    /// Get pending trajectory count
171    pub fn pending_count(&self) -> usize {
172        self.trajectory_buffer.len()
173    }
174
175    /// Get buffer stats
176    pub fn buffer_stats(&self) -> (usize, u64, f64) {
177        (
178            self.trajectory_buffer.len(),
179            self.trajectory_buffer.dropped_count(),
180            self.trajectory_buffer.success_rate(),
181        )
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::types::TrajectoryStep;
189
190    fn make_trajectory(id: u64) -> QueryTrajectory {
191        let mut t = QueryTrajectory::new(id, vec![0.1; 64]);
192        t.add_step(TrajectoryStep::new(vec![0.5; 64], vec![], 0.8, 0));
193        t.finalize(0.8, 1000);
194        t
195    }
196
197    #[test]
198    fn test_instant_loop_creation() {
199        let loop_a = InstantLoop::new(64, InstantLoopConfig::default());
200        assert_eq!(loop_a.pending_count(), 0);
201    }
202
203    #[test]
204    fn test_trajectory_processing() {
205        let loop_a = InstantLoop::new(64, InstantLoopConfig::default());
206
207        let t = make_trajectory(loop_a.next_id());
208        loop_a.on_trajectory(t);
209
210        assert_eq!(loop_a.pending_count(), 1);
211        assert_eq!(
212            loop_a
213                .metrics
214                .trajectories_processed
215                .load(Ordering::Relaxed),
216            1
217        );
218    }
219
220    #[test]
221    fn test_auto_flush() {
222        let config = InstantLoopConfig {
223            flush_threshold: 3,
224            ..Default::default()
225        };
226        let loop_a = InstantLoop::new(64, config);
227
228        for i in 0..5 {
229            loop_a.on_trajectory(make_trajectory(i));
230        }
231
232        assert!(loop_a.metrics.flushes_performed.load(Ordering::Relaxed) >= 1);
233    }
234
235    #[test]
236    fn test_drain() {
237        let loop_a = InstantLoop::new(64, InstantLoopConfig::default());
238
239        for i in 0..10 {
240            loop_a.on_trajectory(make_trajectory(i));
241        }
242
243        let drained = loop_a.drain_trajectories();
244        assert_eq!(drained.len(), 10);
245        assert_eq!(loop_a.pending_count(), 0);
246    }
247}