Skip to main content

ruvector_sona/
types.rs

1//! SONA Core Types
2//!
3//! Defines the fundamental data structures for the Self-Optimizing Neural Architecture.
4
5use crate::time_compat::Instant;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Learning signal generated from inference trajectory
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12    /// Query embedding vector
13    pub query_embedding: Vec<f32>,
14    /// Estimated gradient direction
15    pub gradient_estimate: Vec<f32>,
16    /// Quality score [0.0, 1.0]
17    pub quality_score: f32,
18    /// Signal generation timestamp (serialized as nanos)
19    #[serde(skip)]
20    pub timestamp: Option<Instant>,
21    /// Additional metadata
22    pub metadata: SignalMetadata,
23}
24
25/// Metadata for learning signals
26#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28    /// Source trajectory ID
29    pub trajectory_id: u64,
30    /// Number of steps in trajectory
31    pub step_count: usize,
32    /// Model route taken
33    pub model_route: Option<String>,
34    /// Custom tags
35    pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39    /// Create signal from query trajectory using REINFORCE gradient estimation
40    pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41        let gradient = Self::estimate_gradient(trajectory);
42
43        Self {
44            query_embedding: trajectory.query_embedding.clone(),
45            gradient_estimate: gradient,
46            quality_score: trajectory.final_quality,
47            timestamp: Some(Instant::now()),
48            metadata: SignalMetadata {
49                trajectory_id: trajectory.id,
50                step_count: trajectory.steps.len(),
51                model_route: trajectory.model_route.clone(),
52                tags: HashMap::new(),
53            },
54        }
55    }
56
57    /// Create signal with pre-computed gradient
58    pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59        Self {
60            query_embedding: embedding,
61            gradient_estimate: gradient,
62            quality_score: quality,
63            timestamp: Some(Instant::now()),
64            metadata: SignalMetadata::default(),
65        }
66    }
67
68    /// Estimate gradient using REINFORCE with baseline
69    fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70        if trajectory.steps.is_empty() {
71            return trajectory.query_embedding.clone();
72        }
73
74        let dim = trajectory.query_embedding.len();
75        let mut gradient = vec![0.0f32; dim];
76
77        // Compute baseline (average reward)
78        let baseline =
79            trajectory.steps.iter().map(|s| s.reward).sum::<f32>() / trajectory.steps.len() as f32;
80
81        // REINFORCE: gradient = sum((reward - baseline) * activation)
82        for step in &trajectory.steps {
83            let advantage = step.reward - baseline;
84            let activation_len = step.activations.len().min(dim);
85            for (grad, &act) in gradient
86                .iter_mut()
87                .zip(step.activations.iter())
88                .take(activation_len)
89            {
90                *grad += advantage * act;
91            }
92        }
93
94        // L2 normalize
95        let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
96        if norm > 1e-8 {
97            gradient.iter_mut().for_each(|x| *x /= norm);
98        }
99
100        gradient
101    }
102
103    /// Scale gradient by quality
104    pub fn scaled_gradient(&self) -> Vec<f32> {
105        self.gradient_estimate
106            .iter()
107            .map(|&g| g * self.quality_score)
108            .collect()
109    }
110}
111
112/// Query trajectory recording
113#[derive(Clone, Debug, Serialize, Deserialize)]
114pub struct QueryTrajectory {
115    /// Unique trajectory identifier
116    pub id: u64,
117    /// Query embedding vector
118    pub query_embedding: Vec<f32>,
119    /// Execution steps
120    pub steps: Vec<TrajectoryStep>,
121    /// Final quality score [0.0, 1.0]
122    pub final_quality: f32,
123    /// Total latency in microseconds
124    pub latency_us: u64,
125    /// Model route taken
126    pub model_route: Option<String>,
127    /// Context used
128    pub context_ids: Vec<String>,
129}
130
131impl QueryTrajectory {
132    /// Create new trajectory
133    pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
134        Self {
135            id,
136            query_embedding,
137            steps: Vec::with_capacity(16),
138            final_quality: 0.0,
139            latency_us: 0,
140            model_route: None,
141            context_ids: Vec::new(),
142        }
143    }
144
145    /// Add execution step
146    pub fn add_step(&mut self, step: TrajectoryStep) {
147        self.steps.push(step);
148    }
149
150    /// Finalize trajectory with quality score
151    pub fn finalize(&mut self, quality: f32, latency_us: u64) {
152        self.final_quality = quality;
153        self.latency_us = latency_us;
154    }
155
156    /// Get total reward
157    pub fn total_reward(&self) -> f32 {
158        self.steps.iter().map(|s| s.reward).sum()
159    }
160
161    /// Get average reward
162    pub fn avg_reward(&self) -> f32 {
163        if self.steps.is_empty() {
164            0.0
165        } else {
166            self.total_reward() / self.steps.len() as f32
167        }
168    }
169}
170
171/// Single step in a trajectory
172#[derive(Clone, Debug, Serialize, Deserialize)]
173pub struct TrajectoryStep {
174    /// Layer/module activations (subset for efficiency)
175    pub activations: Vec<f32>,
176    /// Attention weights (flattened)
177    pub attention_weights: Vec<f32>,
178    /// Reward signal for this step
179    pub reward: f32,
180    /// Step index
181    pub step_idx: usize,
182    /// Optional layer name
183    pub layer_name: Option<String>,
184}
185
186impl TrajectoryStep {
187    /// Create new step
188    pub fn new(
189        activations: Vec<f32>,
190        attention_weights: Vec<f32>,
191        reward: f32,
192        step_idx: usize,
193    ) -> Self {
194        Self {
195            activations,
196            attention_weights,
197            reward,
198            step_idx,
199            layer_name: None,
200        }
201    }
202
203    /// Create step with layer name
204    pub fn with_layer(mut self, name: &str) -> Self {
205        self.layer_name = Some(name.to_string());
206        self
207    }
208}
209
210/// Learned pattern from trajectory clustering
211#[derive(Clone, Debug, Serialize, Deserialize)]
212pub struct LearnedPattern {
213    /// Pattern identifier
214    pub id: u64,
215    /// Cluster centroid embedding
216    pub centroid: Vec<f32>,
217    /// Number of trajectories in cluster
218    pub cluster_size: usize,
219    /// Sum of trajectory weights
220    pub total_weight: f32,
221    /// Average quality of member trajectories
222    pub avg_quality: f32,
223    /// Creation timestamp (Unix seconds)
224    pub created_at: u64,
225    /// Last access timestamp
226    pub last_accessed: u64,
227    /// Total access count
228    pub access_count: u32,
229    /// Pattern type/category
230    pub pattern_type: PatternType,
231}
232
233/// Pattern classification
234#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
235pub enum PatternType {
236    #[default]
237    General,
238    Reasoning,
239    Factual,
240    Creative,
241    CodeGen,
242    Conversational,
243}
244
245impl std::fmt::Display for PatternType {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        match self {
248            PatternType::General => write!(f, "general"),
249            PatternType::Reasoning => write!(f, "reasoning"),
250            PatternType::Factual => write!(f, "factual"),
251            PatternType::Creative => write!(f, "creative"),
252            PatternType::CodeGen => write!(f, "codegen"),
253            PatternType::Conversational => write!(f, "conversational"),
254        }
255    }
256}
257
258impl LearnedPattern {
259    /// Create new pattern
260    pub fn new(id: u64, centroid: Vec<f32>) -> Self {
261        use crate::time_compat::SystemTime;
262        let now = SystemTime::now().duration_since_epoch().as_secs();
263
264        Self {
265            id,
266            centroid,
267            cluster_size: 1,
268            total_weight: 1.0,
269            avg_quality: 0.0,
270            created_at: now,
271            last_accessed: now,
272            access_count: 0,
273            pattern_type: PatternType::default(),
274        }
275    }
276
277    /// Merge two patterns
278    pub fn merge(&self, other: &Self) -> Self {
279        let total_size = self.cluster_size + other.cluster_size;
280        let w1 = self.cluster_size as f32 / total_size as f32;
281        let w2 = other.cluster_size as f32 / total_size as f32;
282
283        let centroid: Vec<f32> = self
284            .centroid
285            .iter()
286            .zip(&other.centroid)
287            .map(|(&a, &b)| a * w1 + b * w2)
288            .collect();
289
290        Self {
291            id: self.id,
292            centroid,
293            cluster_size: total_size,
294            total_weight: self.total_weight + other.total_weight,
295            avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
296            created_at: self.created_at.min(other.created_at),
297            last_accessed: self.last_accessed.max(other.last_accessed),
298            access_count: self.access_count + other.access_count,
299            pattern_type: self.pattern_type.clone(),
300        }
301    }
302
303    /// Decay pattern importance
304    pub fn decay(&mut self, factor: f32) {
305        self.total_weight *= factor;
306    }
307
308    /// Record access
309    pub fn touch(&mut self) {
310        use crate::time_compat::SystemTime;
311        self.access_count += 1;
312        self.last_accessed = SystemTime::now().duration_since_epoch().as_secs();
313    }
314
315    /// Check if pattern should be pruned
316    pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
317        use crate::time_compat::SystemTime;
318        let now = SystemTime::now().duration_since_epoch().as_secs();
319        let age = now.saturating_sub(self.last_accessed);
320
321        self.avg_quality < min_quality && self.access_count < min_accesses && age > max_age_secs
322    }
323
324    /// Compute cosine similarity with query
325    pub fn similarity(&self, query: &[f32]) -> f32 {
326        if self.centroid.len() != query.len() {
327            return 0.0;
328        }
329
330        let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
331        let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
332        let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
333
334        if norm_a > 1e-8 && norm_b > 1e-8 {
335            dot / (norm_a * norm_b)
336        } else {
337            0.0
338        }
339    }
340}
341
342/// SONA configuration
343#[derive(Clone, Debug, Serialize, Deserialize)]
344pub struct SonaConfig {
345    /// Hidden dimension
346    pub hidden_dim: usize,
347    /// Embedding dimension
348    pub embedding_dim: usize,
349    /// Micro-LoRA rank
350    pub micro_lora_rank: usize,
351    /// Base LoRA rank
352    pub base_lora_rank: usize,
353    /// Micro-LoRA learning rate
354    pub micro_lora_lr: f32,
355    /// Base LoRA learning rate
356    pub base_lora_lr: f32,
357    /// EWC lambda
358    pub ewc_lambda: f32,
359    /// Pattern extraction clusters
360    pub pattern_clusters: usize,
361    /// Trajectory buffer capacity
362    pub trajectory_capacity: usize,
363    /// Background learning interval (ms)
364    pub background_interval_ms: u64,
365    /// Quality threshold for learning
366    pub quality_threshold: f32,
367    /// Enable SIMD optimizations
368    pub enable_simd: bool,
369}
370
371impl Default for SonaConfig {
372    fn default() -> Self {
373        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
374        // - Rank-2 is 5% faster than Rank-1 due to better SIMD vectorization
375        // - Learning rate 0.002 yields +55% quality improvement
376        // - 100 clusters = 1.3ms search vs 50 clusters = 3.0ms (2.3x faster)
377        // - EWC lambda 2000 optimal for catastrophic forgetting prevention
378        // - Quality threshold 0.3 balances learning vs noise filtering
379        Self {
380            hidden_dim: 256,
381            embedding_dim: 256,
382            micro_lora_rank: 2, // OPTIMIZED: Rank-2 faster than Rank-1 (2,211 vs 2,100 ops/sec)
383            base_lora_rank: 8,  // Balanced for production
384            micro_lora_lr: 0.002, // OPTIMIZED: +55.3% quality improvement
385            base_lora_lr: 0.0001,
386            ewc_lambda: 2000.0,    // OPTIMIZED: Better forgetting prevention
387            pattern_clusters: 100, // OPTIMIZED: 2.3x faster search (1.3ms vs 3.0ms)
388            trajectory_capacity: 10000,
389            background_interval_ms: 3600000, // 1 hour
390            quality_threshold: 0.15,         // Was 0.3; lowered 50% so patterns crystallize earlier
391            enable_simd: true,
392        }
393    }
394}
395
396impl SonaConfig {
397    /// Create config optimized for maximum throughput (real-time chat)
398    pub fn max_throughput() -> Self {
399        Self {
400            hidden_dim: 256,
401            embedding_dim: 256,
402            micro_lora_rank: 2,    // Rank-2 + SIMD = 2,211 ops/sec
403            base_lora_rank: 4,     // Minimal base for speed
404            micro_lora_lr: 0.0005, // Conservative for stability
405            base_lora_lr: 0.0001,
406            ewc_lambda: 2000.0,
407            pattern_clusters: 100,
408            trajectory_capacity: 5000,
409            background_interval_ms: 7200000, // 2 hours
410            quality_threshold: 0.4,
411            enable_simd: true,
412        }
413    }
414
415    /// Create config optimized for maximum quality (research/batch)
416    pub fn max_quality() -> Self {
417        Self {
418            hidden_dim: 256,
419            embedding_dim: 256,
420            micro_lora_rank: 2,
421            base_lora_rank: 16,   // Higher rank for expressiveness
422            micro_lora_lr: 0.002, // Optimal learning rate
423            base_lora_lr: 0.001,  // Aggressive base learning
424            ewc_lambda: 2000.0,
425            pattern_clusters: 100,
426            trajectory_capacity: 20000,
427            background_interval_ms: 1800000, // 30 minutes
428            quality_threshold: 0.2,          // Learn from more trajectories
429            enable_simd: true,
430        }
431    }
432
433    /// Create config for edge/mobile deployment (<5MB memory)
434    pub fn edge_deployment() -> Self {
435        Self {
436            hidden_dim: 256,
437            embedding_dim: 256,
438            micro_lora_rank: 1, // Minimal rank for memory
439            base_lora_rank: 4,
440            micro_lora_lr: 0.001,
441            base_lora_lr: 0.0001,
442            ewc_lambda: 1000.0,
443            pattern_clusters: 50,
444            trajectory_capacity: 200, // Small buffer
445            background_interval_ms: 3600000,
446            quality_threshold: 0.5,
447            enable_simd: true,
448        }
449    }
450
451    /// Create config for batch processing (50+ inferences/sec)
452    pub fn batch_processing() -> Self {
453        Self {
454            hidden_dim: 256,
455            embedding_dim: 256,
456            micro_lora_rank: 2,
457            base_lora_rank: 8,
458            micro_lora_lr: 0.001,
459            base_lora_lr: 0.0001,
460            ewc_lambda: 2000.0,
461            pattern_clusters: 100,
462            trajectory_capacity: 10000,
463            background_interval_ms: 3600000,
464            quality_threshold: 0.3,
465            enable_simd: true,
466        }
467    }
468
469    /// Create config for ephemeral agents (~5MB footprint)
470    ///
471    /// Optimized for lightweight federated learning nodes that collect
472    /// trajectories locally before aggregation.
473    pub fn for_ephemeral() -> Self {
474        Self {
475            hidden_dim: 256,
476            embedding_dim: 256,
477            micro_lora_rank: 2,
478            base_lora_rank: 4, // Small base for memory efficiency
479            micro_lora_lr: 0.002,
480            base_lora_lr: 0.0001,
481            ewc_lambda: 1000.0,
482            pattern_clusters: 50,          // Fewer clusters for memory
483            trajectory_capacity: 500,      // Local buffer before aggregation
484            background_interval_ms: 60000, // 1 minute for quick local updates
485            quality_threshold: 0.3,
486            enable_simd: true,
487        }
488    }
489
490    /// Create config for federated coordinator (central aggregation)
491    ///
492    /// Optimized for aggregating trajectories from multiple ephemeral agents
493    /// with larger capacity and pattern storage.
494    pub fn for_coordinator() -> Self {
495        Self {
496            hidden_dim: 256,
497            embedding_dim: 256,
498            micro_lora_rank: 2,
499            base_lora_rank: 16,             // Higher rank for aggregated learning
500            micro_lora_lr: 0.001,           // Conservative for stability
501            base_lora_lr: 0.0005,           // Moderate base learning
502            ewc_lambda: 2000.0,             // Strong forgetting prevention
503            pattern_clusters: 200,          // More clusters for diverse patterns
504            trajectory_capacity: 50000,     // Large capacity for aggregation
505            background_interval_ms: 300000, // 5 minutes consolidation
506            quality_threshold: 0.4,         // Higher threshold for quality filtering
507            enable_simd: true,
508        }
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_learning_signal_from_trajectory() {
518        let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
519        trajectory.add_step(TrajectoryStep::new(
520            vec![0.5, 0.3, 0.2],
521            vec![0.4, 0.4, 0.2],
522            0.8,
523            0,
524        ));
525        trajectory.finalize(0.8, 1000);
526
527        let signal = LearningSignal::from_trajectory(&trajectory);
528        assert_eq!(signal.quality_score, 0.8);
529        assert_eq!(signal.gradient_estimate.len(), 3);
530        assert_eq!(signal.metadata.trajectory_id, 1);
531    }
532
533    #[test]
534    fn test_pattern_merge() {
535        let p1 = LearnedPattern {
536            id: 1,
537            centroid: vec![1.0, 0.0],
538            cluster_size: 10,
539            total_weight: 5.0,
540            avg_quality: 0.8,
541            created_at: 100,
542            last_accessed: 200,
543            access_count: 5,
544            pattern_type: PatternType::General,
545        };
546
547        let p2 = LearnedPattern {
548            id: 2,
549            centroid: vec![0.0, 1.0],
550            cluster_size: 10,
551            total_weight: 5.0,
552            avg_quality: 0.9,
553            created_at: 150,
554            last_accessed: 250,
555            access_count: 3,
556            pattern_type: PatternType::General,
557        };
558
559        let merged = p1.merge(&p2);
560        assert_eq!(merged.cluster_size, 20);
561        assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
562        assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
563        assert!((merged.avg_quality - 0.85).abs() < 1e-6);
564    }
565
566    #[test]
567    fn test_pattern_similarity() {
568        let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
569
570        assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
571        assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
572    }
573
574    #[test]
575    fn test_trajectory_rewards() {
576        let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
577        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
578        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
579        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
580
581        assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
582        assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
583    }
584}