ruvector_sona/
types.rs

1//! SONA Core Types
2//!
3//! Defines the fundamental data structures for the Self-Optimizing Neural Architecture.
4
5use serde::{Deserialize, Serialize};
6use std::time::Instant;
7use std::collections::HashMap;
8
9/// Learning signal generated from inference trajectory
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12    /// Query embedding vector
13    pub query_embedding: Vec<f32>,
14    /// Estimated gradient direction
15    pub gradient_estimate: Vec<f32>,
16    /// Quality score [0.0, 1.0]
17    pub quality_score: f32,
18    /// Signal generation timestamp (serialized as nanos)
19    #[serde(skip)]
20    pub timestamp: Option<Instant>,
21    /// Additional metadata
22    pub metadata: SignalMetadata,
23}
24
25/// Metadata for learning signals
26#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28    /// Source trajectory ID
29    pub trajectory_id: u64,
30    /// Number of steps in trajectory
31    pub step_count: usize,
32    /// Model route taken
33    pub model_route: Option<String>,
34    /// Custom tags
35    pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39    /// Create signal from query trajectory using REINFORCE gradient estimation
40    pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41        let gradient = Self::estimate_gradient(trajectory);
42
43        Self {
44            query_embedding: trajectory.query_embedding.clone(),
45            gradient_estimate: gradient,
46            quality_score: trajectory.final_quality,
47            timestamp: Some(Instant::now()),
48            metadata: SignalMetadata {
49                trajectory_id: trajectory.id,
50                step_count: trajectory.steps.len(),
51                model_route: trajectory.model_route.clone(),
52                tags: HashMap::new(),
53            },
54        }
55    }
56
57    /// Create signal with pre-computed gradient
58    pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59        Self {
60            query_embedding: embedding,
61            gradient_estimate: gradient,
62            quality_score: quality,
63            timestamp: Some(Instant::now()),
64            metadata: SignalMetadata::default(),
65        }
66    }
67
68    /// Estimate gradient using REINFORCE with baseline
69    fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70        if trajectory.steps.is_empty() {
71            return trajectory.query_embedding.clone();
72        }
73
74        let dim = trajectory.query_embedding.len();
75        let mut gradient = vec![0.0f32; dim];
76
77        // Compute baseline (average reward)
78        let baseline = trajectory.steps.iter()
79            .map(|s| s.reward)
80            .sum::<f32>() / trajectory.steps.len() as f32;
81
82        // REINFORCE: gradient = sum((reward - baseline) * activation)
83        for step in &trajectory.steps {
84            let advantage = step.reward - baseline;
85            let activation_len = step.activations.len().min(dim);
86            for i in 0..activation_len {
87                gradient[i] += advantage * step.activations[i];
88            }
89        }
90
91        // L2 normalize
92        let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
93        if norm > 1e-8 {
94            gradient.iter_mut().for_each(|x| *x /= norm);
95        }
96
97        gradient
98    }
99
100    /// Scale gradient by quality
101    pub fn scaled_gradient(&self) -> Vec<f32> {
102        self.gradient_estimate.iter()
103            .map(|&g| g * self.quality_score)
104            .collect()
105    }
106}
107
108/// Query trajectory recording
109#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct QueryTrajectory {
111    /// Unique trajectory identifier
112    pub id: u64,
113    /// Query embedding vector
114    pub query_embedding: Vec<f32>,
115    /// Execution steps
116    pub steps: Vec<TrajectoryStep>,
117    /// Final quality score [0.0, 1.0]
118    pub final_quality: f32,
119    /// Total latency in microseconds
120    pub latency_us: u64,
121    /// Model route taken
122    pub model_route: Option<String>,
123    /// Context used
124    pub context_ids: Vec<String>,
125}
126
127impl QueryTrajectory {
128    /// Create new trajectory
129    pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
130        Self {
131            id,
132            query_embedding,
133            steps: Vec::with_capacity(16),
134            final_quality: 0.0,
135            latency_us: 0,
136            model_route: None,
137            context_ids: Vec::new(),
138        }
139    }
140
141    /// Add execution step
142    pub fn add_step(&mut self, step: TrajectoryStep) {
143        self.steps.push(step);
144    }
145
146    /// Finalize trajectory with quality score
147    pub fn finalize(&mut self, quality: f32, latency_us: u64) {
148        self.final_quality = quality;
149        self.latency_us = latency_us;
150    }
151
152    /// Get total reward
153    pub fn total_reward(&self) -> f32 {
154        self.steps.iter().map(|s| s.reward).sum()
155    }
156
157    /// Get average reward
158    pub fn avg_reward(&self) -> f32 {
159        if self.steps.is_empty() {
160            0.0
161        } else {
162            self.total_reward() / self.steps.len() as f32
163        }
164    }
165}
166
167/// Single step in a trajectory
168#[derive(Clone, Debug, Serialize, Deserialize)]
169pub struct TrajectoryStep {
170    /// Layer/module activations (subset for efficiency)
171    pub activations: Vec<f32>,
172    /// Attention weights (flattened)
173    pub attention_weights: Vec<f32>,
174    /// Reward signal for this step
175    pub reward: f32,
176    /// Step index
177    pub step_idx: usize,
178    /// Optional layer name
179    pub layer_name: Option<String>,
180}
181
182impl TrajectoryStep {
183    /// Create new step
184    pub fn new(activations: Vec<f32>, attention_weights: Vec<f32>, reward: f32, step_idx: usize) -> Self {
185        Self {
186            activations,
187            attention_weights,
188            reward,
189            step_idx,
190            layer_name: None,
191        }
192    }
193
194    /// Create step with layer name
195    pub fn with_layer(mut self, name: &str) -> Self {
196        self.layer_name = Some(name.to_string());
197        self
198    }
199}
200
201/// Learned pattern from trajectory clustering
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct LearnedPattern {
204    /// Pattern identifier
205    pub id: u64,
206    /// Cluster centroid embedding
207    pub centroid: Vec<f32>,
208    /// Number of trajectories in cluster
209    pub cluster_size: usize,
210    /// Sum of trajectory weights
211    pub total_weight: f32,
212    /// Average quality of member trajectories
213    pub avg_quality: f32,
214    /// Creation timestamp (Unix seconds)
215    pub created_at: u64,
216    /// Last access timestamp
217    pub last_accessed: u64,
218    /// Total access count
219    pub access_count: u32,
220    /// Pattern type/category
221    pub pattern_type: PatternType,
222}
223
224/// Pattern classification
225#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
226pub enum PatternType {
227    #[default]
228    General,
229    Reasoning,
230    Factual,
231    Creative,
232    CodeGen,
233    Conversational,
234}
235
236impl LearnedPattern {
237    /// Create new pattern
238    pub fn new(id: u64, centroid: Vec<f32>) -> Self {
239        let now = std::time::SystemTime::now()
240            .duration_since(std::time::UNIX_EPOCH)
241            .unwrap_or_default()
242            .as_secs();
243
244        Self {
245            id,
246            centroid,
247            cluster_size: 1,
248            total_weight: 1.0,
249            avg_quality: 0.0,
250            created_at: now,
251            last_accessed: now,
252            access_count: 0,
253            pattern_type: PatternType::default(),
254        }
255    }
256
257    /// Merge two patterns
258    pub fn merge(&self, other: &Self) -> Self {
259        let total_size = self.cluster_size + other.cluster_size;
260        let w1 = self.cluster_size as f32 / total_size as f32;
261        let w2 = other.cluster_size as f32 / total_size as f32;
262
263        let centroid: Vec<f32> = self.centroid.iter()
264            .zip(&other.centroid)
265            .map(|(&a, &b)| a * w1 + b * w2)
266            .collect();
267
268        Self {
269            id: self.id,
270            centroid,
271            cluster_size: total_size,
272            total_weight: self.total_weight + other.total_weight,
273            avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
274            created_at: self.created_at.min(other.created_at),
275            last_accessed: self.last_accessed.max(other.last_accessed),
276            access_count: self.access_count + other.access_count,
277            pattern_type: self.pattern_type.clone(),
278        }
279    }
280
281    /// Decay pattern importance
282    pub fn decay(&mut self, factor: f32) {
283        self.total_weight *= factor;
284    }
285
286    /// Record access
287    pub fn touch(&mut self) {
288        self.access_count += 1;
289        self.last_accessed = std::time::SystemTime::now()
290            .duration_since(std::time::UNIX_EPOCH)
291            .unwrap_or_default()
292            .as_secs();
293    }
294
295    /// Check if pattern should be pruned
296    pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
297        let now = std::time::SystemTime::now()
298            .duration_since(std::time::UNIX_EPOCH)
299            .unwrap_or_default()
300            .as_secs();
301        let age = now.saturating_sub(self.last_accessed);
302
303        self.avg_quality < min_quality
304            && self.access_count < min_accesses
305            && age > max_age_secs
306    }
307
308    /// Compute cosine similarity with query
309    pub fn similarity(&self, query: &[f32]) -> f32 {
310        if self.centroid.len() != query.len() {
311            return 0.0;
312        }
313
314        let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
315        let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
316        let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
317
318        if norm_a > 1e-8 && norm_b > 1e-8 {
319            dot / (norm_a * norm_b)
320        } else {
321            0.0
322        }
323    }
324}
325
326/// SONA configuration
327#[derive(Clone, Debug, Serialize, Deserialize)]
328pub struct SonaConfig {
329    /// Hidden dimension
330    pub hidden_dim: usize,
331    /// Embedding dimension
332    pub embedding_dim: usize,
333    /// Micro-LoRA rank
334    pub micro_lora_rank: usize,
335    /// Base LoRA rank
336    pub base_lora_rank: usize,
337    /// Micro-LoRA learning rate
338    pub micro_lora_lr: f32,
339    /// Base LoRA learning rate
340    pub base_lora_lr: f32,
341    /// EWC lambda
342    pub ewc_lambda: f32,
343    /// Pattern extraction clusters
344    pub pattern_clusters: usize,
345    /// Trajectory buffer capacity
346    pub trajectory_capacity: usize,
347    /// Background learning interval (ms)
348    pub background_interval_ms: u64,
349    /// Quality threshold for learning
350    pub quality_threshold: f32,
351    /// Enable SIMD optimizations
352    pub enable_simd: bool,
353}
354
355impl Default for SonaConfig {
356    fn default() -> Self {
357        Self {
358            hidden_dim: 256,
359            embedding_dim: 256,
360            micro_lora_rank: 1,
361            base_lora_rank: 8,
362            micro_lora_lr: 0.001,
363            base_lora_lr: 0.0001,
364            ewc_lambda: 1000.0,
365            pattern_clusters: 50,
366            trajectory_capacity: 10000,
367            background_interval_ms: 3600000, // 1 hour
368            quality_threshold: 0.5,
369            enable_simd: true,
370        }
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_learning_signal_from_trajectory() {
380        let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
381        trajectory.add_step(TrajectoryStep::new(
382            vec![0.5, 0.3, 0.2],
383            vec![0.4, 0.4, 0.2],
384            0.8,
385            0,
386        ));
387        trajectory.finalize(0.8, 1000);
388
389        let signal = LearningSignal::from_trajectory(&trajectory);
390        assert_eq!(signal.quality_score, 0.8);
391        assert_eq!(signal.gradient_estimate.len(), 3);
392        assert_eq!(signal.metadata.trajectory_id, 1);
393    }
394
395    #[test]
396    fn test_pattern_merge() {
397        let p1 = LearnedPattern {
398            id: 1,
399            centroid: vec![1.0, 0.0],
400            cluster_size: 10,
401            total_weight: 5.0,
402            avg_quality: 0.8,
403            created_at: 100,
404            last_accessed: 200,
405            access_count: 5,
406            pattern_type: PatternType::General,
407        };
408
409        let p2 = LearnedPattern {
410            id: 2,
411            centroid: vec![0.0, 1.0],
412            cluster_size: 10,
413            total_weight: 5.0,
414            avg_quality: 0.9,
415            created_at: 150,
416            last_accessed: 250,
417            access_count: 3,
418            pattern_type: PatternType::General,
419        };
420
421        let merged = p1.merge(&p2);
422        assert_eq!(merged.cluster_size, 20);
423        assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
424        assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
425        assert!((merged.avg_quality - 0.85).abs() < 1e-6);
426    }
427
428    #[test]
429    fn test_pattern_similarity() {
430        let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
431
432        assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
433        assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
434    }
435
436    #[test]
437    fn test_trajectory_rewards() {
438        let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
439        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
440        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
441        trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
442
443        assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
444        assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
445    }
446}