1use crate::time_compat::Instant;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct LearningSignal {
12 pub query_embedding: Vec<f32>,
14 pub gradient_estimate: Vec<f32>,
16 pub quality_score: f32,
18 #[serde(skip)]
20 pub timestamp: Option<Instant>,
21 pub metadata: SignalMetadata,
23}
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct SignalMetadata {
28 pub trajectory_id: u64,
30 pub step_count: usize,
32 pub model_route: Option<String>,
34 pub tags: HashMap<String, String>,
36}
37
38impl LearningSignal {
39 pub fn from_trajectory(trajectory: &QueryTrajectory) -> Self {
41 let gradient = Self::estimate_gradient(trajectory);
42
43 Self {
44 query_embedding: trajectory.query_embedding.clone(),
45 gradient_estimate: gradient,
46 quality_score: trajectory.final_quality,
47 timestamp: Some(Instant::now()),
48 metadata: SignalMetadata {
49 trajectory_id: trajectory.id,
50 step_count: trajectory.steps.len(),
51 model_route: trajectory.model_route.clone(),
52 tags: HashMap::new(),
53 },
54 }
55 }
56
57 pub fn with_gradient(embedding: Vec<f32>, gradient: Vec<f32>, quality: f32) -> Self {
59 Self {
60 query_embedding: embedding,
61 gradient_estimate: gradient,
62 quality_score: quality,
63 timestamp: Some(Instant::now()),
64 metadata: SignalMetadata::default(),
65 }
66 }
67
68 fn estimate_gradient(trajectory: &QueryTrajectory) -> Vec<f32> {
70 if trajectory.steps.is_empty() {
71 return trajectory.query_embedding.clone();
72 }
73
74 let dim = trajectory.query_embedding.len();
75 let mut gradient = vec![0.0f32; dim];
76
77 let baseline =
79 trajectory.steps.iter().map(|s| s.reward).sum::<f32>() / trajectory.steps.len() as f32;
80
81 for step in &trajectory.steps {
83 let advantage = step.reward - baseline;
84 let activation_len = step.activations.len().min(dim);
85 for i in 0..activation_len {
86 gradient[i] += advantage * step.activations[i];
87 }
88 }
89
90 let norm: f32 = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
92 if norm > 1e-8 {
93 gradient.iter_mut().for_each(|x| *x /= norm);
94 }
95
96 gradient
97 }
98
99 pub fn scaled_gradient(&self) -> Vec<f32> {
101 self.gradient_estimate
102 .iter()
103 .map(|&g| g * self.quality_score)
104 .collect()
105 }
106}
107
108#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct QueryTrajectory {
111 pub id: u64,
113 pub query_embedding: Vec<f32>,
115 pub steps: Vec<TrajectoryStep>,
117 pub final_quality: f32,
119 pub latency_us: u64,
121 pub model_route: Option<String>,
123 pub context_ids: Vec<String>,
125}
126
127impl QueryTrajectory {
128 pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
130 Self {
131 id,
132 query_embedding,
133 steps: Vec::with_capacity(16),
134 final_quality: 0.0,
135 latency_us: 0,
136 model_route: None,
137 context_ids: Vec::new(),
138 }
139 }
140
141 pub fn add_step(&mut self, step: TrajectoryStep) {
143 self.steps.push(step);
144 }
145
146 pub fn finalize(&mut self, quality: f32, latency_us: u64) {
148 self.final_quality = quality;
149 self.latency_us = latency_us;
150 }
151
152 pub fn total_reward(&self) -> f32 {
154 self.steps.iter().map(|s| s.reward).sum()
155 }
156
157 pub fn avg_reward(&self) -> f32 {
159 if self.steps.is_empty() {
160 0.0
161 } else {
162 self.total_reward() / self.steps.len() as f32
163 }
164 }
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize)]
169pub struct TrajectoryStep {
170 pub activations: Vec<f32>,
172 pub attention_weights: Vec<f32>,
174 pub reward: f32,
176 pub step_idx: usize,
178 pub layer_name: Option<String>,
180}
181
182impl TrajectoryStep {
183 pub fn new(
185 activations: Vec<f32>,
186 attention_weights: Vec<f32>,
187 reward: f32,
188 step_idx: usize,
189 ) -> Self {
190 Self {
191 activations,
192 attention_weights,
193 reward,
194 step_idx,
195 layer_name: None,
196 }
197 }
198
199 pub fn with_layer(mut self, name: &str) -> Self {
201 self.layer_name = Some(name.to_string());
202 self
203 }
204}
205
206#[derive(Clone, Debug, Serialize, Deserialize)]
208pub struct LearnedPattern {
209 pub id: u64,
211 pub centroid: Vec<f32>,
213 pub cluster_size: usize,
215 pub total_weight: f32,
217 pub avg_quality: f32,
219 pub created_at: u64,
221 pub last_accessed: u64,
223 pub access_count: u32,
225 pub pattern_type: PatternType,
227}
228
229#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
231pub enum PatternType {
232 #[default]
233 General,
234 Reasoning,
235 Factual,
236 Creative,
237 CodeGen,
238 Conversational,
239}
240
241impl std::fmt::Display for PatternType {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 match self {
244 PatternType::General => write!(f, "general"),
245 PatternType::Reasoning => write!(f, "reasoning"),
246 PatternType::Factual => write!(f, "factual"),
247 PatternType::Creative => write!(f, "creative"),
248 PatternType::CodeGen => write!(f, "codegen"),
249 PatternType::Conversational => write!(f, "conversational"),
250 }
251 }
252}
253
254impl LearnedPattern {
255 pub fn new(id: u64, centroid: Vec<f32>) -> Self {
257 use crate::time_compat::SystemTime;
258 let now = SystemTime::now().duration_since_epoch().as_secs();
259
260 Self {
261 id,
262 centroid,
263 cluster_size: 1,
264 total_weight: 1.0,
265 avg_quality: 0.0,
266 created_at: now,
267 last_accessed: now,
268 access_count: 0,
269 pattern_type: PatternType::default(),
270 }
271 }
272
273 pub fn merge(&self, other: &Self) -> Self {
275 let total_size = self.cluster_size + other.cluster_size;
276 let w1 = self.cluster_size as f32 / total_size as f32;
277 let w2 = other.cluster_size as f32 / total_size as f32;
278
279 let centroid: Vec<f32> = self
280 .centroid
281 .iter()
282 .zip(&other.centroid)
283 .map(|(&a, &b)| a * w1 + b * w2)
284 .collect();
285
286 Self {
287 id: self.id,
288 centroid,
289 cluster_size: total_size,
290 total_weight: self.total_weight + other.total_weight,
291 avg_quality: self.avg_quality * w1 + other.avg_quality * w2,
292 created_at: self.created_at.min(other.created_at),
293 last_accessed: self.last_accessed.max(other.last_accessed),
294 access_count: self.access_count + other.access_count,
295 pattern_type: self.pattern_type.clone(),
296 }
297 }
298
299 pub fn decay(&mut self, factor: f32) {
301 self.total_weight *= factor;
302 }
303
304 pub fn touch(&mut self) {
306 use crate::time_compat::SystemTime;
307 self.access_count += 1;
308 self.last_accessed = SystemTime::now().duration_since_epoch().as_secs();
309 }
310
311 pub fn should_prune(&self, min_quality: f32, min_accesses: u32, max_age_secs: u64) -> bool {
313 use crate::time_compat::SystemTime;
314 let now = SystemTime::now().duration_since_epoch().as_secs();
315 let age = now.saturating_sub(self.last_accessed);
316
317 self.avg_quality < min_quality && self.access_count < min_accesses && age > max_age_secs
318 }
319
320 pub fn similarity(&self, query: &[f32]) -> f32 {
322 if self.centroid.len() != query.len() {
323 return 0.0;
324 }
325
326 let dot: f32 = self.centroid.iter().zip(query).map(|(a, b)| a * b).sum();
327 let norm_a: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
328 let norm_b: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
329
330 if norm_a > 1e-8 && norm_b > 1e-8 {
331 dot / (norm_a * norm_b)
332 } else {
333 0.0
334 }
335 }
336}
337
338#[derive(Clone, Debug, Serialize, Deserialize)]
340pub struct SonaConfig {
341 pub hidden_dim: usize,
343 pub embedding_dim: usize,
345 pub micro_lora_rank: usize,
347 pub base_lora_rank: usize,
349 pub micro_lora_lr: f32,
351 pub base_lora_lr: f32,
353 pub ewc_lambda: f32,
355 pub pattern_clusters: usize,
357 pub trajectory_capacity: usize,
359 pub background_interval_ms: u64,
361 pub quality_threshold: f32,
363 pub enable_simd: bool,
365}
366
367impl Default for SonaConfig {
368 fn default() -> Self {
369 Self {
376 hidden_dim: 256,
377 embedding_dim: 256,
378 micro_lora_rank: 2, base_lora_rank: 8, micro_lora_lr: 0.002, base_lora_lr: 0.0001,
382 ewc_lambda: 2000.0, pattern_clusters: 100, trajectory_capacity: 10000,
385 background_interval_ms: 3600000, quality_threshold: 0.3, enable_simd: true,
388 }
389 }
390}
391
392impl SonaConfig {
393 pub fn max_throughput() -> Self {
395 Self {
396 hidden_dim: 256,
397 embedding_dim: 256,
398 micro_lora_rank: 2, base_lora_rank: 4, micro_lora_lr: 0.0005, base_lora_lr: 0.0001,
402 ewc_lambda: 2000.0,
403 pattern_clusters: 100,
404 trajectory_capacity: 5000,
405 background_interval_ms: 7200000, quality_threshold: 0.4,
407 enable_simd: true,
408 }
409 }
410
411 pub fn max_quality() -> Self {
413 Self {
414 hidden_dim: 256,
415 embedding_dim: 256,
416 micro_lora_rank: 2,
417 base_lora_rank: 16, micro_lora_lr: 0.002, base_lora_lr: 0.001, ewc_lambda: 2000.0,
421 pattern_clusters: 100,
422 trajectory_capacity: 20000,
423 background_interval_ms: 1800000, quality_threshold: 0.2, enable_simd: true,
426 }
427 }
428
429 pub fn edge_deployment() -> Self {
431 Self {
432 hidden_dim: 256,
433 embedding_dim: 256,
434 micro_lora_rank: 1, base_lora_rank: 4,
436 micro_lora_lr: 0.001,
437 base_lora_lr: 0.0001,
438 ewc_lambda: 1000.0,
439 pattern_clusters: 50,
440 trajectory_capacity: 200, background_interval_ms: 3600000,
442 quality_threshold: 0.5,
443 enable_simd: true,
444 }
445 }
446
447 pub fn batch_processing() -> Self {
449 Self {
450 hidden_dim: 256,
451 embedding_dim: 256,
452 micro_lora_rank: 2,
453 base_lora_rank: 8,
454 micro_lora_lr: 0.001,
455 base_lora_lr: 0.0001,
456 ewc_lambda: 2000.0,
457 pattern_clusters: 100,
458 trajectory_capacity: 10000,
459 background_interval_ms: 3600000,
460 quality_threshold: 0.3,
461 enable_simd: true,
462 }
463 }
464
465 pub fn for_ephemeral() -> Self {
470 Self {
471 hidden_dim: 256,
472 embedding_dim: 256,
473 micro_lora_rank: 2,
474 base_lora_rank: 4, micro_lora_lr: 0.002,
476 base_lora_lr: 0.0001,
477 ewc_lambda: 1000.0,
478 pattern_clusters: 50, trajectory_capacity: 500, background_interval_ms: 60000, quality_threshold: 0.3,
482 enable_simd: true,
483 }
484 }
485
486 pub fn for_coordinator() -> Self {
491 Self {
492 hidden_dim: 256,
493 embedding_dim: 256,
494 micro_lora_rank: 2,
495 base_lora_rank: 16, micro_lora_lr: 0.001, base_lora_lr: 0.0005, ewc_lambda: 2000.0, pattern_clusters: 200, trajectory_capacity: 50000, background_interval_ms: 300000, quality_threshold: 0.4, enable_simd: true,
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_learning_signal_from_trajectory() {
514 let mut trajectory = QueryTrajectory::new(1, vec![0.1, 0.2, 0.3]);
515 trajectory.add_step(TrajectoryStep::new(
516 vec![0.5, 0.3, 0.2],
517 vec![0.4, 0.4, 0.2],
518 0.8,
519 0,
520 ));
521 trajectory.finalize(0.8, 1000);
522
523 let signal = LearningSignal::from_trajectory(&trajectory);
524 assert_eq!(signal.quality_score, 0.8);
525 assert_eq!(signal.gradient_estimate.len(), 3);
526 assert_eq!(signal.metadata.trajectory_id, 1);
527 }
528
529 #[test]
530 fn test_pattern_merge() {
531 let p1 = LearnedPattern {
532 id: 1,
533 centroid: vec![1.0, 0.0],
534 cluster_size: 10,
535 total_weight: 5.0,
536 avg_quality: 0.8,
537 created_at: 100,
538 last_accessed: 200,
539 access_count: 5,
540 pattern_type: PatternType::General,
541 };
542
543 let p2 = LearnedPattern {
544 id: 2,
545 centroid: vec![0.0, 1.0],
546 cluster_size: 10,
547 total_weight: 5.0,
548 avg_quality: 0.9,
549 created_at: 150,
550 last_accessed: 250,
551 access_count: 3,
552 pattern_type: PatternType::General,
553 };
554
555 let merged = p1.merge(&p2);
556 assert_eq!(merged.cluster_size, 20);
557 assert!((merged.centroid[0] - 0.5).abs() < 1e-6);
558 assert!((merged.centroid[1] - 0.5).abs() < 1e-6);
559 assert!((merged.avg_quality - 0.85).abs() < 1e-6);
560 }
561
562 #[test]
563 fn test_pattern_similarity() {
564 let pattern = LearnedPattern::new(1, vec![1.0, 0.0, 0.0]);
565
566 assert!((pattern.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
567 assert!(pattern.similarity(&[0.0, 1.0, 0.0]).abs() < 1e-6);
568 }
569
570 #[test]
571 fn test_trajectory_rewards() {
572 let mut trajectory = QueryTrajectory::new(1, vec![0.1]);
573 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.5, 0));
574 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.7, 1));
575 trajectory.add_step(TrajectoryStep::new(vec![], vec![], 0.9, 2));
576
577 assert!((trajectory.total_reward() - 2.1).abs() < 1e-6);
578 assert!((trajectory.avg_reward() - 0.7).abs() < 1e-6);
579 }
580}