1use serde::{Deserialize, Serialize};
6use std::time::Instant;
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12 pub query_embedding: Vec<f32>,
14 pub gradient_estimate: Vec<f32>,
16 pub quality_score: f32,
18 #[serde(skip)]
20 pub timestamp: Option<Instant>,
21 pub metadata: SignalMetadata,
23}
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28 pub trajectory_id: u64,
30 pub step_count: usize,
32 pub model_route: Option<String>,
34 pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39 pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41 let gradient = Self::estimate_gradient(trajectory);
42
43 Self {
44 query_embedding: trajectory.query_embedding.clone(),
45 gradient_estimate: gradient,
46 quality_score: trajectory.final_quality,
47 timestamp: Some(Instant::now()),
48 metadata: SignalMetadata {
49 trajectory_id: trajectory.id,
50 step_count: trajectory.steps.len(),
51 model_route: trajectory.model_route.clone(),
52 tags: HashMap::new(),
53 },
54 }
55 }
56
57 pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59 Self {
60 query_embedding: embedding,
61 gradient_estimate: gradient,
62 quality_score: quality,
63 timestamp: Some(Instant::now()),
64 metadata: SignalMetadata::default(),
65 }
66 }
67
68 fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70 if trajectory.steps.is_empty() {
71 return trajectory.query_embedding.clone();
72 }
73
74 let dim = trajectory.query_embedding.len();
75 let mut gradient = vec![0.0f32; dim];
76
77 let baseline = trajectory.steps.iter()
79 .map(|s| s.reward)
80 .sum::<f32>() / trajectory.steps.len() as f32;
81
82 for step in &trajectory.steps {
84 let advantage = step.reward - baseline;
85 let activation_len = step.activations.len().min(dim);
86 for i in 0..activation_len {
87 gradient[i] += advantage * step.activations[i];
88 }
89 }
90
91 let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
93 if norm > 1e-8 {
94 gradient.iter_mut().for_each(|x| *x /= norm);
95 }
96
97 gradient
98 }
99
100 pub fn scaled_gradient(&self) -> Vec<f32> {
102 self.gradient_estimate.iter()
103 .map(|&g| g * self.quality_score)
104 .collect()
105 }
106}
107
108#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct QueryTrajectory {
111 pub id: u64,
113 pub query_embedding: Vec<f32>,
115 pub steps: Vec<TrajectoryStep>,
117 pub final_quality: f32,
119 pub latency_us: u64,
121 pub model_route: Option<String>,
123 pub context_ids: Vec<String>,
125}
126
127impl QueryTrajectory {
128 pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
130 Self {
131 id,
132 query_embedding,
133 steps: Vec::with_capacity(16),
134 final_quality: 0.0,
135 latency_us: 0,
136 model_route: None,
137 context_ids: Vec::new(),
138 }
139 }
140
141 pub fn add_step(&mut self, step: TrajectoryStep) {
143 self.steps.push(step);
144 }
145
146 pub fn finalize(&mut self, quality: f32, latency_us: u64) {
148 self.final_quality = quality;
149 self.latency_us = latency_us;
150 }
151
152 pub fn total_reward(&self) -> f32 {
154 self.steps.iter().map(|s| s.reward).sum()
155 }
156
157 pub fn avg_reward(&self) -> f32 {
159 if self.steps.is_empty() {
160 0.0
161 } else {
162 self.total_reward() / self.steps.len() as f32
163 }
164 }
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize)]
169pub struct TrajectoryStep {
170 pub activations: Vec<f32>,
172 pub attention_weights: Vec<f32>,
174 pub reward: f32,
176 pub step_idx: usize,
178 pub layer_name: Option<String>,
180}
181
182impl TrajectoryStep {
183 pub fn new(activations: Vec<f32>, attention_weights: Vec<f32>, reward: f32, step_idx: usize) -> Self {
185 Self {
186 activations,
187 attention_weights,
188 reward,
189 step_idx,
190 layer_name: None,
191 }
192 }
193
194 pub fn with_layer(mut self, name: &str) -> Self {
196 self.layer_name = Some(name.to_string());
197 self
198 }
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct LearnedPattern {
204 pub id: u64,
206 pub centroid: Vec<f32>,
208 pub cluster_size: usize,
210 pub total_weight: f32,
212 pub avg_quality: f32,
214 pub created_at: u64,
216 pub last_accessed: u64,
218 pub access_count: u32,
220 pub pattern_type: PatternType,
222}
223
224#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
226pub enum PatternType {
227 #[default]
228 General,
229 Reasoning,
230 Factual,
231 Creative,
232 CodeGen,
233 Conversational,
234}
235
236impl LearnedPattern {
237 pub fn new(id: u64, centroid: Vec<f32>) -> Self {
239 let now = std::time::SystemTime::now()
240 .duration_since(std::time::UNIX_EPOCH)
241 .unwrap_or_default()
242 .as_secs();
243
244 Self {
245 id,
246 centroid,
247 cluster_size: 1,
248 total_weight: 1.0,
249 avg_quality: 0.0,
250 created_at: now,
251 last_accessed: now,
252 access_count: 0,
253 pattern_type: PatternType::default(),
254 }
255 }
256
257 pub fn merge(&self, other: &Self) -> Self {
259 let total_size = self.cluster_size + other.cluster_size;
260 let w1 = self.cluster_size as f32 / total_size as f32;
261 let w2 = other.cluster_size as f32 / total_size as f32;
262
263 let centroid: Vec<f32> = self.centroid.iter()
264 .zip(&other.centroid)
265 .map(|(&a, &b)| a * w1 + b * w2)
266 .collect();
267
268 Self {
269 id: self.id,
270 centroid,
271 cluster_size: total_size,
272 total_weight: self.total_weight + other.total_weight,
273 avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
274 created_at: self.created_at.min(other.created_at),
275 last_accessed: self.last_accessed.max(other.last_accessed),
276 access_count: self.access_count + other.access_count,
277 pattern_type: self.pattern_type.clone(),
278 }
279 }
280
281 pub fn decay(&mut self, factor: f32) {
283 self.total_weight *= factor;
284 }
285
286 pub fn touch(&mut self) {
288 self.access_count += 1;
289 self.last_accessed = std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .unwrap_or_default()
292 .as_secs();
293 }
294
295 pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
297 let now = std::time::SystemTime::now()
298 .duration_since(std::time::UNIX_EPOCH)
299 .unwrap_or_default()
300 .as_secs();
301 let age = now.saturating_sub(self.last_accessed);
302
303 self.avg_quality < min_quality
304 && self.access_count < min_accesses
305 && age > max_age_secs
306 }
307
308 pub fn similarity(&self, query: &[f32]) -> f32 {
310 if self.centroid.len() != query.len() {
311 return 0.0;
312 }
313
314 let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
315 let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
316 let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
317
318 if norm_a > 1e-8 && norm_b > 1e-8 {
319 dot / (norm_a * norm_b)
320 } else {
321 0.0
322 }
323 }
324}
325
326#[derive(Clone, Debug, Serialize, Deserialize)]
328pub struct SonaConfig {
329 pub hidden_dim: usize,
331 pub embedding_dim: usize,
333 pub micro_lora_rank: usize,
335 pub base_lora_rank: usize,
337 pub micro_lora_lr: f32,
339 pub base_lora_lr: f32,
341 pub ewc_lambda: f32,
343 pub pattern_clusters: usize,
345 pub trajectory_capacity: usize,
347 pub background_interval_ms: u64,
349 pub quality_threshold: f32,
351 pub enable_simd: bool,
353}
354
355impl Default for SonaConfig {
356 fn default() -> Self {
357 Self {
358 hidden_dim: 256,
359 embedding_dim: 256,
360 micro_lora_rank: 1,
361 base_lora_rank: 8,
362 micro_lora_lr: 0.001,
363 base_lora_lr: 0.0001,
364 ewc_lambda: 1000.0,
365 pattern_clusters: 50,
366 trajectory_capacity: 10000,
367 background_interval_ms: 3600000, quality_threshold: 0.5,
369 enable_simd: true,
370 }
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_learning_signal_from_trajectory() {
380 let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
381 trajectory.add_step(TrajectoryStep::new(
382 vec![0.5, 0.3, 0.2],
383 vec![0.4, 0.4, 0.2],
384 0.8,
385 0,
386 ));
387 trajectory.finalize(0.8, 1000);
388
389 let signal = LearningSignal::from_trajectory(&trajectory);
390 assert_eq!(signal.quality_score, 0.8);
391 assert_eq!(signal.gradient_estimate.len(), 3);
392 assert_eq!(signal.metadata.trajectory_id, 1);
393 }
394
395 #[test]
396 fn test_pattern_merge() {
397 let p1 = LearnedPattern {
398 id: 1,
399 centroid: vec![1.0, 0.0],
400 cluster_size: 10,
401 total_weight: 5.0,
402 avg_quality: 0.8,
403 created_at: 100,
404 last_accessed: 200,
405 access_count: 5,
406 pattern_type: PatternType::General,
407 };
408
409 let p2 = LearnedPattern {
410 id: 2,
411 centroid: vec![0.0, 1.0],
412 cluster_size: 10,
413 total_weight: 5.0,
414 avg_quality: 0.9,
415 created_at: 150,
416 last_accessed: 250,
417 access_count: 3,
418 pattern_type: PatternType::General,
419 };
420
421 let merged = p1.merge(&p2);
422 assert_eq!(merged.cluster_size, 20);
423 assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
424 assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
425 assert!((merged.avg_quality - 0.85).abs() < 1e-6);
426 }
427
428 #[test]
429 fn test_pattern_similarity() {
430 let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
431
432 assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
433 assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
434 }
435
436 #[test]
437 fn test_trajectory_rewards() {
438 let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
439 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
440 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
441 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
442
443 assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
444 assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
445 }
446}