sevensense_learning/domain/
entities.rs

1//! Domain entities for the learning bounded context.
2//!
3//! This module defines the core domain entities including:
4//! - Learning sessions for tracking training state
5//! - GNN model types and training metrics
6//! - Transition graphs for embedding relationships
7//! - Refined embeddings as output of the learning process
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use uuid::Uuid;
13
14/// Unique identifier for an embedding
15#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub struct EmbeddingId(pub String);
17
18impl EmbeddingId {
19    /// Create a new embedding ID
20    #[must_use]
21    pub fn new(id: impl Into<String>) -> Self {
22        Self(id.into())
23    }
24
25    /// Generate a new random embedding ID
26    #[must_use]
27    pub fn generate() -> Self {
28        Self(Uuid::new_v4().to_string())
29    }
30
31    /// Get the inner string value
32    #[must_use]
33    pub fn as_str(&self) -> &str {
34        &self.0
35    }
36}
37
38impl From<String> for EmbeddingId {
39    fn from(s: String) -> Self {
40        Self(s)
41    }
42}
43
44impl From<&str> for EmbeddingId {
45    fn from(s: &str) -> Self {
46        Self(s.to_string())
47    }
48}
49
50/// Timestamp type alias for consistency
51pub type Timestamp = DateTime<Utc>;
52
53/// Types of GNN models supported by the learning system
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum GnnModelType {
57    /// Graph Convolutional Network
58    /// Uses spectral convolutions on graph-structured data
59    Gcn,
60    /// GraphSAGE (SAmple and aggreGatE)
61    /// Learns node embeddings through neighborhood sampling and aggregation
62    GraphSage,
63    /// Graph Attention Network
64    /// Uses attention mechanisms to weight neighbor contributions
65    Gat,
66}
67
68impl Default for GnnModelType {
69    fn default() -> Self {
70        Self::Gcn
71    }
72}
73
74impl GnnModelType {
75    /// Get the number of learnable parameters per layer (approximate)
76    #[must_use]
77    pub fn params_per_layer(&self, input_dim: usize, output_dim: usize) -> usize {
78        match self {
79            Self::Gcn => input_dim * output_dim + output_dim,
80            Self::GraphSage => 2 * input_dim * output_dim + output_dim,
81            Self::Gat => input_dim * output_dim + 2 * output_dim,
82        }
83    }
84
85    /// Get recommended number of attention heads (only relevant for GAT)
86    #[must_use]
87    pub fn recommended_heads(&self) -> usize {
88        match self {
89            Self::Gat => 8,
90            _ => 1,
91        }
92    }
93}
94
95/// Status of a training session
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum TrainingStatus {
99    /// Session created but training not started
100    Pending,
101    /// Training is currently running
102    Running,
103    /// Training completed successfully
104    Completed,
105    /// Training failed with an error
106    Failed,
107    /// Training was paused
108    Paused,
109    /// Training was cancelled by user
110    Cancelled,
111}
112
113impl TrainingStatus {
114    /// Check if the status represents a terminal state
115    #[must_use]
116    pub fn is_terminal(&self) -> bool {
117        matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
118    }
119
120    /// Check if training can be resumed from this status
121    #[must_use]
122    pub fn can_resume(&self) -> bool {
123        matches!(self, Self::Paused)
124    }
125
126    /// Check if training is active
127    #[must_use]
128    pub fn is_active(&self) -> bool {
129        matches!(self, Self::Running)
130    }
131}
132
133/// Metrics collected during training
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct TrainingMetrics {
136    /// Current loss value
137    pub loss: f32,
138    /// Training accuracy (0.0 to 1.0)
139    pub accuracy: f32,
140    /// Current epoch number
141    pub epoch: usize,
142    /// Current learning rate
143    pub learning_rate: f32,
144    /// Validation loss (if validation set provided)
145    pub validation_loss: Option<f32>,
146    /// Validation accuracy
147    pub validation_accuracy: Option<f32>,
148    /// Gradient norm (for monitoring stability)
149    pub gradient_norm: Option<f32>,
150    /// Time taken for this epoch in milliseconds
151    pub epoch_time_ms: u64,
152    /// Additional custom metrics
153    #[serde(default)]
154    pub custom_metrics: HashMap<String, f32>,
155}
156
157impl Default for TrainingMetrics {
158    fn default() -> Self {
159        Self {
160            loss: f32::INFINITY,
161            accuracy: 0.0,
162            epoch: 0,
163            learning_rate: 0.001,
164            validation_loss: None,
165            validation_accuracy: None,
166            gradient_norm: None,
167            epoch_time_ms: 0,
168            custom_metrics: HashMap::new(),
169        }
170    }
171}
172
173impl TrainingMetrics {
174    /// Create new metrics for an epoch
175    #[must_use]
176    pub fn new(epoch: usize, loss: f32, accuracy: f32, learning_rate: f32) -> Self {
177        Self {
178            loss,
179            accuracy,
180            epoch,
181            learning_rate,
182            ..Default::default()
183        }
184    }
185
186    /// Set validation metrics
187    #[must_use]
188    pub fn with_validation(mut self, loss: f32, accuracy: f32) -> Self {
189        self.validation_loss = Some(loss);
190        self.validation_accuracy = Some(accuracy);
191        self
192    }
193
194    /// Add a custom metric
195    pub fn add_custom_metric(&mut self, name: impl Into<String>, value: f32) {
196        self.custom_metrics.insert(name.into(), value);
197    }
198
199    /// Check if training is converging (loss is decreasing)
200    #[must_use]
201    pub fn is_improving(&self, previous: &Self) -> bool {
202        self.loss < previous.loss
203    }
204}
205
206/// Hyperparameters for training
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct HyperParameters {
209    /// Initial learning rate
210    pub learning_rate: f32,
211    /// Weight decay (L2 regularization)
212    pub weight_decay: f32,
213    /// Dropout probability
214    pub dropout: f32,
215    /// Number of training epochs
216    pub epochs: usize,
217    /// Batch size for training
218    pub batch_size: usize,
219    /// Early stopping patience (epochs without improvement)
220    pub early_stopping_patience: Option<usize>,
221    /// Gradient clipping threshold
222    pub gradient_clip: Option<f32>,
223    /// Temperature for contrastive loss
224    pub temperature: f32,
225    /// Margin for triplet loss
226    pub triplet_margin: f32,
227    /// EWC lambda (importance of old task knowledge)
228    pub ewc_lambda: f32,
229    /// Number of GNN layers
230    pub num_layers: usize,
231    /// Hidden dimension size
232    pub hidden_dim: usize,
233    /// Number of attention heads (for GAT)
234    pub num_heads: usize,
235    /// Negative sample ratio for contrastive learning
236    pub negative_ratio: usize,
237}
238
239impl Default for HyperParameters {
240    fn default() -> Self {
241        Self {
242            learning_rate: 0.001,
243            weight_decay: 5e-4,
244            dropout: 0.5,
245            epochs: 200,
246            batch_size: 32,
247            early_stopping_patience: Some(20),
248            gradient_clip: Some(1.0),
249            temperature: 0.07,
250            triplet_margin: 1.0,
251            ewc_lambda: 5000.0,
252            num_layers: 2,
253            hidden_dim: 256,
254            num_heads: 8,
255            negative_ratio: 5,
256        }
257    }
258}
259
260/// Configuration for the learning service
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct LearningConfig {
263    /// Type of GNN model to use
264    pub model_type: GnnModelType,
265    /// Input embedding dimension
266    pub input_dim: usize,
267    /// Output embedding dimension
268    pub output_dim: usize,
269    /// Training hyperparameters
270    pub hyperparameters: HyperParameters,
271    /// Enable mixed precision training
272    pub mixed_precision: bool,
273    /// Device to use for training
274    pub device: Device,
275    /// Random seed for reproducibility
276    pub seed: Option<u64>,
277    /// Enable gradient checkpointing to save memory
278    pub gradient_checkpointing: bool,
279}
280
281impl Default for LearningConfig {
282    fn default() -> Self {
283        Self {
284            model_type: GnnModelType::Gcn,
285            input_dim: 768,
286            output_dim: 256,
287            hyperparameters: HyperParameters::default(),
288            mixed_precision: false,
289            device: Device::Cpu,
290            seed: None,
291            gradient_checkpointing: false,
292        }
293    }
294}
295
296/// Device for computation
297#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
298#[serde(rename_all = "snake_case")]
299pub enum Device {
300    /// CPU computation
301    #[default]
302    Cpu,
303    /// CUDA GPU computation
304    Cuda(usize),
305    /// Metal GPU (Apple Silicon)
306    Metal,
307}
308
309/// A learning session tracking the state of a training run
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct LearningSession {
312    /// Unique session identifier
313    pub id: String,
314    /// Type of GNN model being trained
315    pub model_type: GnnModelType,
316    /// Current training status
317    pub status: TrainingStatus,
318    /// Current training metrics
319    pub metrics: TrainingMetrics,
320    /// When the session was started
321    pub started_at: Timestamp,
322    /// When the session was last updated
323    pub updated_at: Timestamp,
324    /// When the session completed (if applicable)
325    pub completed_at: Option<Timestamp>,
326    /// Configuration used for this session
327    pub config: LearningConfig,
328    /// History of metrics per epoch
329    #[serde(default)]
330    pub metrics_history: Vec<TrainingMetrics>,
331    /// Best metrics achieved during training
332    pub best_metrics: Option<TrainingMetrics>,
333    /// Error message if training failed
334    pub error_message: Option<String>,
335    /// Number of checkpoints saved
336    pub checkpoint_count: usize,
337}
338
339impl LearningSession {
340    /// Create a new learning session
341    #[must_use]
342    pub fn new(config: LearningConfig) -> Self {
343        let now = Utc::now();
344        Self {
345            id: Uuid::new_v4().to_string(),
346            model_type: config.model_type,
347            status: TrainingStatus::Pending,
348            metrics: TrainingMetrics::default(),
349            started_at: now,
350            updated_at: now,
351            completed_at: None,
352            config,
353            metrics_history: Vec::new(),
354            best_metrics: None,
355            error_message: None,
356            checkpoint_count: 0,
357        }
358    }
359
360    /// Start the training session
361    pub fn start(&mut self) {
362        self.status = TrainingStatus::Running;
363        self.updated_at = Utc::now();
364    }
365
366    /// Update metrics for a completed epoch
367    pub fn update_metrics(&mut self, metrics: TrainingMetrics) {
368        // Update best metrics if this is an improvement
369        if self.best_metrics.is_none()
370            || metrics.loss < self.best_metrics.as_ref().unwrap().loss
371        {
372            self.best_metrics = Some(metrics.clone());
373        }
374
375        self.metrics = metrics.clone();
376        self.metrics_history.push(metrics);
377        self.updated_at = Utc::now();
378    }
379
380    /// Mark the session as completed
381    pub fn complete(&mut self) {
382        self.status = TrainingStatus::Completed;
383        self.completed_at = Some(Utc::now());
384        self.updated_at = Utc::now();
385    }
386
387    /// Mark the session as failed
388    pub fn fail(&mut self, error: impl Into<String>) {
389        self.status = TrainingStatus::Failed;
390        self.error_message = Some(error.into());
391        self.completed_at = Some(Utc::now());
392        self.updated_at = Utc::now();
393    }
394
395    /// Pause the training session
396    pub fn pause(&mut self) {
397        if self.status == TrainingStatus::Running {
398            self.status = TrainingStatus::Paused;
399            self.updated_at = Utc::now();
400        }
401    }
402
403    /// Resume a paused session
404    pub fn resume(&mut self) {
405        if self.status == TrainingStatus::Paused {
406            self.status = TrainingStatus::Running;
407            self.updated_at = Utc::now();
408        }
409    }
410
411    /// Get the training duration
412    #[must_use]
413    pub fn duration(&self) -> chrono::Duration {
414        let end = self.completed_at.unwrap_or_else(Utc::now);
415        end - self.started_at
416    }
417
418    /// Check if training should stop early
419    #[must_use]
420    pub fn should_early_stop(&self) -> bool {
421        if let Some(patience) = self.config.hyperparameters.early_stopping_patience {
422            if self.metrics_history.len() <= patience {
423                return false;
424            }
425
426            let best_epoch = self
427                .metrics_history
428                .iter()
429                .enumerate()
430                .min_by(|(_, a), (_, b)| a.loss.partial_cmp(&b.loss).unwrap())
431                .map(|(i, _)| i)
432                .unwrap_or(0);
433
434            self.metrics_history.len() - best_epoch > patience
435        } else {
436            false
437        }
438    }
439}
440
441/// A node in the transition graph
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct GraphNode {
444    /// Embedding ID for this node
445    pub id: EmbeddingId,
446    /// The embedding vector
447    pub embedding: Vec<f32>,
448    /// Optional node features
449    pub features: Option<Vec<f32>>,
450    /// Node label (for supervised learning)
451    pub label: Option<usize>,
452    /// Metadata associated with this node
453    #[serde(default)]
454    pub metadata: HashMap<String, String>,
455}
456
457impl GraphNode {
458    /// Create a new graph node
459    #[must_use]
460    pub fn new(id: EmbeddingId, embedding: Vec<f32>) -> Self {
461        Self {
462            id,
463            embedding,
464            features: None,
465            label: None,
466            metadata: HashMap::new(),
467        }
468    }
469
470    /// Get the embedding dimension
471    #[must_use]
472    pub fn dim(&self) -> usize {
473        self.embedding.len()
474    }
475}
476
477/// An edge in the transition graph
478#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct GraphEdge {
480    /// Source node index
481    pub from: usize,
482    /// Target node index
483    pub to: usize,
484    /// Edge weight (e.g., similarity score)
485    pub weight: f32,
486    /// Edge type for heterogeneous graphs
487    pub edge_type: Option<String>,
488}
489
490impl GraphEdge {
491    /// Create a new edge
492    #[must_use]
493    pub fn new(from: usize, to: usize, weight: f32) -> Self {
494        Self {
495            from,
496            to,
497            weight,
498            edge_type: None,
499        }
500    }
501
502    /// Create a typed edge
503    #[must_use]
504    pub fn typed(from: usize, to: usize, weight: f32, edge_type: impl Into<String>) -> Self {
505        Self {
506            from,
507            to,
508            weight,
509            edge_type: Some(edge_type.into()),
510        }
511    }
512}
513
514/// A graph representing transitions between embeddings
515#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct TransitionGraph {
517    /// Nodes in the graph (embedding IDs)
518    pub nodes: Vec<EmbeddingId>,
519    /// Node embeddings (parallel to nodes)
520    pub embeddings: Vec<Vec<f32>>,
521    /// Edges as (from_index, to_index, weight) tuples
522    pub edges: Vec<(usize, usize, f32)>,
523    /// Optional node labels for supervised learning
524    #[serde(default)]
525    pub labels: Vec<Option<usize>>,
526    /// Number of unique classes (if labeled)
527    pub num_classes: Option<usize>,
528    /// Whether the graph is directed
529    pub directed: bool,
530}
531
532impl Default for TransitionGraph {
533    fn default() -> Self {
534        Self::new()
535    }
536}
537
538impl TransitionGraph {
539    /// Create a new empty transition graph
540    #[must_use]
541    pub fn new() -> Self {
542        Self {
543            nodes: Vec::new(),
544            embeddings: Vec::new(),
545            edges: Vec::new(),
546            labels: Vec::new(),
547            num_classes: None,
548            directed: true,
549        }
550    }
551
552    /// Create an undirected graph
553    #[must_use]
554    pub fn undirected() -> Self {
555        Self {
556            directed: false,
557            ..Self::new()
558        }
559    }
560
561    /// Add a node to the graph
562    pub fn add_node(&mut self, id: EmbeddingId, embedding: Vec<f32>, label: Option<usize>) {
563        self.nodes.push(id);
564        self.embeddings.push(embedding);
565        self.labels.push(label);
566    }
567
568    /// Add an edge to the graph
569    pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
570        assert!(from < self.nodes.len(), "Invalid 'from' node index");
571        assert!(to < self.nodes.len(), "Invalid 'to' node index");
572        self.edges.push((from, to, weight));
573
574        // For undirected graphs, add reverse edge
575        if !self.directed {
576            self.edges.push((to, from, weight));
577        }
578    }
579
580    /// Get the number of nodes
581    #[must_use]
582    pub fn num_nodes(&self) -> usize {
583        self.nodes.len()
584    }
585
586    /// Get the number of edges
587    #[must_use]
588    pub fn num_edges(&self) -> usize {
589        self.edges.len()
590    }
591
592    /// Get the embedding dimension (assumes all embeddings have same dimension)
593    #[must_use]
594    pub fn embedding_dim(&self) -> Option<usize> {
595        self.embeddings.first().map(Vec::len)
596    }
597
598    /// Get neighbors of a node
599    #[must_use]
600    pub fn neighbors(&self, node_idx: usize) -> Vec<(usize, f32)> {
601        self.edges
602            .iter()
603            .filter(|(from, _, _)| *from == node_idx)
604            .map(|(_, to, weight)| (*to, *weight))
605            .collect()
606    }
607
608    /// Get the adjacency list representation
609    #[must_use]
610    pub fn adjacency_list(&self) -> Vec<Vec<(usize, f32)>> {
611        let mut adj = vec![Vec::new(); self.nodes.len()];
612        for &(from, to, weight) in &self.edges {
613            adj[from].push((to, weight));
614        }
615        adj
616    }
617
618    /// Compute node degrees
619    #[must_use]
620    pub fn degrees(&self) -> Vec<usize> {
621        let mut degrees = vec![0; self.nodes.len()];
622        for &(from, to, _) in &self.edges {
623            degrees[from] += 1;
624            if !self.directed {
625                degrees[to] += 1;
626            }
627        }
628        degrees
629    }
630
631    /// Validate the graph structure
632    pub fn validate(&self) -> Result<(), String> {
633        if self.nodes.len() != self.embeddings.len() {
634            return Err("Nodes and embeddings count mismatch".to_string());
635        }
636        if !self.labels.is_empty() && self.labels.len() != self.nodes.len() {
637            return Err("Labels count mismatch".to_string());
638        }
639        for &(from, to, _) in &self.edges {
640            if from >= self.nodes.len() || to >= self.nodes.len() {
641                return Err(format!("Invalid edge: ({from}, {to})"));
642            }
643        }
644        Ok(())
645    }
646}
647
648/// A refined embedding produced by the learning process
649#[derive(Debug, Clone, Serialize, Deserialize)]
650pub struct RefinedEmbedding {
651    /// ID of the original embedding that was refined
652    pub original_id: EmbeddingId,
653    /// The refined embedding vector
654    pub refined_vector: Vec<f32>,
655    /// Score indicating quality of refinement (0.0 to 1.0)
656    pub refinement_score: f32,
657    /// The session that produced this refinement
658    pub session_id: Option<String>,
659    /// Timestamp of refinement
660    pub refined_at: Timestamp,
661    /// Delta from original (optional, for analysis)
662    pub delta_norm: Option<f32>,
663    /// Confidence in the refinement
664    pub confidence: f32,
665}
666
667impl RefinedEmbedding {
668    /// Create a new refined embedding
669    #[must_use]
670    pub fn new(
671        original_id: EmbeddingId,
672        refined_vector: Vec<f32>,
673        refinement_score: f32,
674    ) -> Self {
675        Self {
676            original_id,
677            refined_vector,
678            refinement_score,
679            session_id: None,
680            refined_at: Utc::now(),
681            delta_norm: None,
682            confidence: refinement_score,
683        }
684    }
685
686    /// Compute the delta norm from original embedding
687    pub fn compute_delta(&mut self, original: &[f32]) {
688        if original.len() != self.refined_vector.len() {
689            return;
690        }
691        let delta: f32 = original
692            .iter()
693            .zip(&self.refined_vector)
694            .map(|(a, b)| (a - b).powi(2))
695            .sum();
696        self.delta_norm = Some(delta.sqrt());
697    }
698
699    /// Get the embedding dimension
700    #[must_use]
701    pub fn dim(&self) -> usize {
702        self.refined_vector.len()
703    }
704
705    /// Normalize the refined vector to unit length
706    pub fn normalize(&mut self) {
707        let norm: f32 = self.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
708        if norm > 1e-10 {
709            for x in &mut self.refined_vector {
710                *x /= norm;
711            }
712        }
713    }
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719
720    #[test]
721    fn test_embedding_id() {
722        let id = EmbeddingId::new("test-123");
723        assert_eq!(id.as_str(), "test-123");
724
725        let generated = EmbeddingId::generate();
726        assert!(!generated.as_str().is_empty());
727    }
728
729    #[test]
730    fn test_gnn_model_type() {
731        assert_eq!(GnnModelType::default(), GnnModelType::Gcn);
732        assert_eq!(GnnModelType::Gat.recommended_heads(), 8);
733        assert_eq!(GnnModelType::Gcn.recommended_heads(), 1);
734    }
735
736    #[test]
737    fn test_training_status() {
738        assert!(!TrainingStatus::Running.is_terminal());
739        assert!(TrainingStatus::Completed.is_terminal());
740        assert!(TrainingStatus::Failed.is_terminal());
741        assert!(TrainingStatus::Paused.can_resume());
742        assert!(!TrainingStatus::Completed.can_resume());
743    }
744
745    #[test]
746    fn test_training_metrics() {
747        let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
748        assert_eq!(metrics.epoch, 1);
749        assert_eq!(metrics.loss, 0.5);
750
751        let better = TrainingMetrics::new(2, 0.3, 0.9, 0.001);
752        assert!(better.is_improving(&metrics));
753    }
754
755    #[test]
756    fn test_learning_session() {
757        let config = LearningConfig::default();
758        let mut session = LearningSession::new(config);
759
760        assert_eq!(session.status, TrainingStatus::Pending);
761
762        session.start();
763        assert_eq!(session.status, TrainingStatus::Running);
764
765        let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
766        session.update_metrics(metrics);
767        assert_eq!(session.metrics_history.len(), 1);
768
769        session.complete();
770        assert_eq!(session.status, TrainingStatus::Completed);
771        assert!(session.completed_at.is_some());
772    }
773
774    #[test]
775    fn test_transition_graph() {
776        let mut graph = TransitionGraph::new();
777
778        let emb1 = vec![0.1, 0.2, 0.3];
779        let emb2 = vec![0.4, 0.5, 0.6];
780
781        graph.add_node(EmbeddingId::new("n1"), emb1, Some(0));
782        graph.add_node(EmbeddingId::new("n2"), emb2, Some(1));
783        graph.add_edge(0, 1, 0.8);
784
785        assert_eq!(graph.num_nodes(), 2);
786        assert_eq!(graph.num_edges(), 1);
787        assert_eq!(graph.embedding_dim(), Some(3));
788
789        let neighbors = graph.neighbors(0);
790        assert_eq!(neighbors.len(), 1);
791        assert_eq!(neighbors[0], (1, 0.8));
792
793        assert!(graph.validate().is_ok());
794    }
795
796    #[test]
797    fn test_refined_embedding() {
798        let original = vec![1.0, 0.0, 0.0];
799        let refined = vec![0.9, 0.1, 0.0];
800
801        let mut re = RefinedEmbedding::new(
802            EmbeddingId::new("test"),
803            refined,
804            0.95,
805        );
806
807        re.compute_delta(&original);
808        assert!(re.delta_norm.is_some());
809
810        re.normalize();
811        let norm: f32 = re.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
812        assert!((norm - 1.0).abs() < 1e-6);
813    }
814
815    #[test]
816    fn test_early_stopping() {
817        let mut config = LearningConfig::default();
818        config.hyperparameters.early_stopping_patience = Some(3);
819
820        let mut session = LearningSession::new(config);
821        session.start();
822
823        // Improving metrics
824        for i in 0..5 {
825            let loss = 1.0 - (i as f32 * 0.1);
826            session.update_metrics(TrainingMetrics::new(i, loss, 0.8, 0.001));
827        }
828        assert!(!session.should_early_stop());
829
830        // Non-improving metrics
831        for i in 5..10 {
832            session.update_metrics(TrainingMetrics::new(i, 0.6, 0.8, 0.001));
833        }
834        assert!(session.should_early_stop());
835    }
836}