ruvector_sona/
types.rs

1//! SONA Core Types
2//!
3//! Defines the fundamental data structures for the Self-Optimizing Neural Architecture.
4
5use serde::{Deserialize, Serialize};
6use std::time::Instant;
7use std::collections::HashMap;
8
9/// Learning signal generated from inference trajectory
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12    /// Query embedding vector
13    pub query_embedding: Vec<f32>,
14    /// Estimated gradient direction
15    pub gradient_estimate: Vec<f32>,
16    /// Quality score [0.0, 1.0]
17    pub quality_score: f32,
18    /// Signal generation timestamp (serialized as nanos)
19    #[serde(skip)]
20    pub timestamp: Option<Instant>,
21    /// Additional metadata
22    pub metadata: SignalMetadata,
23}
24
25/// Metadata for learning signals
26#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28    /// Source trajectory ID
29    pub trajectory_id: u64,
30    /// Number of steps in trajectory
31    pub step_count: usize,
32    /// Model route taken
33    pub model_route: Option<String>,
34    /// Custom tags
35    pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39    /// Create signal from query trajectory using REINFORCE gradient estimation
40    pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41        let gradient = Self::estimate_gradient(trajectory);
42
43        Self {
44            query_embedding: trajectory.query_embedding.clone(),
45            gradient_estimate: gradient,
46            quality_score: trajectory.final_quality,
47            timestamp: Some(Instant::now()),
48            metadata: SignalMetadata {
49                trajectory_id: trajectory.id,
50                step_count: trajectory.steps.len(),
51                model_route: trajectory.model_route.clone(),
52                tags: HashMap::new(),
53            },
54        }
55    }
56
57    /// Create signal with pre-computed gradient
58    pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59        Self {
60            query_embedding: embedding,
61            gradient_estimate: gradient,
62            quality_score: quality,
63            timestamp: Some(Instant::now()),
64            metadata: SignalMetadata::default(),
65        }
66    }
67
68    /// Estimate gradient using REINFORCE with baseline
69    fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70        if trajectory.steps.is_empty() {
71            return trajectory.query_embedding.clone();
72        }
73
74        let dim = trajectory.query_embedding.len();
75        let mut gradient = vec![0.0f32; dim];
76
77        // Compute baseline (average reward)
78        let baseline = trajectory.steps.iter()
79            .map(|s| s.reward)
80            .sum::<f32>() / trajectory.steps.len() as f32;
81
82        // REINFORCE: gradient = sum((reward - baseline) * activation)
83        for step in &trajectory.steps {
84            let advantage = step.reward - baseline;
85            let activation_len = step.activations.len().min(dim);
86            for i in 0..activation_len {
87                gradient[i] += advantage * step.activations[i];
88            }
89        }
90
91        // L2 normalize
92        let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
93        if norm > 1e-8 {
94            gradient.iter_mut().for_each(|x| *x /= norm);
95        }
96
97        gradient
98    }
99
100    /// Scale gradient by quality
101    pub fn scaled_gradient(&self) -> Vec<f32> {
102        self.gradient_estimate.iter()
103            .map(|&g| g * self.quality_score)
104            .collect()
105    }
106}
107
108/// Query trajectory recording
109#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct QueryTrajectory {
111    /// Unique trajectory identifier
112    pub id: u64,
113    /// Query embedding vector
114    pub query_embedding: Vec<f32>,
115    /// Execution steps
116    pub steps: Vec<TrajectoryStep>,
117    /// Final quality score [0.0, 1.0]
118    pub final_quality: f32,
119    /// Total latency in microseconds
120    pub latency_us: u64,
121    /// Model route taken
122    pub model_route: Option<String>,
123    /// Context used
124    pub context_ids: Vec<String>,
125}
126
127impl QueryTrajectory {
128    /// Create new trajectory
129    pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
130        Self {
131            id,
132            query_embedding,
133            steps: Vec::with_capacity(16),
134            final_quality: 0.0,
135            latency_us: 0,
136            model_route: None,
137            context_ids: Vec::new(),
138        }
139    }
140
141    /// Add execution step
142    pub fn add_step(&mut self, step: TrajectoryStep) {
143        self.steps.push(step);
144    }
145
146    /// Finalize trajectory with quality score
147    pub fn finalize(&mut self, quality: f32, latency_us: u64) {
148        self.final_quality = quality;
149        self.latency_us = latency_us;
150    }
151
152    /// Get total reward
153    pub fn total_reward(&self) -> f32 {
154        self.steps.iter().map(|s| s.reward).sum()
155    }
156
157    /// Get average reward
158    pub fn avg_reward(&self) -> f32 {
159        if self.steps.is_empty() {
160            0.0
161        } else {
162            self.total_reward() / self.steps.len() as f32
163        }
164    }
165}
166
167/// Single step in a trajectory
168#[derive(Clone, Debug, Serialize, Deserialize)]
169pub struct TrajectoryStep {
170    /// Layer/module activations (subset for efficiency)
171    pub activations: Vec<f32>,
172    /// Attention weights (flattened)
173    pub attention_weights: Vec<f32>,
174    /// Reward signal for this step
175    pub reward: f32,
176    /// Step index
177    pub step_idx: usize,
178    /// Optional layer name
179    pub layer_name: Option<String>,
180}
181
182impl TrajectoryStep {
183    /// Create new step
184    pub fn new(activations: Vec<f32>, attention_weights: Vec<f32>, reward: f32, step_idx: usize) -> Self {
185        Self {
186            activations,
187            attention_weights,
188            reward,
189            step_idx,
190            layer_name: None,
191        }
192    }
193
194    /// Create step with layer name
195    pub fn with_layer(mut self, name: &str) -> Self {
196        self.layer_name = Some(name.to_string());
197        self
198    }
199}
200
201/// Learned pattern from trajectory clustering
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct LearnedPattern {
204    /// Pattern identifier
205    pub id: u64,
206    /// Cluster centroid embedding
207    pub centroid: Vec<f32>,
208    /// Number of trajectories in cluster
209    pub cluster_size: usize,
210    /// Sum of trajectory weights
211    pub total_weight: f32,
212    /// Average quality of member trajectories
213    pub avg_quality: f32,
214    /// Creation timestamp (Unix seconds)
215    pub created_at: u64,
216    /// Last access timestamp
217    pub last_accessed: u64,
218    /// Total access count
219    pub access_count: u32,
220    /// Pattern type/category
221    pub pattern_type: PatternType,
222}
223
224/// Pattern classification
225#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
226pub enum PatternType {
227    #[default]
228    General,
229    Reasoning,
230    Factual,
231    Creative,
232    CodeGen,
233    Conversational,
234}
235
236impl std::fmt::Display for PatternType {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            PatternType::General => write!(f, "general"),
240            PatternType::Reasoning => write!(f, "reasoning"),
241            PatternType::Factual => write!(f, "factual"),
242            PatternType::Creative => write!(f, "creative"),
243            PatternType::CodeGen => write!(f, "codegen"),
244            PatternType::Conversational => write!(f, "conversational"),
245        }
246    }
247}
248
249impl LearnedPattern {
250    /// Create new pattern
251    pub fn new(id: u64, centroid: Vec<f32>) -> Self {
252        let now = std::time::SystemTime::now()
253            .duration_since(std::time::UNIX_EPOCH)
254            .unwrap_or_default()
255            .as_secs();
256
257        Self {
258            id,
259            centroid,
260            cluster_size: 1,
261            total_weight: 1.0,
262            avg_quality: 0.0,
263            created_at: now,
264            last_accessed: now,
265            access_count: 0,
266            pattern_type: PatternType::default(),
267        }
268    }
269
270    /// Merge two patterns
271    pub fn merge(&self, other: &Self) -> Self {
272        let total_size = self.cluster_size + other.cluster_size;
273        let w1 = self.cluster_size as f32 / total_size as f32;
274        let w2 = other.cluster_size as f32 / total_size as f32;
275
276        let centroid: Vec<f32> = self.centroid.iter()
277            .zip(&other.centroid)
278            .map(|(&a, &b)| a * w1 + b * w2)
279            .collect();
280
281        Self {
282            id: self.id,
283            centroid,
284            cluster_size: total_size,
285            total_weight: self.total_weight + other.total_weight,
286            avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
287            created_at: self.created_at.min(other.created_at),
288            last_accessed: self.last_accessed.max(other.last_accessed),
289            access_count: self.access_count + other.access_count,
290            pattern_type: self.pattern_type.clone(),
291        }
292    }
293
294    /// Decay pattern importance
295    pub fn decay(&mut self, factor: f32) {
296        self.total_weight *= factor;
297    }
298
299    /// Record access
300    pub fn touch(&mut self) {
301        self.access_count += 1;
302        self.last_accessed = std::time::SystemTime::now()
303            .duration_since(std::time::UNIX_EPOCH)
304            .unwrap_or_default()
305            .as_secs();
306    }
307
308    /// Check if pattern should be pruned
309    pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
310        let now = std::time::SystemTime::now()
311            .duration_since(std::time::UNIX_EPOCH)
312            .unwrap_or_default()
313            .as_secs();
314        let age = now.saturating_sub(self.last_accessed);
315
316        self.avg_quality < min_quality
317            && self.access_count < min_accesses
318            && age > max_age_secs
319    }
320
321    /// Compute cosine similarity with query
322    pub fn similarity(&self, query: &[f32]) -> f32 {
323        if self.centroid.len() != query.len() {
324            return 0.0;
325        }
326
327        let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
328        let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
329        let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
330
331        if norm_a > 1e-8 && norm_b > 1e-8 {
332            dot / (norm_a * norm_b)
333        } else {
334            0.0
335        }
336    }
337}
338
339/// SONA configuration
340#[derive(Clone, Debug, Serialize, Deserialize)]
341pub struct SonaConfig {
342    /// Hidden dimension
343    pub hidden_dim: usize,
344    /// Embedding dimension
345    pub embedding_dim: usize,
346    /// Micro-LoRA rank
347    pub micro_lora_rank: usize,
348    /// Base LoRA rank
349    pub base_lora_rank: usize,
350    /// Micro-LoRA learning rate
351    pub micro_lora_lr: f32,
352    /// Base LoRA learning rate
353    pub base_lora_lr: f32,
354    /// EWC lambda
355    pub ewc_lambda: f32,
356    /// Pattern extraction clusters
357    pub pattern_clusters: usize,
358    /// Trajectory buffer capacity
359    pub trajectory_capacity: usize,
360    /// Background learning interval (ms)
361    pub background_interval_ms: u64,
362    /// Quality threshold for learning
363    pub quality_threshold: f32,
364    /// Enable SIMD optimizations
365    pub enable_simd: bool,
366}
367
368impl Default for SonaConfig {
369    fn default() -> Self {
370        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
371        // - Rank-2 is 5% faster than Rank-1 due to better SIMD vectorization
372        // - Learning rate 0.002 yields +55% quality improvement
373        // - 100 clusters = 1.3ms search vs 50 clusters = 3.0ms (2.3x faster)
374        // - EWC lambda 2000 optimal for catastrophic forgetting prevention
375        // - Quality threshold 0.3 balances learning vs noise filtering
376        Self {
377            hidden_dim: 256,
378            embedding_dim: 256,
379            micro_lora_rank: 2,      // OPTIMIZED: Rank-2 faster than Rank-1 (2,211 vs 2,100 ops/sec)
380            base_lora_rank: 8,       // Balanced for production
381            micro_lora_lr: 0.002,    // OPTIMIZED: +55.3% quality improvement
382            base_lora_lr: 0.0001,
383            ewc_lambda: 2000.0,      // OPTIMIZED: Better forgetting prevention
384            pattern_clusters: 100,   // OPTIMIZED: 2.3x faster search (1.3ms vs 3.0ms)
385            trajectory_capacity: 10000,
386            background_interval_ms: 3600000, // 1 hour
387            quality_threshold: 0.3,  // OPTIMIZED: Lower threshold for more learning
388            enable_simd: true,
389        }
390    }
391}
392
393impl SonaConfig {
394    /// Create config optimized for maximum throughput (real-time chat)
395    pub fn max_throughput() -> Self {
396        Self {
397            hidden_dim: 256,
398            embedding_dim: 256,
399            micro_lora_rank: 2,       // Rank-2 + SIMD = 2,211 ops/sec
400            base_lora_rank: 4,        // Minimal base for speed
401            micro_lora_lr: 0.0005,    // Conservative for stability
402            base_lora_lr: 0.0001,
403            ewc_lambda: 2000.0,
404            pattern_clusters: 100,
405            trajectory_capacity: 5000,
406            background_interval_ms: 7200000, // 2 hours
407            quality_threshold: 0.4,
408            enable_simd: true,
409        }
410    }
411
412    /// Create config optimized for maximum quality (research/batch)
413    pub fn max_quality() -> Self {
414        Self {
415            hidden_dim: 256,
416            embedding_dim: 256,
417            micro_lora_rank: 2,
418            base_lora_rank: 16,       // Higher rank for expressiveness
419            micro_lora_lr: 0.002,     // Optimal learning rate
420            base_lora_lr: 0.001,      // Aggressive base learning
421            ewc_lambda: 2000.0,
422            pattern_clusters: 100,
423            trajectory_capacity: 20000,
424            background_interval_ms: 1800000, // 30 minutes
425            quality_threshold: 0.2,   // Learn from more trajectories
426            enable_simd: true,
427        }
428    }
429
430    /// Create config for edge/mobile deployment (<5MB memory)
431    pub fn edge_deployment() -> Self {
432        Self {
433            hidden_dim: 256,
434            embedding_dim: 256,
435            micro_lora_rank: 1,       // Minimal rank for memory
436            base_lora_rank: 4,
437            micro_lora_lr: 0.001,
438            base_lora_lr: 0.0001,
439            ewc_lambda: 1000.0,
440            pattern_clusters: 50,
441            trajectory_capacity: 200, // Small buffer
442            background_interval_ms: 3600000,
443            quality_threshold: 0.5,
444            enable_simd: true,
445        }
446    }
447
448    /// Create config for batch processing (50+ inferences/sec)
449    pub fn batch_processing() -> Self {
450        Self {
451            hidden_dim: 256,
452            embedding_dim: 256,
453            micro_lora_rank: 2,
454            base_lora_rank: 8,
455            micro_lora_lr: 0.001,
456            base_lora_lr: 0.0001,
457            ewc_lambda: 2000.0,
458            pattern_clusters: 100,
459            trajectory_capacity: 10000,
460            background_interval_ms: 3600000,
461            quality_threshold: 0.3,
462            enable_simd: true,
463        }
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_learning_signal_from_trajectory() {
473        let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
474        trajectory.add_step(TrajectoryStep::new(
475            vec![0.5, 0.3, 0.2],
476            vec![0.4, 0.4, 0.2],
477            0.8,
478            0,
479        ));
480        trajectory.finalize(0.8, 1000);
481
482        let signal = LearningSignal::from_trajectory(&trajectory);
483        assert_eq!(signal.quality_score, 0.8);
484        assert_eq!(signal.gradient_estimate.len(), 3);
485        assert_eq!(signal.metadata.trajectory_id, 1);
486    }
487
488    #[test]
489    fn test_pattern_merge() {
490        let p1 = LearnedPattern {
491            id: 1,
492            centroid: vec![1.0, 0.0],
493            cluster_size: 10,
494            total_weight: 5.0,
495            avg_quality: 0.8,
496            created_at: 100,
497            last_accessed: 200,
498            access_count: 5,
499            pattern_type: PatternType::General,
500        };
501
502        let p2 = LearnedPattern {
503            id: 2,
504            centroid: vec![0.0, 1.0],
505            cluster_size: 10,
506            total_weight: 5.0,
507            avg_quality: 0.9,
508            created_at: 150,
509            last_accessed: 250,
510            access_count: 3,
511            pattern_type: PatternType::General,
512        };
513
514        let merged = p1.merge(&p2);
515        assert_eq!(merged.cluster_size, 20);
516        assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
517        assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
518        assert!((merged.avg_quality - 0.85).abs() < 1e-6);
519    }
520
521    #[test]
522    fn test_pattern_similarity() {
523        let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
524
525        assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
526        assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
527    }
528
529    #[test]
530    fn test_trajectory_rewards() {
531        let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
532        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
533        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
534        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
535
536        assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
537        assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
538    }
539}