1use crate::time_compat::Instant;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12 pub query_embedding: Vec<f32>,
14 pub gradient_estimate: Vec<f32>,
16 pub quality_score: f32,
18 #[serde(skip)]
20 pub timestamp: Option<Instant>,
21 pub metadata: SignalMetadata,
23}
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28 pub trajectory_id: u64,
30 pub step_count: usize,
32 pub model_route: Option<String>,
34 pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39 pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41 let gradient = Self::estimate_gradient(trajectory);
42
43 Self {
44 query_embedding: trajectory.query_embedding.clone(),
45 gradient_estimate: gradient,
46 quality_score: trajectory.final_quality,
47 timestamp: Some(Instant::now()),
48 metadata: SignalMetadata {
49 trajectory_id: trajectory.id,
50 step_count: trajectory.steps.len(),
51 model_route: trajectory.model_route.clone(),
52 tags: HashMap::new(),
53 },
54 }
55 }
56
57 pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59 Self {
60 query_embedding: embedding,
61 gradient_estimate: gradient,
62 quality_score: quality,
63 timestamp: Some(Instant::now()),
64 metadata: SignalMetadata::default(),
65 }
66 }
67
68 fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70 if trajectory.steps.is_empty() {
71 return trajectory.query_embedding.clone();
72 }
73
74 let dim = trajectory.query_embedding.len();
75 let mut gradient = vec![0.0f32; dim];
76
77 let baseline =
79 trajectory.steps.iter().map(|s| s.reward).sum::<f32>() / trajectory.steps.len() as f32;
80
81 for step in &trajectory.steps {
83 let advantage = step.reward - baseline;
84 let activation_len = step.activations.len().min(dim);
85 for (grad, &act) in gradient
86 .iter_mut()
87 .zip(step.activations.iter())
88 .take(activation_len)
89 {
90 *grad += advantage * act;
91 }
92 }
93
94 let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
96 if norm > 1e-8 {
97 gradient.iter_mut().for_each(|x| *x /= norm);
98 }
99
100 gradient
101 }
102
103 pub fn scaled_gradient(&self) -> Vec<f32> {
105 self.gradient_estimate
106 .iter()
107 .map(|&g| g * self.quality_score)
108 .collect()
109 }
110}
111
112#[derive(Clone, Debug, Serialize, Deserialize)]
114pub struct QueryTrajectory {
115 pub id: u64,
117 pub query_embedding: Vec<f32>,
119 pub steps: Vec<TrajectoryStep>,
121 pub final_quality: f32,
123 pub latency_us: u64,
125 pub model_route: Option<String>,
127 pub context_ids: Vec<String>,
129}
130
131impl QueryTrajectory {
132 pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
134 Self {
135 id,
136 query_embedding,
137 steps: Vec::with_capacity(16),
138 final_quality: 0.0,
139 latency_us: 0,
140 model_route: None,
141 context_ids: Vec::new(),
142 }
143 }
144
145 pub fn add_step(&mut self, step: TrajectoryStep) {
147 self.steps.push(step);
148 }
149
150 pub fn finalize(&mut self, quality: f32, latency_us: u64) {
152 self.final_quality = quality;
153 self.latency_us = latency_us;
154 }
155
156 pub fn total_reward(&self) -> f32 {
158 self.steps.iter().map(|s| s.reward).sum()
159 }
160
161 pub fn avg_reward(&self) -> f32 {
163 if self.steps.is_empty() {
164 0.0
165 } else {
166 self.total_reward() / self.steps.len() as f32
167 }
168 }
169}
170
171#[derive(Clone, Debug, Serialize, Deserialize)]
173pub struct TrajectoryStep {
174 pub activations: Vec<f32>,
176 pub attention_weights: Vec<f32>,
178 pub reward: f32,
180 pub step_idx: usize,
182 pub layer_name: Option<String>,
184}
185
186impl TrajectoryStep {
187 pub fn new(
189 activations: Vec<f32>,
190 attention_weights: Vec<f32>,
191 reward: f32,
192 step_idx: usize,
193 ) -> Self {
194 Self {
195 activations,
196 attention_weights,
197 reward,
198 step_idx,
199 layer_name: None,
200 }
201 }
202
203 pub fn with_layer(mut self, name: &str) -> Self {
205 self.layer_name = Some(name.to_string());
206 self
207 }
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize)]
212pub struct LearnedPattern {
213 pub id: u64,
215 pub centroid: Vec<f32>,
217 pub cluster_size: usize,
219 pub total_weight: f32,
221 pub avg_quality: f32,
223 pub created_at: u64,
225 pub last_accessed: u64,
227 pub access_count: u32,
229 pub pattern_type: PatternType,
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
235pub enum PatternType {
236 #[default]
237 General,
238 Reasoning,
239 Factual,
240 Creative,
241 CodeGen,
242 Conversational,
243}
244
245impl std::fmt::Display for PatternType {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 match self {
248 PatternType::General => write!(f, "general"),
249 PatternType::Reasoning => write!(f, "reasoning"),
250 PatternType::Factual => write!(f, "factual"),
251 PatternType::Creative => write!(f, "creative"),
252 PatternType::CodeGen => write!(f, "codegen"),
253 PatternType::Conversational => write!(f, "conversational"),
254 }
255 }
256}
257
258impl LearnedPattern {
259 pub fn new(id: u64, centroid: Vec<f32>) -> Self {
261 use crate::time_compat::SystemTime;
262 let now = SystemTime::now().duration_since_epoch().as_secs();
263
264 Self {
265 id,
266 centroid,
267 cluster_size: 1,
268 total_weight: 1.0,
269 avg_quality: 0.0,
270 created_at: now,
271 last_accessed: now,
272 access_count: 0,
273 pattern_type: PatternType::default(),
274 }
275 }
276
277 pub fn merge(&self, other: &Self) -> Self {
279 let total_size = self.cluster_size + other.cluster_size;
280 let w1 = self.cluster_size as f32 / total_size as f32;
281 let w2 = other.cluster_size as f32 / total_size as f32;
282
283 let centroid: Vec<f32> = self
284 .centroid
285 .iter()
286 .zip(&other.centroid)
287 .map(|(&a, &b)| a * w1 + b * w2)
288 .collect();
289
290 Self {
291 id: self.id,
292 centroid,
293 cluster_size: total_size,
294 total_weight: self.total_weight + other.total_weight,
295 avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
296 created_at: self.created_at.min(other.created_at),
297 last_accessed: self.last_accessed.max(other.last_accessed),
298 access_count: self.access_count + other.access_count,
299 pattern_type: self.pattern_type.clone(),
300 }
301 }
302
303 pub fn decay(&mut self, factor: f32) {
305 self.total_weight *= factor;
306 }
307
308 pub fn touch(&mut self) {
310 use crate::time_compat::SystemTime;
311 self.access_count += 1;
312 self.last_accessed = SystemTime::now().duration_since_epoch().as_secs();
313 }
314
315 pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
317 use crate::time_compat::SystemTime;
318 let now = SystemTime::now().duration_since_epoch().as_secs();
319 let age = now.saturating_sub(self.last_accessed);
320
321 self.avg_quality < min_quality && self.access_count < min_accesses && age > max_age_secs
322 }
323
324 pub fn similarity(&self, query: &[f32]) -> f32 {
326 if self.centroid.len() != query.len() {
327 return 0.0;
328 }
329
330 let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
331 let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
332 let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
333
334 if norm_a > 1e-8 && norm_b > 1e-8 {
335 dot / (norm_a * norm_b)
336 } else {
337 0.0
338 }
339 }
340}
341
342#[derive(Clone, Debug, Serialize, Deserialize)]
344pub struct SonaConfig {
345 pub hidden_dim: usize,
347 pub embedding_dim: usize,
349 pub micro_lora_rank: usize,
351 pub base_lora_rank: usize,
353 pub micro_lora_lr: f32,
355 pub base_lora_lr: f32,
357 pub ewc_lambda: f32,
359 pub pattern_clusters: usize,
361 pub trajectory_capacity: usize,
363 pub background_interval_ms: u64,
365 pub quality_threshold: f32,
367 pub enable_simd: bool,
369}
370
371impl Default for SonaConfig {
372 fn default() -> Self {
373 Self {
380 hidden_dim: 256,
381 embedding_dim: 256,
382 micro_lora_rank: 2, base_lora_rank: 8, micro_lora_lr: 0.002, base_lora_lr: 0.0001,
386 ewc_lambda: 2000.0, pattern_clusters: 100, trajectory_capacity: 10000,
389 background_interval_ms: 3600000, quality_threshold: 0.15, enable_simd: true,
392 }
393 }
394}
395
396impl SonaConfig {
397 pub fn max_throughput() -> Self {
399 Self {
400 hidden_dim: 256,
401 embedding_dim: 256,
402 micro_lora_rank: 2, base_lora_rank: 4, micro_lora_lr: 0.0005, base_lora_lr: 0.0001,
406 ewc_lambda: 2000.0,
407 pattern_clusters: 100,
408 trajectory_capacity: 5000,
409 background_interval_ms: 7200000, quality_threshold: 0.4,
411 enable_simd: true,
412 }
413 }
414
415 pub fn max_quality() -> Self {
417 Self {
418 hidden_dim: 256,
419 embedding_dim: 256,
420 micro_lora_rank: 2,
421 base_lora_rank: 16, micro_lora_lr: 0.002, base_lora_lr: 0.001, ewc_lambda: 2000.0,
425 pattern_clusters: 100,
426 trajectory_capacity: 20000,
427 background_interval_ms: 1800000, quality_threshold: 0.2, enable_simd: true,
430 }
431 }
432
433 pub fn edge_deployment() -> Self {
435 Self {
436 hidden_dim: 256,
437 embedding_dim: 256,
438 micro_lora_rank: 1, base_lora_rank: 4,
440 micro_lora_lr: 0.001,
441 base_lora_lr: 0.0001,
442 ewc_lambda: 1000.0,
443 pattern_clusters: 50,
444 trajectory_capacity: 200, background_interval_ms: 3600000,
446 quality_threshold: 0.5,
447 enable_simd: true,
448 }
449 }
450
451 pub fn batch_processing() -> Self {
453 Self {
454 hidden_dim: 256,
455 embedding_dim: 256,
456 micro_lora_rank: 2,
457 base_lora_rank: 8,
458 micro_lora_lr: 0.001,
459 base_lora_lr: 0.0001,
460 ewc_lambda: 2000.0,
461 pattern_clusters: 100,
462 trajectory_capacity: 10000,
463 background_interval_ms: 3600000,
464 quality_threshold: 0.3,
465 enable_simd: true,
466 }
467 }
468
469 pub fn for_ephemeral() -> Self {
474 Self {
475 hidden_dim: 256,
476 embedding_dim: 256,
477 micro_lora_rank: 2,
478 base_lora_rank: 4, micro_lora_lr: 0.002,
480 base_lora_lr: 0.0001,
481 ewc_lambda: 1000.0,
482 pattern_clusters: 50, trajectory_capacity: 500, background_interval_ms: 60000, quality_threshold: 0.3,
486 enable_simd: true,
487 }
488 }
489
490 pub fn for_coordinator() -> Self {
495 Self {
496 hidden_dim: 256,
497 embedding_dim: 256,
498 micro_lora_rank: 2,
499 base_lora_rank: 16, micro_lora_lr: 0.001, base_lora_lr: 0.0005, ewc_lambda: 2000.0, pattern_clusters: 200, trajectory_capacity: 50000, background_interval_ms: 300000, quality_threshold: 0.4, enable_simd: true,
508 }
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_learning_signal_from_trajectory() {
518 let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
519 trajectory.add_step(TrajectoryStep::new(
520 vec![0.5, 0.3, 0.2],
521 vec![0.4, 0.4, 0.2],
522 0.8,
523 0,
524 ));
525 trajectory.finalize(0.8, 1000);
526
527 let signal = LearningSignal::from_trajectory(&trajectory);
528 assert_eq!(signal.quality_score, 0.8);
529 assert_eq!(signal.gradient_estimate.len(), 3);
530 assert_eq!(signal.metadata.trajectory_id, 1);
531 }
532
533 #[test]
534 fn test_pattern_merge() {
535 let p1 = LearnedPattern {
536 id: 1,
537 centroid: vec![1.0, 0.0],
538 cluster_size: 10,
539 total_weight: 5.0,
540 avg_quality: 0.8,
541 created_at: 100,
542 last_accessed: 200,
543 access_count: 5,
544 pattern_type: PatternType::General,
545 };
546
547 let p2 = LearnedPattern {
548 id: 2,
549 centroid: vec![0.0, 1.0],
550 cluster_size: 10,
551 total_weight: 5.0,
552 avg_quality: 0.9,
553 created_at: 150,
554 last_accessed: 250,
555 access_count: 3,
556 pattern_type: PatternType::General,
557 };
558
559 let merged = p1.merge(&p2);
560 assert_eq!(merged.cluster_size, 20);
561 assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
562 assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
563 assert!((merged.avg_quality - 0.85).abs() < 1e-6);
564 }
565
566 #[test]
567 fn test_pattern_similarity() {
568 let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
569
570 assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
571 assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
572 }
573
574 #[test]
575 fn test_trajectory_rewards() {
576 let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
577 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
578 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
579 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
580
581 assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
582 assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
583 }
584}