sevensense_learning/application/
services.rs

1//! Learning service implementation.
2//!
3//! Provides the main application service for GNN-based learning,
4//! including training, embedding refinement, and edge prediction.
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use ndarray::Array2;
10use rayon::prelude::*;
11use tokio::sync::RwLock;
12use tracing::{debug, info, instrument, warn};
13
14use crate::domain::entities::{
15    EmbeddingId, GnnModelType, LearningConfig, LearningSession, RefinedEmbedding,
16    TrainingMetrics, TrainingStatus, TransitionGraph,
17};
18use crate::domain::repository::LearningRepository;
19use crate::ewc::{EwcRegularizer, EwcState};
20use crate::infrastructure::gnn_model::{GnnError, GnnModel};
21use crate::loss;
22
23/// Error type for learning service operations
24#[derive(Debug, thiserror::Error)]
25pub enum LearningError {
26    /// Invalid configuration
27    #[error("Invalid configuration: {0}")]
28    InvalidConfig(String),
29
30    /// Training error
31    #[error("Training error: {0}")]
32    TrainingError(String),
33
34    /// Model error
35    #[error("Model error: {0}")]
36    ModelError(String),
37
38    /// Data error
39    #[error("Data error: {0}")]
40    DataError(String),
41
42    /// Repository error
43    #[error("Repository error: {0}")]
44    RepositoryError(#[from] crate::domain::repository::RepositoryError),
45
46    /// GNN model error
47    #[error("GNN error: {0}")]
48    GnnError(#[from] GnnError),
49
50    /// Session not found
51    #[error("Session not found: {0}")]
52    SessionNotFound(String),
53
54    /// Session already running
55    #[error("A training session is already running")]
56    SessionAlreadyRunning,
57
58    /// Dimension mismatch
59    #[error("Dimension mismatch: expected {expected}, got {actual}")]
60    DimensionMismatch { expected: usize, actual: usize },
61
62    /// Empty graph
63    #[error("Graph is empty or invalid")]
64    EmptyGraph,
65
66    /// Internal error
67    #[error("Internal error: {0}")]
68    Internal(String),
69}
70
71/// Result type for learning operations
72pub type LearningResult<T> = Result<T, LearningError>;
73
74/// Main learning service for GNN-based embedding refinement.
75///
76/// This service manages:
77/// - GNN model training on transition graphs
78/// - Embedding refinement through message passing
79/// - Edge prediction for relationship modeling
80/// - Continual learning with EWC regularization
81pub struct LearningService {
82    /// The GNN model
83    model: Arc<RwLock<GnnModel>>,
84    /// Service configuration
85    config: LearningConfig,
86    /// EWC state for continual learning
87    ewc_state: Arc<RwLock<Option<EwcState>>>,
88    /// Optional repository for persistence
89    repository: Option<Arc<dyn LearningRepository>>,
90    /// Current active session
91    current_session: Arc<RwLock<Option<LearningSession>>>,
92}
93
94impl LearningService {
95    /// Create a new learning service with the given configuration
96    #[must_use]
97    pub fn new(config: LearningConfig) -> Self {
98        let model = GnnModel::new(
99            config.model_type,
100            config.input_dim,
101            config.output_dim,
102            config.hyperparameters.num_layers,
103            config.hyperparameters.hidden_dim,
104            config.hyperparameters.num_heads,
105            config.hyperparameters.dropout,
106        );
107
108        Self {
109            model: Arc::new(RwLock::new(model)),
110            config,
111            ewc_state: Arc::new(RwLock::new(None)),
112            repository: None,
113            current_session: Arc::new(RwLock::new(None)),
114        }
115    }
116
117    /// Create a learning service with a repository
118    #[must_use]
119    pub fn with_repository(mut self, repository: Arc<dyn LearningRepository>) -> Self {
120        self.repository = Some(repository);
121        self
122    }
123
124    /// Get the current configuration
125    #[must_use]
126    pub fn config(&self) -> &LearningConfig {
127        &self.config
128    }
129
130    /// Get the model type
131    #[must_use]
132    pub fn model_type(&self) -> GnnModelType {
133        self.config.model_type
134    }
135
136    /// Start a new training session
137    #[instrument(skip(self), err)]
138    pub async fn start_session(&self) -> LearningResult<String> {
139        // Check if a session is already running
140        {
141            let session = self.current_session.read().await;
142            if let Some(ref s) = *session {
143                if s.status.is_active() {
144                    return Err(LearningError::SessionAlreadyRunning);
145                }
146            }
147        }
148
149        let mut session = LearningSession::new(self.config.clone());
150        session.start();
151
152        let session_id = session.id.clone();
153
154        // Persist if repository available
155        if let Some(ref repo) = self.repository {
156            repo.save_session(&session).await?;
157        }
158
159        *self.current_session.write().await = Some(session);
160
161        info!(session_id = %session_id, "Started new learning session");
162        Ok(session_id)
163    }
164
165    /// Train a single epoch on the transition graph
166    ///
167    /// # Arguments
168    /// * `graph` - The transition graph to train on
169    ///
170    /// # Returns
171    /// Training metrics for the epoch
172    #[instrument(skip(self, graph), fields(nodes = graph.num_nodes(), edges = graph.num_edges()), err)]
173    pub async fn train_epoch(&self, graph: &TransitionGraph) -> LearningResult<TrainingMetrics> {
174        let start_time = Instant::now();
175
176        // Validate graph
177        if graph.num_nodes() == 0 {
178            return Err(LearningError::EmptyGraph);
179        }
180
181        if let Some(dim) = graph.embedding_dim() {
182            if dim != self.config.input_dim {
183                return Err(LearningError::DimensionMismatch {
184                    expected: self.config.input_dim,
185                    actual: dim,
186                });
187            }
188        }
189
190        // Ensure we have an active session
191        let mut session_guard = self.current_session.write().await;
192        let session = session_guard
193            .as_mut()
194            .ok_or_else(|| LearningError::TrainingError("No active session".to_string()))?;
195
196        let current_epoch = session.metrics.epoch + 1;
197        let lr = self.compute_learning_rate(current_epoch);
198
199        // Build adjacency matrix
200        let adj_matrix = self.build_adjacency_matrix(graph);
201
202        // Build feature matrix from embeddings
203        let features = self.build_feature_matrix(graph);
204
205        // Forward pass through GNN
206        let mut model = self.model.write().await;
207        let output = model.forward(&features, &adj_matrix)?;
208
209        // Compute loss using contrastive learning
210        let (loss, accuracy) = self.compute_loss(graph, &output).await?;
211
212        // Compute gradients and update weights
213        let gradients = self.compute_gradients(graph, &features, &output, &adj_matrix, &model)?;
214        let grad_norm = self.compute_gradient_norm(&gradients);
215
216        // Apply gradient clipping if configured
217        let clipped_gradients = if let Some(clip_value) = self.config.hyperparameters.gradient_clip {
218            self.clip_gradients(gradients, clip_value)
219        } else {
220            gradients
221        };
222
223        // Update model weights
224        model.update_weights(&clipped_gradients, lr, self.config.hyperparameters.weight_decay);
225
226        // Apply EWC regularization if available
227        if let Some(ref ewc_state) = *self.ewc_state.read().await {
228            let ewc_reg = EwcRegularizer::new(self.config.hyperparameters.ewc_lambda);
229            let ewc_loss = ewc_reg.compute_penalty(&model, ewc_state);
230            debug!(ewc_loss = ewc_loss, "Applied EWC regularization");
231        }
232
233        let epoch_time_ms = start_time.elapsed().as_millis() as u64;
234
235        let metrics = TrainingMetrics {
236            loss,
237            accuracy,
238            epoch: current_epoch,
239            learning_rate: lr,
240            validation_loss: None,
241            validation_accuracy: None,
242            gradient_norm: Some(grad_norm),
243            epoch_time_ms,
244            custom_metrics: Default::default(),
245        };
246
247        // Update session metrics
248        session.update_metrics(metrics.clone());
249
250        // Persist session if repository available
251        drop(model); // Release write lock before async operation
252        if let Some(ref repo) = self.repository {
253            repo.update_session(session).await?;
254        }
255
256        info!(
257            epoch = current_epoch,
258            loss = loss,
259            accuracy = accuracy,
260            time_ms = epoch_time_ms,
261            "Completed training epoch"
262        );
263
264        Ok(metrics)
265    }
266
267    /// Refine embeddings using the trained GNN model
268    ///
269    /// # Arguments
270    /// * `embeddings` - Input embeddings to refine
271    ///
272    /// # Returns
273    /// Refined embeddings with quality scores
274    #[instrument(skip(self, embeddings), fields(count = embeddings.len()), err)]
275    pub async fn refine_embeddings(
276        &self,
277        embeddings: &[(EmbeddingId, Vec<f32>)],
278    ) -> LearningResult<Vec<RefinedEmbedding>> {
279        if embeddings.is_empty() {
280            return Ok(Vec::new());
281        }
282
283        // Validate dimensions
284        if let Some((_, emb)) = embeddings.first() {
285            if emb.len() != self.config.input_dim {
286                return Err(LearningError::DimensionMismatch {
287                    expected: self.config.input_dim,
288                    actual: emb.len(),
289                });
290            }
291        }
292
293        let model = self.model.read().await;
294
295        // Build a simple graph where each embedding is a node
296        // Connected based on cosine similarity
297        let n = embeddings.len();
298        let similarity_threshold = 0.5;
299
300        // Build feature matrix
301        let mut features = Array2::zeros((n, self.config.input_dim));
302        for (i, (_, emb)) in embeddings.iter().enumerate() {
303            for (j, &val) in emb.iter().enumerate() {
304                features[[i, j]] = val;
305            }
306        }
307
308        // Build adjacency matrix based on similarity
309        let mut adj_matrix = Array2::<f32>::eye(n);
310        for i in 0..n {
311            for j in (i + 1)..n {
312                let sim = cosine_similarity(&embeddings[i].1, &embeddings[j].1);
313                if sim > similarity_threshold {
314                    adj_matrix[[i, j]] = sim;
315                    adj_matrix[[j, i]] = sim;
316                }
317            }
318        }
319
320        // Normalize adjacency matrix
321        let degrees: Vec<f32> = (0..n)
322            .map(|i| adj_matrix.row(i).sum())
323            .collect();
324        for i in 0..n {
325            for j in 0..n {
326                if degrees[i] > 0.0 && degrees[j] > 0.0 {
327                    adj_matrix[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
328                }
329            }
330        }
331
332        // Forward pass
333        let output = model.forward(&features, &adj_matrix)?;
334
335        // Create refined embeddings
336        let session_id = self
337            .current_session
338            .read()
339            .await
340            .as_ref()
341            .map(|s| s.id.clone());
342
343        let refined: Vec<RefinedEmbedding> = embeddings
344            .par_iter()
345            .enumerate()
346            .map(|(i, (id, original))| {
347                let refined_vec: Vec<f32> = output.row(i).to_vec();
348
349                // Compute refinement score based on change magnitude
350                let delta = original
351                    .iter()
352                    .zip(&refined_vec)
353                    .map(|(a, b)| (a - b).powi(2))
354                    .sum::<f32>()
355                    .sqrt();
356
357                let score = 1.0 / (1.0 + delta); // Higher score for smaller changes
358
359                let mut refined = RefinedEmbedding::new(id.clone(), refined_vec, score);
360                refined.session_id = session_id.clone();
361                refined.delta_norm = Some(delta);
362                refined.normalize();
363                refined
364            })
365            .collect();
366
367        info!(count = refined.len(), "Refined embeddings");
368
369        // Persist if repository available
370        if let Some(ref repo) = self.repository {
371            repo.save_refined_embeddings(&refined).await?;
372        }
373
374        Ok(refined)
375    }
376
377    /// Predict edge weight between two embeddings
378    ///
379    /// # Arguments
380    /// * `from` - Source embedding
381    /// * `to` - Target embedding
382    ///
383    /// # Returns
384    /// Predicted edge weight (0.0 to 1.0)
385    #[instrument(skip(self, from, to), err)]
386    pub async fn predict_edge(&self, from: &[f32], to: &[f32]) -> LearningResult<f32> {
387        // Validate dimensions
388        if from.len() != self.config.input_dim {
389            return Err(LearningError::DimensionMismatch {
390                expected: self.config.input_dim,
391                actual: from.len(),
392            });
393        }
394        if to.len() != self.config.input_dim {
395            return Err(LearningError::DimensionMismatch {
396                expected: self.config.input_dim,
397                actual: to.len(),
398            });
399        }
400
401        let model = self.model.read().await;
402
403        // Create a mini-graph with two nodes
404        let mut features = Array2::zeros((2, self.config.input_dim));
405        for (j, &val) in from.iter().enumerate() {
406            features[[0, j]] = val;
407        }
408        for (j, &val) in to.iter().enumerate() {
409            features[[1, j]] = val;
410        }
411
412        // Simple adjacency (self-loops only initially)
413        let adj_matrix = Array2::<f32>::eye(2);
414
415        // Forward pass
416        let output = model.forward(&features, &adj_matrix)?;
417
418        // Compute similarity of refined embeddings
419        let from_refined: Vec<f32> = output.row(0).to_vec();
420        let to_refined: Vec<f32> = output.row(1).to_vec();
421
422        let similarity = cosine_similarity(&from_refined, &to_refined);
423        let weight = (similarity + 1.0) / 2.0; // Map from [-1, 1] to [0, 1]
424
425        Ok(weight)
426    }
427
428    /// Complete the current training session
429    #[instrument(skip(self), err)]
430    pub async fn complete_session(&self) -> LearningResult<()> {
431        let mut session_guard = self.current_session.write().await;
432
433        if let Some(ref mut session) = *session_guard {
434            session.complete();
435
436            // Compute and store Fisher information for EWC
437            // This would be done in a real implementation with the final model state
438
439            if let Some(ref repo) = self.repository {
440                repo.update_session(session).await?;
441            }
442
443            info!(session_id = %session.id, "Completed learning session");
444        }
445
446        Ok(())
447    }
448
449    /// Fail the current session with an error
450    #[instrument(skip(self, error), err)]
451    pub async fn fail_session(&self, error: impl Into<String>) -> LearningResult<()> {
452        let error_msg = error.into();
453        let mut session_guard = self.current_session.write().await;
454
455        if let Some(ref mut session) = *session_guard {
456            session.fail(&error_msg);
457
458            if let Some(ref repo) = self.repository {
459                repo.update_session(session).await?;
460            }
461
462            warn!(session_id = %session.id, error = %error_msg, "Failed learning session");
463        }
464
465        Ok(())
466    }
467
468    /// Get the current session status
469    pub async fn get_session(&self) -> Option<LearningSession> {
470        self.current_session.read().await.clone()
471    }
472
473    /// Save EWC state from current model for future regularization
474    #[instrument(skip(self, graph), err)]
475    pub async fn consolidate_ewc(&self, graph: &TransitionGraph) -> LearningResult<()> {
476        let model = self.model.read().await;
477        let fisher = self.compute_fisher_information(&model, graph)?;
478        let state = EwcState::new(model.get_parameters(), fisher);
479
480        *self.ewc_state.write().await = Some(state);
481
482        info!("Consolidated EWC state");
483        Ok(())
484    }
485
486    // =========== Private Helper Methods ===========
487
488    fn build_adjacency_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
489        let n = graph.num_nodes();
490        let mut adj = Array2::zeros((n, n));
491
492        // Add self-loops
493        for i in 0..n {
494            adj[[i, i]] = 1.0;
495        }
496
497        // Add edges
498        for &(from, to, weight) in &graph.edges {
499            adj[[from, to]] = weight;
500            if !graph.directed {
501                adj[[to, from]] = weight;
502            }
503        }
504
505        // Symmetric normalization: D^(-1/2) * A * D^(-1/2)
506        let degrees: Vec<f32> = (0..n).map(|i| adj.row(i).sum()).collect();
507        for i in 0..n {
508            for j in 0..n {
509                if degrees[i] > 0.0 && degrees[j] > 0.0 {
510                    adj[[i, j]] /= (degrees[i] * degrees[j]).sqrt();
511                }
512            }
513        }
514
515        adj
516    }
517
518    fn build_feature_matrix(&self, graph: &TransitionGraph) -> Array2<f32> {
519        let n = graph.num_nodes();
520        let dim = graph.embedding_dim().unwrap_or(self.config.input_dim);
521        let mut features = Array2::zeros((n, dim));
522
523        for (i, emb) in graph.embeddings.iter().enumerate() {
524            for (j, &val) in emb.iter().enumerate() {
525                features[[i, j]] = val;
526            }
527        }
528
529        features
530    }
531
532    async fn compute_loss(
533        &self,
534        graph: &TransitionGraph,
535        output: &Array2<f32>,
536    ) -> LearningResult<(f32, f32)> {
537        let n = graph.num_nodes();
538        if n == 0 {
539            return Ok((0.0, 0.0));
540        }
541
542        let mut total_loss = 0.0;
543        let mut correct = 0usize;
544        let mut total = 0usize;
545
546        let hp = &self.config.hyperparameters;
547
548        // For each edge, compute contrastive loss
549        for &(from, to, weight) in &graph.edges {
550            let anchor: Vec<f32> = output.row(from).to_vec();
551            let positive: Vec<f32> = output.row(to).to_vec();
552
553            // Sample negative nodes (nodes not connected to anchor)
554            let negatives: Vec<Vec<f32>> = (0..n)
555                .filter(|&i| i != from && i != to)
556                .take(hp.negative_ratio)
557                .map(|i| output.row(i).to_vec())
558                .collect();
559
560            if !negatives.is_empty() {
561                let neg_refs: Vec<&[f32]> = negatives.iter().map(|v| v.as_slice()).collect();
562
563                // InfoNCE loss
564                let loss = loss::info_nce_loss(&anchor, &positive, &neg_refs, hp.temperature);
565                total_loss += loss * weight;
566            }
567
568            // Compute accuracy based on whether positive is closer than negatives
569            let pos_sim = cosine_similarity(&anchor, &positive);
570            let all_closer = (0..n)
571                .filter(|&i| i != from && i != to)
572                .all(|i| {
573                    let neg: Vec<f32> = output.row(i).to_vec();
574                    cosine_similarity(&anchor, &neg) < pos_sim
575                });
576
577            if all_closer {
578                correct += 1;
579            }
580            total += 1;
581        }
582
583        let avg_loss = if graph.edges.is_empty() {
584            0.0
585        } else {
586            total_loss / graph.edges.len() as f32
587        };
588
589        let accuracy = if total == 0 {
590            0.0
591        } else {
592            correct as f32 / total as f32
593        };
594
595        Ok((avg_loss, accuracy))
596    }
597
598    fn compute_gradients(
599        &self,
600        _graph: &TransitionGraph,
601        features: &Array2<f32>,
602        output: &Array2<f32>,
603        _adj_matrix: &Array2<f32>,
604        model: &GnnModel,
605    ) -> LearningResult<Vec<Array2<f32>>> {
606        // Simplified gradient computation
607        // In practice, this would use automatic differentiation
608        let num_layers = model.num_layers();
609        let mut gradients = Vec::with_capacity(num_layers);
610        let batch_size = features.nrows() as f32;
611
612        for layer_idx in 0..num_layers {
613            let (in_dim, out_dim) = model.layer_dims(layer_idx);
614
615            // Compute gradient approximation based on output variance
616            // This is a simplified placeholder - real backprop would use chain rule
617            let output_centered = &output.mapv(|x| x - output.mean().unwrap_or(0.0));
618
619            // Approximate gradient as outer product scaled by learning signal
620            let grad = if layer_idx == 0 {
621                // Input layer: gradient is features^T * output_signal / batch_size
622                let output_slice = if output.ncols() >= out_dim {
623                    output_centered.slice(ndarray::s![.., ..out_dim]).to_owned()
624                } else {
625                    Array2::zeros((output.nrows(), out_dim))
626                };
627                let feat_slice = if features.ncols() >= in_dim {
628                    features.slice(ndarray::s![.., ..in_dim]).to_owned()
629                } else {
630                    Array2::zeros((features.nrows(), in_dim))
631                };
632                feat_slice.t().dot(&output_slice) / batch_size
633            } else {
634                // Hidden layers: use small random gradients scaled by output variance
635                let variance = output.var(0.0);
636                Array2::from_elem((in_dim, out_dim), 0.01 * variance.sqrt())
637            };
638
639            // Reshape to (out_dim, in_dim)
640            let scaled_grad = grad.t().to_owned();
641            gradients.push(scaled_grad);
642        }
643
644        Ok(gradients)
645    }
646
647    fn compute_gradient_norm(&self, gradients: &[Array2<f32>]) -> f32 {
648        gradients
649            .iter()
650            .map(|g| g.iter().map(|&x| x * x).sum::<f32>())
651            .sum::<f32>()
652            .sqrt()
653    }
654
655    fn clip_gradients(&self, gradients: Vec<Array2<f32>>, max_norm: f32) -> Vec<Array2<f32>> {
656        let current_norm = self.compute_gradient_norm(&gradients);
657        if current_norm <= max_norm {
658            return gradients;
659        }
660
661        let scale = max_norm / current_norm;
662        gradients.into_iter().map(|g| g * scale).collect()
663    }
664
665    fn compute_learning_rate(&self, epoch: usize) -> f32 {
666        let base_lr = self.config.hyperparameters.learning_rate;
667        let total_epochs = self.config.hyperparameters.epochs;
668
669        // Cosine annealing schedule
670        let progress = epoch as f32 / total_epochs as f32;
671        let cosine_factor = (1.0 + (progress * std::f32::consts::PI).cos()) / 2.0;
672
673        base_lr * cosine_factor
674    }
675
676    fn compute_fisher_information(
677        &self,
678        _model: &GnnModel,
679        _graph: &TransitionGraph,
680    ) -> LearningResult<crate::ewc::FisherInformation> {
681        // Simplified Fisher information computation
682        // In practice, this would compute the diagonal of the Fisher matrix
683        Ok(crate::ewc::FisherInformation::default())
684    }
685}
686
687/// Compute cosine similarity between two vectors
688fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
689    if a.len() != b.len() || a.is_empty() {
690        return 0.0;
691    }
692
693    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
694    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
695    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
696
697    if norm_a < 1e-10 || norm_b < 1e-10 {
698        return 0.0;
699    }
700
701    dot / (norm_a * norm_b)
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    #[test]
709    fn test_cosine_similarity() {
710        let a = vec![1.0, 0.0, 0.0];
711        let b = vec![1.0, 0.0, 0.0];
712        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
713
714        let c = vec![0.0, 1.0, 0.0];
715        assert!((cosine_similarity(&a, &c)).abs() < 1e-6);
716
717        let d = vec![-1.0, 0.0, 0.0];
718        assert!((cosine_similarity(&a, &d) + 1.0).abs() < 1e-6);
719    }
720
721    #[tokio::test]
722    async fn test_learning_service_creation() {
723        let config = LearningConfig::default();
724        let service = LearningService::new(config.clone());
725
726        assert_eq!(service.model_type(), GnnModelType::Gcn);
727        assert_eq!(service.config().input_dim, 768);
728    }
729
730    #[tokio::test]
731    async fn test_start_session() {
732        let config = LearningConfig::default();
733        let service = LearningService::new(config);
734
735        let session_id = service.start_session().await.unwrap();
736        assert!(!session_id.is_empty());
737
738        let session = service.get_session().await.unwrap();
739        assert_eq!(session.status, TrainingStatus::Running);
740    }
741
742    #[tokio::test]
743    async fn test_train_epoch() {
744        let mut config = LearningConfig::default();
745        config.input_dim = 8;
746        config.output_dim = 4;
747        config.hyperparameters.hidden_dim = 8;
748
749        let service = LearningService::new(config);
750        service.start_session().await.unwrap();
751
752        let mut graph = TransitionGraph::new();
753        graph.add_node(EmbeddingId::new("n1"), vec![0.1; 8], None);
754        graph.add_node(EmbeddingId::new("n2"), vec![0.2; 8], None);
755        graph.add_node(EmbeddingId::new("n3"), vec![0.3; 8], None);
756        graph.add_edge(0, 1, 0.8);
757        graph.add_edge(1, 2, 0.7);
758
759        let metrics = service.train_epoch(&graph).await.unwrap();
760        assert_eq!(metrics.epoch, 1);
761        assert!(metrics.loss >= 0.0);
762    }
763
764    #[tokio::test]
765    async fn test_refine_embeddings() {
766        let mut config = LearningConfig::default();
767        config.input_dim = 8;
768        config.output_dim = 4;
769        config.hyperparameters.hidden_dim = 8;
770
771        let service = LearningService::new(config);
772        service.start_session().await.unwrap();
773
774        let embeddings = vec![
775            (EmbeddingId::new("e1"), vec![0.1; 8]),
776            (EmbeddingId::new("e2"), vec![0.2; 8]),
777        ];
778
779        let refined = service.refine_embeddings(&embeddings).await.unwrap();
780        assert_eq!(refined.len(), 2);
781        assert_eq!(refined[0].dim(), 4); // Output dimension
782    }
783
784    #[tokio::test]
785    async fn test_predict_edge() {
786        let mut config = LearningConfig::default();
787        config.input_dim = 8;
788        config.output_dim = 4;
789        config.hyperparameters.hidden_dim = 8;
790
791        let service = LearningService::new(config);
792
793        let from = vec![0.1; 8];
794        let to = vec![0.1; 8]; // Same embedding should have high weight
795
796        let weight = service.predict_edge(&from, &to).await.unwrap();
797        assert!(weight >= 0.0 && weight <= 1.0);
798    }
799
800    #[tokio::test]
801    async fn test_empty_graph_error() {
802        let config = LearningConfig::default();
803        let service = LearningService::new(config);
804        service.start_session().await.unwrap();
805
806        let graph = TransitionGraph::new();
807        let result = service.train_epoch(&graph).await;
808
809        assert!(matches!(result, Err(LearningError::EmptyGraph)));
810    }
811
812    #[tokio::test]
813    async fn test_dimension_mismatch() {
814        let mut config = LearningConfig::default();
815        config.input_dim = 768;
816
817        let service = LearningService::new(config);
818        service.start_session().await.unwrap();
819
820        let mut graph = TransitionGraph::new();
821        graph.add_node(EmbeddingId::new("n1"), vec![0.1; 128], None); // Wrong dimension
822
823        let result = service.train_epoch(&graph).await;
824        assert!(matches!(
825            result,
826            Err(LearningError::DimensionMismatch { .. })
827        ));
828    }
829}