Skip to main content

ruvector_sona/
types.rs

1//! SONA Core Types
2//!
3//! Defines the fundamental data structures for the Self-Optimizing Neural Architecture.
4
5use crate::time_compat::Instant;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Learning signal generated from inference trajectory
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12    /// Query embedding vector
13    pub query_embedding: Vec<f32>,
14    /// Estimated gradient direction
15    pub gradient_estimate: Vec<f32>,
16    /// Quality score [0.0, 1.0]
17    pub quality_score: f32,
18    /// Signal generation timestamp (serialized as nanos)
19    #[serde(skip)]
20    pub timestamp: Option<Instant>,
21    /// Additional metadata
22    pub metadata: SignalMetadata,
23}
24
25/// Metadata for learning signals
26#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28    /// Source trajectory ID
29    pub trajectory_id: u64,
30    /// Number of steps in trajectory
31    pub step_count: usize,
32    /// Model route taken
33    pub model_route: Option<String>,
34    /// Custom tags
35    pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39    /// Create signal from query trajectory using REINFORCE gradient estimation
40    pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41        let gradient = Self::estimate_gradient(trajectory);
42
43        Self {
44            query_embedding: trajectory.query_embedding.clone(),
45            gradient_estimate: gradient,
46            quality_score: trajectory.final_quality,
47            timestamp: Some(Instant::now()),
48            metadata: SignalMetadata {
49                trajectory_id: trajectory.id,
50                step_count: trajectory.steps.len(),
51                model_route: trajectory.model_route.clone(),
52                tags: HashMap::new(),
53            },
54        }
55    }
56
57    /// Create signal with pre-computed gradient
58    pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59        Self {
60            query_embedding: embedding,
61            gradient_estimate: gradient,
62            quality_score: quality,
63            timestamp: Some(Instant::now()),
64            metadata: SignalMetadata::default(),
65        }
66    }
67
68    /// Estimate gradient using REINFORCE with baseline
69    fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70        if trajectory.steps.is_empty() {
71            return trajectory.query_embedding.clone();
72        }
73
74        let dim = trajectory.query_embedding.len();
75        let mut gradient = vec![0.0f32; dim];
76
77        // Compute baseline (average reward)
78        let baseline =
79            trajectory.steps.iter().map(|s| s.reward).sum::<f32>() / trajectory.steps.len() as f32;
80
81        // REINFORCE: gradient = sum((reward - baseline) * activation)
82        for step in &trajectory.steps {
83            let advantage = step.reward - baseline;
84            let activation_len = step.activations.len().min(dim);
85            for i in 0..activation_len {
86                gradient[i] += advantage * step.activations[i];
87            }
88        }
89
90        // L2 normalize
91        let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
92        if norm > 1e-8 {
93            gradient.iter_mut().for_each(|x| *x /= norm);
94        }
95
96        gradient
97    }
98
99    /// Scale gradient by quality
100    pub fn scaled_gradient(&self) -> Vec<f32> {
101        self.gradient_estimate
102            .iter()
103            .map(|&g| g * self.quality_score)
104            .collect()
105    }
106}
107
108/// Query trajectory recording
109#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct QueryTrajectory {
111    /// Unique trajectory identifier
112    pub id: u64,
113    /// Query embedding vector
114    pub query_embedding: Vec<f32>,
115    /// Execution steps
116    pub steps: Vec<TrajectoryStep>,
117    /// Final quality score [0.0, 1.0]
118    pub final_quality: f32,
119    /// Total latency in microseconds
120    pub latency_us: u64,
121    /// Model route taken
122    pub model_route: Option<String>,
123    /// Context used
124    pub context_ids: Vec<String>,
125}
126
127impl QueryTrajectory {
128    /// Create new trajectory
129    pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
130        Self {
131            id,
132            query_embedding,
133            steps: Vec::with_capacity(16),
134            final_quality: 0.0,
135            latency_us: 0,
136            model_route: None,
137            context_ids: Vec::new(),
138        }
139    }
140
141    /// Add execution step
142    pub fn add_step(&mut self, step: TrajectoryStep) {
143        self.steps.push(step);
144    }
145
146    /// Finalize trajectory with quality score
147    pub fn finalize(&mut self, quality: f32, latency_us: u64) {
148        self.final_quality = quality;
149        self.latency_us = latency_us;
150    }
151
152    /// Get total reward
153    pub fn total_reward(&self) -> f32 {
154        self.steps.iter().map(|s| s.reward).sum()
155    }
156
157    /// Get average reward
158    pub fn avg_reward(&self) -> f32 {
159        if self.steps.is_empty() {
160            0.0
161        } else {
162            self.total_reward() / self.steps.len() as f32
163        }
164    }
165}
166
167/// Single step in a trajectory
168#[derive(Clone, Debug, Serialize, Deserialize)]
169pub struct TrajectoryStep {
170    /// Layer/module activations (subset for efficiency)
171    pub activations: Vec<f32>,
172    /// Attention weights (flattened)
173    pub attention_weights: Vec<f32>,
174    /// Reward signal for this step
175    pub reward: f32,
176    /// Step index
177    pub step_idx: usize,
178    /// Optional layer name
179    pub layer_name: Option<String>,
180}
181
182impl TrajectoryStep {
183    /// Create new step
184    pub fn new(
185        activations: Vec<f32>,
186        attention_weights: Vec<f32>,
187        reward: f32,
188        step_idx: usize,
189    ) -> Self {
190        Self {
191            activations,
192            attention_weights,
193            reward,
194            step_idx,
195            layer_name: None,
196        }
197    }
198
199    /// Create step with layer name
200    pub fn with_layer(mut self, name: &str) -> Self {
201        self.layer_name = Some(name.to_string());
202        self
203    }
204}
205
206/// Learned pattern from trajectory clustering
207#[derive(Clone, Debug, Serialize, Deserialize)]
208pub struct LearnedPattern {
209    /// Pattern identifier
210    pub id: u64,
211    /// Cluster centroid embedding
212    pub centroid: Vec<f32>,
213    /// Number of trajectories in cluster
214    pub cluster_size: usize,
215    /// Sum of trajectory weights
216    pub total_weight: f32,
217    /// Average quality of member trajectories
218    pub avg_quality: f32,
219    /// Creation timestamp (Unix seconds)
220    pub created_at: u64,
221    /// Last access timestamp
222    pub last_accessed: u64,
223    /// Total access count
224    pub access_count: u32,
225    /// Pattern type/category
226    pub pattern_type: PatternType,
227}
228
229/// Pattern classification
230#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
231pub enum PatternType {
232    #[default]
233    General,
234    Reasoning,
235    Factual,
236    Creative,
237    CodeGen,
238    Conversational,
239}
240
241impl std::fmt::Display for PatternType {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        match self {
244            PatternType::General => write!(f, "general"),
245            PatternType::Reasoning => write!(f, "reasoning"),
246            PatternType::Factual => write!(f, "factual"),
247            PatternType::Creative => write!(f, "creative"),
248            PatternType::CodeGen => write!(f, "codegen"),
249            PatternType::Conversational => write!(f, "conversational"),
250        }
251    }
252}
253
254impl LearnedPattern {
255    /// Create new pattern
256    pub fn new(id: u64, centroid: Vec<f32>) -> Self {
257        use crate::time_compat::SystemTime;
258        let now = SystemTime::now().duration_since_epoch().as_secs();
259
260        Self {
261            id,
262            centroid,
263            cluster_size: 1,
264            total_weight: 1.0,
265            avg_quality: 0.0,
266            created_at: now,
267            last_accessed: now,
268            access_count: 0,
269            pattern_type: PatternType::default(),
270        }
271    }
272
273    /// Merge two patterns
274    pub fn merge(&self, other: &Self) -> Self {
275        let total_size = self.cluster_size + other.cluster_size;
276        let w1 = self.cluster_size as f32 / total_size as f32;
277        let w2 = other.cluster_size as f32 / total_size as f32;
278
279        let centroid: Vec<f32> = self
280            .centroid
281            .iter()
282            .zip(&other.centroid)
283            .map(|(&a, &b)| a * w1 + b * w2)
284            .collect();
285
286        Self {
287            id: self.id,
288            centroid,
289            cluster_size: total_size,
290            total_weight: self.total_weight + other.total_weight,
291            avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
292            created_at: self.created_at.min(other.created_at),
293            last_accessed: self.last_accessed.max(other.last_accessed),
294            access_count: self.access_count + other.access_count,
295            pattern_type: self.pattern_type.clone(),
296        }
297    }
298
299    /// Decay pattern importance
300    pub fn decay(&mut self, factor: f32) {
301        self.total_weight *= factor;
302    }
303
304    /// Record access
305    pub fn touch(&mut self) {
306        use crate::time_compat::SystemTime;
307        self.access_count += 1;
308        self.last_accessed = SystemTime::now().duration_since_epoch().as_secs();
309    }
310
311    /// Check if pattern should be pruned
312    pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
313        use crate::time_compat::SystemTime;
314        let now = SystemTime::now().duration_since_epoch().as_secs();
315        let age = now.saturating_sub(self.last_accessed);
316
317        self.avg_quality < min_quality && self.access_count < min_accesses && age > max_age_secs
318    }
319
320    /// Compute cosine similarity with query
321    pub fn similarity(&self, query: &[f32]) -> f32 {
322        if self.centroid.len() != query.len() {
323            return 0.0;
324        }
325
326        let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
327        let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
328        let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
329
330        if norm_a > 1e-8 && norm_b > 1e-8 {
331            dot / (norm_a * norm_b)
332        } else {
333            0.0
334        }
335    }
336}
337
338/// SONA configuration
339#[derive(Clone, Debug, Serialize, Deserialize)]
340pub struct SonaConfig {
341    /// Hidden dimension
342    pub hidden_dim: usize,
343    /// Embedding dimension
344    pub embedding_dim: usize,
345    /// Micro-LoRA rank
346    pub micro_lora_rank: usize,
347    /// Base LoRA rank
348    pub base_lora_rank: usize,
349    /// Micro-LoRA learning rate
350    pub micro_lora_lr: f32,
351    /// Base LoRA learning rate
352    pub base_lora_lr: f32,
353    /// EWC lambda
354    pub ewc_lambda: f32,
355    /// Pattern extraction clusters
356    pub pattern_clusters: usize,
357    /// Trajectory buffer capacity
358    pub trajectory_capacity: usize,
359    /// Background learning interval (ms)
360    pub background_interval_ms: u64,
361    /// Quality threshold for learning
362    pub quality_threshold: f32,
363    /// Enable SIMD optimizations
364    pub enable_simd: bool,
365}
366
367impl Default for SonaConfig {
368    fn default() -> Self {
369        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
370        // - Rank-2 is 5% faster than Rank-1 due to better SIMD vectorization
371        // - Learning rate 0.002 yields +55% quality improvement
372        // - 100 clusters = 1.3ms search vs 50 clusters = 3.0ms (2.3x faster)
373        // - EWC lambda 2000 optimal for catastrophic forgetting prevention
374        // - Quality threshold 0.3 balances learning vs noise filtering
375        Self {
376            hidden_dim: 256,
377            embedding_dim: 256,
378            micro_lora_rank: 2, // OPTIMIZED: Rank-2 faster than Rank-1 (2,211 vs 2,100 ops/sec)
379            base_lora_rank: 8,  // Balanced for production
380            micro_lora_lr: 0.002, // OPTIMIZED: +55.3% quality improvement
381            base_lora_lr: 0.0001,
382            ewc_lambda: 2000.0,    // OPTIMIZED: Better forgetting prevention
383            pattern_clusters: 100, // OPTIMIZED: 2.3x faster search (1.3ms vs 3.0ms)
384            trajectory_capacity: 10000,
385            background_interval_ms: 3600000, // 1 hour
386            quality_threshold: 0.3,          // OPTIMIZED: Lower threshold for more learning
387            enable_simd: true,
388        }
389    }
390}
391
392impl SonaConfig {
393    /// Create config optimized for maximum throughput (real-time chat)
394    pub fn max_throughput() -> Self {
395        Self {
396            hidden_dim: 256,
397            embedding_dim: 256,
398            micro_lora_rank: 2,    // Rank-2 + SIMD = 2,211 ops/sec
399            base_lora_rank: 4,     // Minimal base for speed
400            micro_lora_lr: 0.0005, // Conservative for stability
401            base_lora_lr: 0.0001,
402            ewc_lambda: 2000.0,
403            pattern_clusters: 100,
404            trajectory_capacity: 5000,
405            background_interval_ms: 7200000, // 2 hours
406            quality_threshold: 0.4,
407            enable_simd: true,
408        }
409    }
410
411    /// Create config optimized for maximum quality (research/batch)
412    pub fn max_quality() -> Self {
413        Self {
414            hidden_dim: 256,
415            embedding_dim: 256,
416            micro_lora_rank: 2,
417            base_lora_rank: 16,   // Higher rank for expressiveness
418            micro_lora_lr: 0.002, // Optimal learning rate
419            base_lora_lr: 0.001,  // Aggressive base learning
420            ewc_lambda: 2000.0,
421            pattern_clusters: 100,
422            trajectory_capacity: 20000,
423            background_interval_ms: 1800000, // 30 minutes
424            quality_threshold: 0.2,          // Learn from more trajectories
425            enable_simd: true,
426        }
427    }
428
429    /// Create config for edge/mobile deployment (<5MB memory)
430    pub fn edge_deployment() -> Self {
431        Self {
432            hidden_dim: 256,
433            embedding_dim: 256,
434            micro_lora_rank: 1, // Minimal rank for memory
435            base_lora_rank: 4,
436            micro_lora_lr: 0.001,
437            base_lora_lr: 0.0001,
438            ewc_lambda: 1000.0,
439            pattern_clusters: 50,
440            trajectory_capacity: 200, // Small buffer
441            background_interval_ms: 3600000,
442            quality_threshold: 0.5,
443            enable_simd: true,
444        }
445    }
446
447    /// Create config for batch processing (50+ inferences/sec)
448    pub fn batch_processing() -> Self {
449        Self {
450            hidden_dim: 256,
451            embedding_dim: 256,
452            micro_lora_rank: 2,
453            base_lora_rank: 8,
454            micro_lora_lr: 0.001,
455            base_lora_lr: 0.0001,
456            ewc_lambda: 2000.0,
457            pattern_clusters: 100,
458            trajectory_capacity: 10000,
459            background_interval_ms: 3600000,
460            quality_threshold: 0.3,
461            enable_simd: true,
462        }
463    }
464
465    /// Create config for ephemeral agents (~5MB footprint)
466    ///
467    /// Optimized for lightweight federated learning nodes that collect
468    /// trajectories locally before aggregation.
469    pub fn for_ephemeral() -> Self {
470        Self {
471            hidden_dim: 256,
472            embedding_dim: 256,
473            micro_lora_rank: 2,
474            base_lora_rank: 4, // Small base for memory efficiency
475            micro_lora_lr: 0.002,
476            base_lora_lr: 0.0001,
477            ewc_lambda: 1000.0,
478            pattern_clusters: 50,          // Fewer clusters for memory
479            trajectory_capacity: 500,      // Local buffer before aggregation
480            background_interval_ms: 60000, // 1 minute for quick local updates
481            quality_threshold: 0.3,
482            enable_simd: true,
483        }
484    }
485
486    /// Create config for federated coordinator (central aggregation)
487    ///
488    /// Optimized for aggregating trajectories from multiple ephemeral agents
489    /// with larger capacity and pattern storage.
490    pub fn for_coordinator() -> Self {
491        Self {
492            hidden_dim: 256,
493            embedding_dim: 256,
494            micro_lora_rank: 2,
495            base_lora_rank: 16,             // Higher rank for aggregated learning
496            micro_lora_lr: 0.001,           // Conservative for stability
497            base_lora_lr: 0.0005,           // Moderate base learning
498            ewc_lambda: 2000.0,             // Strong forgetting prevention
499            pattern_clusters: 200,          // More clusters for diverse patterns
500            trajectory_capacity: 50000,     // Large capacity for aggregation
501            background_interval_ms: 300000, // 5 minutes consolidation
502            quality_threshold: 0.4,         // Higher threshold for quality filtering
503            enable_simd: true,
504        }
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_learning_signal_from_trajectory() {
514        let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
515        trajectory.add_step(TrajectoryStep::new(
516            vec![0.5, 0.3, 0.2],
517            vec![0.4, 0.4, 0.2],
518            0.8,
519            0,
520        ));
521        trajectory.finalize(0.8, 1000);
522
523        let signal = LearningSignal::from_trajectory(&trajectory);
524        assert_eq!(signal.quality_score, 0.8);
525        assert_eq!(signal.gradient_estimate.len(), 3);
526        assert_eq!(signal.metadata.trajectory_id, 1);
527    }
528
529    #[test]
530    fn test_pattern_merge() {
531        let p1 = LearnedPattern {
532            id: 1,
533            centroid: vec![1.0, 0.0],
534            cluster_size: 10,
535            total_weight: 5.0,
536            avg_quality: 0.8,
537            created_at: 100,
538            last_accessed: 200,
539            access_count: 5,
540            pattern_type: PatternType::General,
541        };
542
543        let p2 = LearnedPattern {
544            id: 2,
545            centroid: vec![0.0, 1.0],
546            cluster_size: 10,
547            total_weight: 5.0,
548            avg_quality: 0.9,
549            created_at: 150,
550            last_accessed: 250,
551            access_count: 3,
552            pattern_type: PatternType::General,
553        };
554
555        let merged = p1.merge(&p2);
556        assert_eq!(merged.cluster_size, 20);
557        assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
558        assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
559        assert!((merged.avg_quality - 0.85).abs() < 1e-6);
560    }
561
562    #[test]
563    fn test_pattern_similarity() {
564        let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
565
566        assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
567        assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
568    }
569
570    #[test]
571    fn test_trajectory_rewards() {
572        let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
573        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
574        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
575        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
576
577        assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
578        assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
579    }
580}