1use std::sync::Arc;
7use std::time::Instant;
8
9use ndarray::Array2;
10use rayon::prelude::*;
11use tokio::sync::RwLock;
12use tracing::{debug, info, instrument, warn};
13
14use crate::domain::entities::{
15 EmbeddingId, GnnModelType, LearningConfig, LearningSession, RefinedEmbedding,
16 TrainingMetrics, TrainingStatus, TransitionGraph,
17};
18use crate::domain::repository::LearningRepository;
19use crate::ewc::{EwcRegularizer, EwcState};
20use crate::infrastructure::gnn_model::{GnnError, GnnModel};
21use crate::loss;
22
23#[derive(Debug, thiserror::Error)]
25pub enum LearningError {
26 #[error("Invalid configuration: {0}")]
28 InvalidConfig(String),
29
30 #[error("Training error: {0}")]
32 TrainingError(String),
33
34 #[error("Model error: {0}")]
36 ModelError(String),
37
38 #[error("Data error: {0}")]
40 DataError(String),
41
42 #[error("Repository error: {0}")]
44 RepositoryError(#[from] crate::domain::repository::RepositoryError),
45
46 #[error("GNN error: {0}")]
48 GnnError(#[from] GnnError),
49
50 #[error("Session not found: {0}")]
52 SessionNotFound(String),
53
54 #[error("A training session is already running")]
56 SessionAlreadyRunning,
57
58 #[error("Dimension mismatch: expected {expected}, got {actual}")]
60 DimensionMismatch { expected: usize, actual: usize },
61
62 #[error("Graph is empty or invalid")]
64 EmptyGraph,
65
66 #[error("Internal error: {0}")]
68 Internal(String),
69}
70
71pub type LearningResult<T> = Result<T, LearningError>;
73
74pub struct LearningService {
82 model: Arc<RwLock<GnnModel>>,
84 config: LearningConfig,
86 ewc_state: Arc<RwLock<Option<EwcState>>>,
88 repository: Option<Arc<dyn LearningRepository>>,
90 current_session: Arc<RwLock<Option<LearningSession>>>,
92}
93
94impl LearningService {
95 #[must_use]
97 pub fn new(config: LearningConfig) -> Self {
98 let model = GnnModel::new(
99 config.model_type,
100 config.input_dim,
101 config.output_dim,
102 config.hyperparameters.num_layers,
103 config.hyperparameters.hidden_dim,
104 config.hyperparameters.num_heads,
105 config.hyperparameters.dropout,
106 );
107
108 Self {
109 model: Arc::new(RwLock::new(model)),
110 config,
111 ewc_state: Arc::new(RwLock::new(None)),
112 repository: None,
113 current_session: Arc::new(RwLock::new(None)),
114 }
115 }
116
117 #[must_use]
119 pub fn with_repository(mut self, repository: Arc<dyn LearningRepository>) -> Self {
120 self.repository = Some(repository);
121 self
122 }
123
124 #[must_use]
126 pub fn config(&self) -> &LearningConfig {
127 &self.config
128 }
129
130 #[must_use]
132 pub fn model_type(&self) -> GnnModelType {
133 self.config.model_type
134 }
135
136 #[instrument(skip(self), err)]
138 pub async fn start_session(&self) -> LearningResult<String> {
139 {
141 let session = self.current_session.read().await;
142 if let Some(ref s) = *session {
143 if s.status.is_active() {
144 return Err(LearningError::SessionAlreadyRunning);
145 }
146 }
147 }
148
149 let mut session = LearningSession::new(self.config.clone());
150 session.start();
151
152 let session_id = session.id.clone();
153
154 if let Some(ref repo) = self.repository {
156 repo.save_session(&session).await?;
157 }
158
159 *self.current_session.write().await = Some(session);
160
161 info!(session_id = %session_id, "Started new learning session");
162 Ok(session_id)
163 }
164
165 #[instrument(skip(self, graph), fields(nodes = graph.num_nodes(), edges = graph.num_edges()), err)]
173 pub async fn train_epoch(&self, graph: &TransitionGraph) -> LearningResult<TrainingMetrics> {
174 let start_time = Instant::now();
175
176 if graph.num_nodes() == 0 {
178 return Err(LearningError::EmptyGraph);
179 }
180
181 if let Some(dim) = graph.embedding_dim() {
182 if dim != self.config.input_dim {
183 return Err(LearningError::DimensionMismatch {
184 expected: self.config.input_dim,
185 actual: dim,
186 });
187 }
188 }
189
190 let mut session_guard = self.current_session.write().await;
192 let session = session_guard
193 .as_mut()
194 .ok_or_else(|| LearningError::TrainingError("No active session".to_string()))?;
195
196 let current_epoch = session.metrics.epoch + 1;
197 let lr = self.compute_learning_rate(current_epoch);
198
199 let adj_matrix = self.build_adjacency_matrix(graph);
201
202 let features = self.build_feature_matrix(graph);
204
205 let mut model = self.model.write().await;
207 let output = model.forward(&features, &adj_matrix)?;
208
209 let (loss, accuracy) = self.compute_loss(graph, &output).await?;
211
212 let gradients = self.compute_gradients(graph, &features, &output, &adj_matrix, &model)?;
214 let grad_norm = self.compute_gradient_norm(&gradients);
215
216 let clipped_gradients = if let Some(clip_value) = self.config.hyperparameters.gradient_clip {
218 self.clip_gradients(gradients, clip_value)
219 } else {
220 gradients
221 };
222
223 model.update_weights(&clipped_gradients, lr, self.config.hyperparameters.weight_decay);
225
226 if let Some(ref ewc_state) = *self.ewc_state.read().await {
228 let ewc_reg = EwcRegularizer::new(self.config.hyperparameters.ewc_lambda);
229 let ewc_loss = ewc_reg.compute_penalty(&model, ewc_state);
230 debug!(ewc_loss = ewc_loss, "Applied EWC regularization");
231 }
232
233 let epoch_time_ms = start_time.elapsed().as_millis() as u64;
234
235 let metrics = TrainingMetrics {
236 loss,
237 accuracy,
238 epoch: current_epoch,
239 learning_rate: lr,
240 validation_loss: None,
241 validation_accuracy: None,
242 gradient_norm: Some(grad_norm),
243 epoch_time_ms,
244 custom_metrics: Default::default(),
245 };
246
247 session.update_metrics(metrics.clone());
249
250 drop(model); if let Some(ref repo) = self.repository {
253 repo.update_session(session).await?;
254 }
255
256 info!(
257 epoch = current_epoch,
258 loss = loss,
259 accuracy = accuracy,
260 time_ms = epoch_time_ms,
261 "Completed training epoch"
262 );
263
264 Ok(metrics)
265 }
266
267 #[instrument(skip(self, embeddings), fields(count = embeddings.len()), err)]
275 pub async fn refine_embeddings(
276 &self,
277 embeddings: &[(EmbeddingId, Vec<f32>)],
278 ) -> LearningResult<Vec<RefinedEmbedding>> {
279 if embeddings.is_empty() {
280 return Ok(Vec::new());
281 }
282
283 if let Some((_, emb)) = embeddings.first() {
285 if emb.len() != self.config.input_dim {
286 return Err(LearningError::DimensionMismatch {
287 expected: self.config.input_dim,
288 actual: emb.len(),
289 });
290 }
291 }
292
293 let model = self.model.read().await;
294
295 let n = embeddings.len();
298 let similarity_threshold = 0.5;
299
300 let mut features = Array2::zeros((n, self.config.input_dim));
302 for (i, (_, emb)) in embeddings.iter().enumerate() {
303 for (j, &val) in emb.iter().enumerate() {
304 features[[i, j]] = val;
305 }
306 }
307
308 let mut adj_matrix = Array2::<f32>::eye(n);
310 for i in 0..n {
311 for j in (i + 1)..n {
312 let sim = cosine_similarity(&embeddings[i].1, &embeddings[j].1);
313 if sim > similarity_threshold {
314 adj_matrix[[i, j]] = sim;
315 adj_matrix[[j, i]] = sim;
316 }
317 }
318 }
319
320 let degrees: Vec<f32> = (0..n)
322 .map(|i| adj_matrix.row(i).sum())
323 .collect();
324 for i in 0..n {
325 for j in 0..n {
326 if degrees[i] > 0.0 && degrees[j] > 0.0 {
327 adj_matrix[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
328 }
329 }
330 }
331
332 let output = model.forward(&features, &adj_matrix)?;
334
335 let session_id = self
337 .current_session
338 .read()
339 .await
340 .as_ref()
341 .map(|s| s.id.clone());
342
343 let refined: Vec<RefinedEmbedding> = embeddings
344 .par_iter()
345 .enumerate()
346 .map(|(i, (id, original))| {
347 let refined_vec: Vec<f32> = output.row(i).to_vec();
348
349 let delta = original
351 .iter()
352 .zip(&refined_vec)
353 .map(|(a, b)| (a - b).powi(2))
354 .sum::<f32>()
355 .sqrt();
356
357 let score = 1.0 / (1.0 + delta); let mut refined = RefinedEmbedding::new(id.clone(), refined_vec, score);
360 refined.session_id = session_id.clone();
361 refined.delta_norm = Some(delta);
362 refined.normalize();
363 refined
364 })
365 .collect();
366
367 info!(count = refined.len(), "Refined embeddings");
368
369 if let Some(ref repo) = self.repository {
371 repo.save_refined_embeddings(&refined).await?;
372 }
373
374 Ok(refined)
375 }
376
377 #[instrument(skip(self, from, to), err)]
386 pub async fn predict_edge(&self, from: &[f32], to: &[f32]) -> LearningResult<f32> {
387 if from.len() != self.config.input_dim {
389 return Err(LearningError::DimensionMismatch {
390 expected: self.config.input_dim,
391 actual: from.len(),
392 });
393 }
394 if to.len() != self.config.input_dim {
395 return Err(LearningError::DimensionMismatch {
396 expected: self.config.input_dim,
397 actual: to.len(),
398 });
399 }
400
401 let model = self.model.read().await;
402
403 let mut features = Array2::zeros((2, self.config.input_dim));
405 for (j, &val) in from.iter().enumerate() {
406 features[[0, j]] = val;
407 }
408 for (j, &val) in to.iter().enumerate() {
409 features[[1, j]] = val;
410 }
411
412 let adj_matrix = Array2::<f32>::eye(2);
414
415 let output = model.forward(&features, &adj_matrix)?;
417
418 let from_refined: Vec<f32> = output.row(0).to_vec();
420 let to_refined: Vec<f32> = output.row(1).to_vec();
421
422 let similarity = cosine_similarity(&from_refined, &to_refined);
423 let weight = (similarity + 1.0) / 2.0; Ok(weight)
426 }
427
428 #[instrument(skip(self), err)]
430 pub async fn complete_session(&self) -> LearningResult<()> {
431 let mut session_guard = self.current_session.write().await;
432
433 if let Some(ref mut session) = *session_guard {
434 session.complete();
435
436 if let Some(ref repo) = self.repository {
440 repo.update_session(session).await?;
441 }
442
443 info!(session_id = %session.id, "Completed learning session");
444 }
445
446 Ok(())
447 }
448
449 #[instrument(skip(self, error), err)]
451 pub async fn fail_session(&self, error: impl Into<String>) -> LearningResult<()> {
452 let error_msg = error.into();
453 let mut session_guard = self.current_session.write().await;
454
455 if let Some(ref mut session) = *session_guard {
456 session.fail(&error_msg);
457
458 if let Some(ref repo) = self.repository {
459 repo.update_session(session).await?;
460 }
461
462 warn!(session_id = %session.id, error = %error_msg, "Failed learning session");
463 }
464
465 Ok(())
466 }
467
468 pub async fn get_session(&self) -> Option<LearningSession> {
470 self.current_session.read().await.clone()
471 }
472
473 #[instrument(skip(self, graph), err)]
475 pub async fn consolidate_ewc(&self, graph: &TransitionGraph) -> LearningResult<()> {
476 let model = self.model.read().await;
477 let fisher = self.compute_fisher_information(&model, graph)?;
478 let state = EwcState::new(model.get_parameters(), fisher);
479
480 *self.ewc_state.write().await = Some(state);
481
482 info!("Consolidated EWC state");
483 Ok(())
484 }
485
486 fn build_adjacency_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
489 let n = graph.num_nodes();
490 let mut adj = Array2::zeros((n, n));
491
492 for i in 0..n {
494 adj[[i, i]] = 1.0;
495 }
496
497 for &(from, to, weight) in &graph.edges {
499 adj[[from, to]] = weight;
500 if !graph.directed {
501 adj[[to, from]] = weight;
502 }
503 }
504
505 let degrees: Vec<f32> = (0..n).map(|i| adj.row(i).sum()).collect();
507 for i in 0..n {
508 for j in 0..n {
509 if degrees[i] > 0.0 && degrees[j] > 0.0 {
510 adj[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
511 }
512 }
513 }
514
515 adj
516 }
517
518 fn build_feature_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
519 let n = graph.num_nodes();
520 let dim = graph.embedding_dim().unwrap_or(self.config.input_dim);
521 let mut features = Array2::zeros((n, dim));
522
523 for (i, emb) in graph.embeddings.iter().enumerate() {
524 for (j, &val) in emb.iter().enumerate() {
525 features[[i, j]] = val;
526 }
527 }
528
529 features
530 }
531
532 async fn compute_loss(
533 &self,
534 graph: &TransitionGraph,
535 output: &Array2<f32>,
536 ) -> LearningResult<(f32, f32)> {
537 let n = graph.num_nodes();
538 if n == 0 {
539 return Ok((0.0, 0.0));
540 }
541
542 let mut total_loss = 0.0;
543 let mut correct = 0usize;
544 let mut total = 0usize;
545
546 let hp = &self.config.hyperparameters;
547
548 for &(from, to, weight) in &graph.edges {
550 let anchor: Vec<f32> = output.row(from).to_vec();
551 let positive: Vec<f32> = output.row(to).to_vec();
552
553 let negatives: Vec<Vec<f32>> = (0..n)
555 .filter(|&i| i != from && i != to)
556 .take(hp.negative_ratio)
557 .map(|i| output.row(i).to_vec())
558 .collect();
559
560 if !negatives.is_empty() {
561 let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
562
563 let loss = loss::info_nce_loss(&anchor, &positive, &neg_refs, hp.temperature);
565 total_loss += loss * weight;
566 }
567
568 let pos_sim = cosine_similarity(&anchor, &positive);
570 let all_closer = (0..n)
571 .filter(|&i| i != from && i != to)
572 .all(|i| {
573 let neg: Vec<f32> = output.row(i).to_vec();
574 cosine_similarity(&anchor, &neg) < pos_sim
575 });
576
577 if all_closer {
578 correct += 1;
579 }
580 total += 1;
581 }
582
583 let avg_loss = if graph.edges.is_empty() {
584 0.0
585 } else {
586 total_loss / graph.edges.len() as f32
587 };
588
589 let accuracy = if total == 0 {
590 0.0
591 } else {
592 correct as f32 / total as f32
593 };
594
595 Ok((avg_loss, accuracy))
596 }
597
598 fn compute_gradients(
599 &self,
600 _graph: &TransitionGraph,
601 features: &Array2<f32>,
602 output: &Array2<f32>,
603 _adj_matrix: &Array2<f32>,
604 model: &GnnModel,
605 ) -> LearningResult<Vec<Array2<f32>>> {
606 let num_layers = model.num_layers();
609 let mut gradients = Vec::with_capacity(num_layers);
610 let batch_size = features.nrows() as f32;
611
612 for layer_idx in 0..num_layers {
613 let (in_dim, out_dim) = model.layer_dims(layer_idx);
614
615 let output_centered = &output.mapv(|x| x - output.mean().unwrap_or(0.0));
618
619 let grad = if layer_idx == 0 {
621 let output_slice = if output.ncols() >= out_dim {
623 output_centered.slice(ndarray::s![.., ..out_dim]).to_owned()
624 } else {
625 Array2::zeros((output.nrows(), out_dim))
626 };
627 let feat_slice = if features.ncols() >= in_dim {
628 features.slice(ndarray::s![.., ..in_dim]).to_owned()
629 } else {
630 Array2::zeros((features.nrows(), in_dim))
631 };
632 feat_slice.t().dot(&output_slice) / batch_size
633 } else {
634 let variance = output.var(0.0);
636 Array2::from_elem((in_dim, out_dim), 0.01 * variance.sqrt())
637 };
638
639 let scaled_grad = grad.t().to_owned();
641 gradients.push(scaled_grad);
642 }
643
644 Ok(gradients)
645 }
646
647 fn compute_gradient_norm(&self, gradients: &[Array2<f32>]) -> f32 {
648 gradients
649 .iter()
650 .map(|g| g.iter().map(|&x| x * x).sum::<f32>())
651 .sum::<f32>()
652 .sqrt()
653 }
654
655 fn clip_gradients(&self, gradients: Vec<Array2<f32>>, max_norm: f32) -> Vec<Array2<f32>> {
656 let current_norm = self.compute_gradient_norm(&gradients);
657 if current_norm <= max_norm {
658 return gradients;
659 }
660
661 let scale = max_norm / current_norm;
662 gradients.into_iter().map(|g| g * scale).collect()
663 }
664
665 fn compute_learning_rate(&self, epoch: usize) -> f32 {
666 let base_lr = self.config.hyperparameters.learning_rate;
667 let total_epochs = self.config.hyperparameters.epochs;
668
669 let progress = epoch as f32 / total_epochs as f32;
671 let cosine_factor = (1.0 + (progress * std::f32::consts::PI).cos()) / 2.0;
672
673 base_lr * cosine_factor
674 }
675
676 fn compute_fisher_information(
677 &self,
678 _model: &GnnModel,
679 _graph: &TransitionGraph,
680 ) -> LearningResult<crate::ewc::FisherInformation> {
681 Ok(crate::ewc::FisherInformation::default())
684 }
685}
686
687fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
689 if a.len() != b.len() || a.is_empty() {
690 return 0.0;
691 }
692
693 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
694 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
695 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
696
697 if norm_a < 1e-10 || norm_b < 1e-10 {
698 return 0.0;
699 }
700
701 dot / (norm_a * norm_b)
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_cosine_similarity() {
710 let a = vec![1.0, 0.0, 0.0];
711 let b = vec![1.0, 0.0, 0.0];
712 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
713
714 let c = vec![0.0, 1.0, 0.0];
715 assert!((cosine_similarity(&a, &c)).abs() < 1e-6);
716
717 let d = vec![-1.0, 0.0, 0.0];
718 assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
719 }
720
721 #[tokio::test]
722 async fn test_learning_service_creation() {
723 let config = LearningConfig::default();
724 let service = LearningService::new(config.clone());
725
726 assert_eq!(service.model_type(), GnnModelType::Gcn);
727 assert_eq!(service.config().input_dim, 768);
728 }
729
730 #[tokio::test]
731 async fn test_start_session() {
732 let config = LearningConfig::default();
733 let service = LearningService::new(config);
734
735 let session_id = service.start_session().await.unwrap();
736 assert!(!session_id.is_empty());
737
738 let session = service.get_session().await.unwrap();
739 assert_eq!(session.status, TrainingStatus::Running);
740 }
741
742 #[tokio::test]
743 async fn test_train_epoch() {
744 let mut config = LearningConfig::default();
745 config.input_dim = 8;
746 config.output_dim = 4;
747 config.hyperparameters.hidden_dim = 8;
748
749 let service = LearningService::new(config);
750 service.start_session().await.unwrap();
751
752 let mut graph = TransitionGraph::new();
753 graph.add_node(EmbeddingId::new("n1"), vec![0.1; 8], None);
754 graph.add_node(EmbeddingId::new("n2"), vec![0.2; 8], None);
755 graph.add_node(EmbeddingId::new("n3"), vec![0.3; 8], None);
756 graph.add_edge(0, 1, 0.8);
757 graph.add_edge(1, 2, 0.7);
758
759 let metrics = service.train_epoch(&graph).await.unwrap();
760 assert_eq!(metrics.epoch, 1);
761 assert!(metrics.loss >= 0.0);
762 }
763
764 #[tokio::test]
765 async fn test_refine_embeddings() {
766 let mut config = LearningConfig::default();
767 config.input_dim = 8;
768 config.output_dim = 4;
769 config.hyperparameters.hidden_dim = 8;
770
771 let service = LearningService::new(config);
772 service.start_session().await.unwrap();
773
774 let embeddings = vec![
775 (EmbeddingId::new("e1"), vec![0.1; 8]),
776 (EmbeddingId::new("e2"), vec![0.2; 8]),
777 ];
778
779 let refined = service.refine_embeddings(&embeddings).await.unwrap();
780 assert_eq!(refined.len(), 2);
781 assert_eq!(refined[0].dim(), 4); }
783
784 #[tokio::test]
785 async fn test_predict_edge() {
786 let mut config = LearningConfig::default();
787 config.input_dim = 8;
788 config.output_dim = 4;
789 config.hyperparameters.hidden_dim = 8;
790
791 let service = LearningService::new(config);
792
793 let from = vec![0.1; 8];
794 let to = vec![0.1; 8]; let weight = service.predict_edge(&from, &to).await.unwrap();
797 assert!(weight >= 0.0 && weight <= 1.0);
798 }
799
800 #[tokio::test]
801 async fn test_empty_graph_error() {
802 let config = LearningConfig::default();
803 let service = LearningService::new(config);
804 service.start_session().await.unwrap();
805
806 let graph = TransitionGraph::new();
807 let result = service.train_epoch(&graph).await;
808
809 assert!(matches!(result, Err(LearningError::EmptyGraph)));
810 }
811
812 #[tokio::test]
813 async fn test_dimension_mismatch() {
814 let mut config = LearningConfig::default();
815 config.input_dim = 768;
816
817 let service = LearningService::new(config);
818 service.start_session().await.unwrap();
819
820 let mut graph = TransitionGraph::new();
821 graph.add_node(EmbeddingId::new("n1"), vec![0.1; 128], None); let result = service.train_epoch(&graph).await;
824 assert!(matches!(
825 result,
826 Err(LearningError::DimensionMismatch { .. })
827 ));
828 }
829}