1use serde::{Deserialize, Serialize};
6use std::time::Instant;
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12 pub query_embedding: Vec<f32>,
14 pub gradient_estimate: Vec<f32>,
16 pub quality_score: f32,
18 #[serde(skip)]
20 pub timestamp: Option<Instant>,
21 pub metadata: SignalMetadata,
23}
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28 pub trajectory_id: u64,
30 pub step_count: usize,
32 pub model_route: Option<String>,
34 pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39 pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41 let gradient = Self::estimate_gradient(trajectory);
42
43 Self {
44 query_embedding: trajectory.query_embedding.clone(),
45 gradient_estimate: gradient,
46 quality_score: trajectory.final_quality,
47 timestamp: Some(Instant::now()),
48 metadata: SignalMetadata {
49 trajectory_id: trajectory.id,
50 step_count: trajectory.steps.len(),
51 model_route: trajectory.model_route.clone(),
52 tags: HashMap::new(),
53 },
54 }
55 }
56
57 pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59 Self {
60 query_embedding: embedding,
61 gradient_estimate: gradient,
62 quality_score: quality,
63 timestamp: Some(Instant::now()),
64 metadata: SignalMetadata::default(),
65 }
66 }
67
68 fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70 if trajectory.steps.is_empty() {
71 return trajectory.query_embedding.clone();
72 }
73
74 let dim = trajectory.query_embedding.len();
75 let mut gradient = vec![0.0f32; dim];
76
77 let baseline = trajectory.steps.iter()
79 .map(|s| s.reward)
80 .sum::<f32>() / trajectory.steps.len() as f32;
81
82 for step in &trajectory.steps {
84 let advantage = step.reward - baseline;
85 let activation_len = step.activations.len().min(dim);
86 for i in 0..activation_len {
87 gradient[i] += advantage * step.activations[i];
88 }
89 }
90
91 let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
93 if norm > 1e-8 {
94 gradient.iter_mut().for_each(|x| *x /= norm);
95 }
96
97 gradient
98 }
99
100 pub fn scaled_gradient(&self) -> Vec<f32> {
102 self.gradient_estimate.iter()
103 .map(|&g| g * self.quality_score)
104 .collect()
105 }
106}
107
108#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct QueryTrajectory {
111 pub id: u64,
113 pub query_embedding: Vec<f32>,
115 pub steps: Vec<TrajectoryStep>,
117 pub final_quality: f32,
119 pub latency_us: u64,
121 pub model_route: Option<String>,
123 pub context_ids: Vec<String>,
125}
126
127impl QueryTrajectory {
128 pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
130 Self {
131 id,
132 query_embedding,
133 steps: Vec::with_capacity(16),
134 final_quality: 0.0,
135 latency_us: 0,
136 model_route: None,
137 context_ids: Vec::new(),
138 }
139 }
140
141 pub fn add_step(&mut self, step: TrajectoryStep) {
143 self.steps.push(step);
144 }
145
146 pub fn finalize(&mut self, quality: f32, latency_us: u64) {
148 self.final_quality = quality;
149 self.latency_us = latency_us;
150 }
151
152 pub fn total_reward(&self) -> f32 {
154 self.steps.iter().map(|s| s.reward).sum()
155 }
156
157 pub fn avg_reward(&self) -> f32 {
159 if self.steps.is_empty() {
160 0.0
161 } else {
162 self.total_reward() / self.steps.len() as f32
163 }
164 }
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize)]
169pub struct TrajectoryStep {
170 pub activations: Vec<f32>,
172 pub attention_weights: Vec<f32>,
174 pub reward: f32,
176 pub step_idx: usize,
178 pub layer_name: Option<String>,
180}
181
182impl TrajectoryStep {
183 pub fn new(activations: Vec<f32>, attention_weights: Vec<f32>, reward: f32, step_idx: usize) -> Self {
185 Self {
186 activations,
187 attention_weights,
188 reward,
189 step_idx,
190 layer_name: None,
191 }
192 }
193
194 pub fn with_layer(mut self, name: &str) -> Self {
196 self.layer_name = Some(name.to_string());
197 self
198 }
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct LearnedPattern {
204 pub id: u64,
206 pub centroid: Vec<f32>,
208 pub cluster_size: usize,
210 pub total_weight: f32,
212 pub avg_quality: f32,
214 pub created_at: u64,
216 pub last_accessed: u64,
218 pub access_count: u32,
220 pub pattern_type: PatternType,
222}
223
224#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
226pub enum PatternType {
227 #[default]
228 General,
229 Reasoning,
230 Factual,
231 Creative,
232 CodeGen,
233 Conversational,
234}
235
236impl std::fmt::Display for PatternType {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 PatternType::General => write!(f, "general"),
240 PatternType::Reasoning => write!(f, "reasoning"),
241 PatternType::Factual => write!(f, "factual"),
242 PatternType::Creative => write!(f, "creative"),
243 PatternType::CodeGen => write!(f, "codegen"),
244 PatternType::Conversational => write!(f, "conversational"),
245 }
246 }
247}
248
249impl LearnedPattern {
250 pub fn new(id: u64, centroid: Vec<f32>) -> Self {
252 let now = std::time::SystemTime::now()
253 .duration_since(std::time::UNIX_EPOCH)
254 .unwrap_or_default()
255 .as_secs();
256
257 Self {
258 id,
259 centroid,
260 cluster_size: 1,
261 total_weight: 1.0,
262 avg_quality: 0.0,
263 created_at: now,
264 last_accessed: now,
265 access_count: 0,
266 pattern_type: PatternType::default(),
267 }
268 }
269
270 pub fn merge(&self, other: &Self) -> Self {
272 let total_size = self.cluster_size + other.cluster_size;
273 let w1 = self.cluster_size as f32 / total_size as f32;
274 let w2 = other.cluster_size as f32 / total_size as f32;
275
276 let centroid: Vec<f32> = self.centroid.iter()
277 .zip(&other.centroid)
278 .map(|(&a, &b)| a * w1 + b * w2)
279 .collect();
280
281 Self {
282 id: self.id,
283 centroid,
284 cluster_size: total_size,
285 total_weight: self.total_weight + other.total_weight,
286 avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
287 created_at: self.created_at.min(other.created_at),
288 last_accessed: self.last_accessed.max(other.last_accessed),
289 access_count: self.access_count + other.access_count,
290 pattern_type: self.pattern_type.clone(),
291 }
292 }
293
294 pub fn decay(&mut self, factor: f32) {
296 self.total_weight *= factor;
297 }
298
299 pub fn touch(&mut self) {
301 self.access_count += 1;
302 self.last_accessed = std::time::SystemTime::now()
303 .duration_since(std::time::UNIX_EPOCH)
304 .unwrap_or_default()
305 .as_secs();
306 }
307
308 pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
310 let now = std::time::SystemTime::now()
311 .duration_since(std::time::UNIX_EPOCH)
312 .unwrap_or_default()
313 .as_secs();
314 let age = now.saturating_sub(self.last_accessed);
315
316 self.avg_quality < min_quality
317 && self.access_count < min_accesses
318 && age > max_age_secs
319 }
320
321 pub fn similarity(&self, query: &[f32]) -> f32 {
323 if self.centroid.len() != query.len() {
324 return 0.0;
325 }
326
327 let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
328 let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
329 let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
330
331 if norm_a > 1e-8 && norm_b > 1e-8 {
332 dot / (norm_a * norm_b)
333 } else {
334 0.0
335 }
336 }
337}
338
339#[derive(Clone, Debug, Serialize, Deserialize)]
341pub struct SonaConfig {
342 pub hidden_dim: usize,
344 pub embedding_dim: usize,
346 pub micro_lora_rank: usize,
348 pub base_lora_rank: usize,
350 pub micro_lora_lr: f32,
352 pub base_lora_lr: f32,
354 pub ewc_lambda: f32,
356 pub pattern_clusters: usize,
358 pub trajectory_capacity: usize,
360 pub background_interval_ms: u64,
362 pub quality_threshold: f32,
364 pub enable_simd: bool,
366}
367
368impl Default for SonaConfig {
369 fn default() -> Self {
370 Self {
377 hidden_dim: 256,
378 embedding_dim: 256,
379 micro_lora_rank: 2, base_lora_rank: 8, micro_lora_lr: 0.002, base_lora_lr: 0.0001,
383 ewc_lambda: 2000.0, pattern_clusters: 100, trajectory_capacity: 10000,
386 background_interval_ms: 3600000, quality_threshold: 0.3, enable_simd: true,
389 }
390 }
391}
392
393impl SonaConfig {
394 pub fn max_throughput() -> Self {
396 Self {
397 hidden_dim: 256,
398 embedding_dim: 256,
399 micro_lora_rank: 2, base_lora_rank: 4, micro_lora_lr: 0.0005, base_lora_lr: 0.0001,
403 ewc_lambda: 2000.0,
404 pattern_clusters: 100,
405 trajectory_capacity: 5000,
406 background_interval_ms: 7200000, quality_threshold: 0.4,
408 enable_simd: true,
409 }
410 }
411
412 pub fn max_quality() -> Self {
414 Self {
415 hidden_dim: 256,
416 embedding_dim: 256,
417 micro_lora_rank: 2,
418 base_lora_rank: 16, micro_lora_lr: 0.002, base_lora_lr: 0.001, ewc_lambda: 2000.0,
422 pattern_clusters: 100,
423 trajectory_capacity: 20000,
424 background_interval_ms: 1800000, quality_threshold: 0.2, enable_simd: true,
427 }
428 }
429
430 pub fn edge_deployment() -> Self {
432 Self {
433 hidden_dim: 256,
434 embedding_dim: 256,
435 micro_lora_rank: 1, base_lora_rank: 4,
437 micro_lora_lr: 0.001,
438 base_lora_lr: 0.0001,
439 ewc_lambda: 1000.0,
440 pattern_clusters: 50,
441 trajectory_capacity: 200, background_interval_ms: 3600000,
443 quality_threshold: 0.5,
444 enable_simd: true,
445 }
446 }
447
448 pub fn batch_processing() -> Self {
450 Self {
451 hidden_dim: 256,
452 embedding_dim: 256,
453 micro_lora_rank: 2,
454 base_lora_rank: 8,
455 micro_lora_lr: 0.001,
456 base_lora_lr: 0.0001,
457 ewc_lambda: 2000.0,
458 pattern_clusters: 100,
459 trajectory_capacity: 10000,
460 background_interval_ms: 3600000,
461 quality_threshold: 0.3,
462 enable_simd: true,
463 }
464 }
465
466 pub fn for_ephemeral() -> Self {
471 Self {
472 hidden_dim: 256,
473 embedding_dim: 256,
474 micro_lora_rank: 2,
475 base_lora_rank: 4, micro_lora_lr: 0.002,
477 base_lora_lr: 0.0001,
478 ewc_lambda: 1000.0,
479 pattern_clusters: 50, trajectory_capacity: 500, background_interval_ms: 60000, quality_threshold: 0.3,
483 enable_simd: true,
484 }
485 }
486
487 pub fn for_coordinator() -> Self {
492 Self {
493 hidden_dim: 256,
494 embedding_dim: 256,
495 micro_lora_rank: 2,
496 base_lora_rank: 16, micro_lora_lr: 0.001, base_lora_lr: 0.0005, ewc_lambda: 2000.0, pattern_clusters: 200, trajectory_capacity: 50000, background_interval_ms: 300000, quality_threshold: 0.4, enable_simd: true,
505 }
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_learning_signal_from_trajectory() {
515 let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
516 trajectory.add_step(TrajectoryStep::new(
517 vec![0.5, 0.3, 0.2],
518 vec![0.4, 0.4, 0.2],
519 0.8,
520 0,
521 ));
522 trajectory.finalize(0.8, 1000);
523
524 let signal = LearningSignal::from_trajectory(&trajectory);
525 assert_eq!(signal.quality_score, 0.8);
526 assert_eq!(signal.gradient_estimate.len(), 3);
527 assert_eq!(signal.metadata.trajectory_id, 1);
528 }
529
530 #[test]
531 fn test_pattern_merge() {
532 let p1 = LearnedPattern {
533 id: 1,
534 centroid: vec![1.0, 0.0],
535 cluster_size: 10,
536 total_weight: 5.0,
537 avg_quality: 0.8,
538 created_at: 100,
539 last_accessed: 200,
540 access_count: 5,
541 pattern_type: PatternType::General,
542 };
543
544 let p2 = LearnedPattern {
545 id: 2,
546 centroid: vec![0.0, 1.0],
547 cluster_size: 10,
548 total_weight: 5.0,
549 avg_quality: 0.9,
550 created_at: 150,
551 last_accessed: 250,
552 access_count: 3,
553 pattern_type: PatternType::General,
554 };
555
556 let merged = p1.merge(&p2);
557 assert_eq!(merged.cluster_size, 20);
558 assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
559 assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
560 assert!((merged.avg_quality - 0.85).abs() < 1e-6);
561 }
562
563 #[test]
564 fn test_pattern_similarity() {
565 let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
566
567 assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
568 assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
569 }
570
571 #[test]
572 fn test_trajectory_rewards() {
573 let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
574 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
575 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
576 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
577
578 assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
579 assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
580 }
581}