1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use uuid::Uuid;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub struct EmbeddingId(pub String);
17
18impl EmbeddingId {
19 #[must_use]
21 pub fn new(id: impl Into<String>) -> Self {
22 Self(id.into())
23 }
24
25 #[must_use]
27 pub fn generate() -> Self {
28 Self(Uuid::new_v4().to_string())
29 }
30
31 #[must_use]
33 pub fn as_str(&self) -> &str {
34 &self.0
35 }
36}
37
38impl From<String> for EmbeddingId {
39 fn from(s: String) -> Self {
40 Self(s)
41 }
42}
43
44impl From<&str> for EmbeddingId {
45 fn from(s: &str) -> Self {
46 Self(s.to_string())
47 }
48}
49
50pub type Timestamp = DateTime<Utc>;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum GnnModelType {
57 Gcn,
60 GraphSage,
63 Gat,
66}
67
68impl Default for GnnModelType {
69 fn default() -> Self {
70 Self::Gcn
71 }
72}
73
74impl GnnModelType {
75 #[must_use]
77 pub fn params_per_layer(&self, input_dim: usize, output_dim: usize) -> usize {
78 match self {
79 Self::Gcn => input_dim * output_dim + output_dim,
80 Self::GraphSage => 2 * input_dim * output_dim + output_dim,
81 Self::Gat => input_dim * output_dim + 2 * output_dim,
82 }
83 }
84
85 #[must_use]
87 pub fn recommended_heads(&self) -> usize {
88 match self {
89 Self::Gat => 8,
90 _ => 1,
91 }
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum TrainingStatus {
99 Pending,
101 Running,
103 Completed,
105 Failed,
107 Paused,
109 Cancelled,
111}
112
113impl TrainingStatus {
114 #[must_use]
116 pub fn is_terminal(&self) -> bool {
117 matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
118 }
119
120 #[must_use]
122 pub fn can_resume(&self) -> bool {
123 matches!(self, Self::Paused)
124 }
125
126 #[must_use]
128 pub fn is_active(&self) -> bool {
129 matches!(self, Self::Running)
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct TrainingMetrics {
136 pub loss: f32,
138 pub accuracy: f32,
140 pub epoch: usize,
142 pub learning_rate: f32,
144 pub validation_loss: Option<f32>,
146 pub validation_accuracy: Option<f32>,
148 pub gradient_norm: Option<f32>,
150 pub epoch_time_ms: u64,
152 #[serde(default)]
154 pub custom_metrics: HashMap<String, f32>,
155}
156
157impl Default for TrainingMetrics {
158 fn default() -> Self {
159 Self {
160 loss: f32::INFINITY,
161 accuracy: 0.0,
162 epoch: 0,
163 learning_rate: 0.001,
164 validation_loss: None,
165 validation_accuracy: None,
166 gradient_norm: None,
167 epoch_time_ms: 0,
168 custom_metrics: HashMap::new(),
169 }
170 }
171}
172
173impl TrainingMetrics {
174 #[must_use]
176 pub fn new(epoch: usize, loss: f32, accuracy: f32, learning_rate: f32) -> Self {
177 Self {
178 loss,
179 accuracy,
180 epoch,
181 learning_rate,
182 ..Default::default()
183 }
184 }
185
186 #[must_use]
188 pub fn with_validation(mut self, loss: f32, accuracy: f32) -> Self {
189 self.validation_loss = Some(loss);
190 self.validation_accuracy = Some(accuracy);
191 self
192 }
193
194 pub fn add_custom_metric(&mut self, name: impl Into<String>, value: f32) {
196 self.custom_metrics.insert(name.into(), value);
197 }
198
199 #[must_use]
201 pub fn is_improving(&self, previous: &Self) -> bool {
202 self.loss < previous.loss
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct HyperParameters {
209 pub learning_rate: f32,
211 pub weight_decay: f32,
213 pub dropout: f32,
215 pub epochs: usize,
217 pub batch_size: usize,
219 pub early_stopping_patience: Option<usize>,
221 pub gradient_clip: Option<f32>,
223 pub temperature: f32,
225 pub triplet_margin: f32,
227 pub ewc_lambda: f32,
229 pub num_layers: usize,
231 pub hidden_dim: usize,
233 pub num_heads: usize,
235 pub negative_ratio: usize,
237}
238
239impl Default for HyperParameters {
240 fn default() -> Self {
241 Self {
242 learning_rate: 0.001,
243 weight_decay: 5e-4,
244 dropout: 0.5,
245 epochs: 200,
246 batch_size: 32,
247 early_stopping_patience: Some(20),
248 gradient_clip: Some(1.0),
249 temperature: 0.07,
250 triplet_margin: 1.0,
251 ewc_lambda: 5000.0,
252 num_layers: 2,
253 hidden_dim: 256,
254 num_heads: 8,
255 negative_ratio: 5,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct LearningConfig {
263 pub model_type: GnnModelType,
265 pub input_dim: usize,
267 pub output_dim: usize,
269 pub hyperparameters: HyperParameters,
271 pub mixed_precision: bool,
273 pub device: Device,
275 pub seed: Option<u64>,
277 pub gradient_checkpointing: bool,
279}
280
281impl Default for LearningConfig {
282 fn default() -> Self {
283 Self {
284 model_type: GnnModelType::Gcn,
285 input_dim: 768,
286 output_dim: 256,
287 hyperparameters: HyperParameters::default(),
288 mixed_precision: false,
289 device: Device::Cpu,
290 seed: None,
291 gradient_checkpointing: false,
292 }
293 }
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
298#[serde(rename_all = "snake_case")]
299pub enum Device {
300 #[default]
302 Cpu,
303 Cuda(usize),
305 Metal,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct LearningSession {
312 pub id: String,
314 pub model_type: GnnModelType,
316 pub status: TrainingStatus,
318 pub metrics: TrainingMetrics,
320 pub started_at: Timestamp,
322 pub updated_at: Timestamp,
324 pub completed_at: Option<Timestamp>,
326 pub config: LearningConfig,
328 #[serde(default)]
330 pub metrics_history: Vec<TrainingMetrics>,
331 pub best_metrics: Option<TrainingMetrics>,
333 pub error_message: Option<String>,
335 pub checkpoint_count: usize,
337}
338
339impl LearningSession {
340 #[must_use]
342 pub fn new(config: LearningConfig) -> Self {
343 let now = Utc::now();
344 Self {
345 id: Uuid::new_v4().to_string(),
346 model_type: config.model_type,
347 status: TrainingStatus::Pending,
348 metrics: TrainingMetrics::default(),
349 started_at: now,
350 updated_at: now,
351 completed_at: None,
352 config,
353 metrics_history: Vec::new(),
354 best_metrics: None,
355 error_message: None,
356 checkpoint_count: 0,
357 }
358 }
359
360 pub fn start(&mut self) {
362 self.status = TrainingStatus::Running;
363 self.updated_at = Utc::now();
364 }
365
366 pub fn update_metrics(&mut self, metrics: TrainingMetrics) {
368 if self.best_metrics.is_none()
370 || metrics.loss < self.best_metrics.as_ref().unwrap().loss
371 {
372 self.best_metrics = Some(metrics.clone());
373 }
374
375 self.metrics = metrics.clone();
376 self.metrics_history.push(metrics);
377 self.updated_at = Utc::now();
378 }
379
380 pub fn complete(&mut self) {
382 self.status = TrainingStatus::Completed;
383 self.completed_at = Some(Utc::now());
384 self.updated_at = Utc::now();
385 }
386
387 pub fn fail(&mut self, error: impl Into<String>) {
389 self.status = TrainingStatus::Failed;
390 self.error_message = Some(error.into());
391 self.completed_at = Some(Utc::now());
392 self.updated_at = Utc::now();
393 }
394
395 pub fn pause(&mut self) {
397 if self.status == TrainingStatus::Running {
398 self.status = TrainingStatus::Paused;
399 self.updated_at = Utc::now();
400 }
401 }
402
403 pub fn resume(&mut self) {
405 if self.status == TrainingStatus::Paused {
406 self.status = TrainingStatus::Running;
407 self.updated_at = Utc::now();
408 }
409 }
410
411 #[must_use]
413 pub fn duration(&self) -> chrono::Duration {
414 let end = self.completed_at.unwrap_or_else(Utc::now);
415 end - self.started_at
416 }
417
418 #[must_use]
420 pub fn should_early_stop(&self) -> bool {
421 if let Some(patience) = self.config.hyperparameters.early_stopping_patience {
422 if self.metrics_history.len() <= patience {
423 return false;
424 }
425
426 let best_epoch = self
427 .metrics_history
428 .iter()
429 .enumerate()
430 .min_by(|(_, a), (_, b)| a.loss.partial_cmp(&b.loss).unwrap())
431 .map(|(i, _)| i)
432 .unwrap_or(0);
433
434 self.metrics_history.len() - best_epoch > patience
435 } else {
436 false
437 }
438 }
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct GraphNode {
444 pub id: EmbeddingId,
446 pub embedding: Vec<f32>,
448 pub features: Option<Vec<f32>>,
450 pub label: Option<usize>,
452 #[serde(default)]
454 pub metadata: HashMap<String, String>,
455}
456
457impl GraphNode {
458 #[must_use]
460 pub fn new(id: EmbeddingId, embedding: Vec<f32>) -> Self {
461 Self {
462 id,
463 embedding,
464 features: None,
465 label: None,
466 metadata: HashMap::new(),
467 }
468 }
469
470 #[must_use]
472 pub fn dim(&self) -> usize {
473 self.embedding.len()
474 }
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct GraphEdge {
480 pub from: usize,
482 pub to: usize,
484 pub weight: f32,
486 pub edge_type: Option<String>,
488}
489
490impl GraphEdge {
491 #[must_use]
493 pub fn new(from: usize, to: usize, weight: f32) -> Self {
494 Self {
495 from,
496 to,
497 weight,
498 edge_type: None,
499 }
500 }
501
502 #[must_use]
504 pub fn typed(from: usize, to: usize, weight: f32, edge_type: impl Into<String>) -> Self {
505 Self {
506 from,
507 to,
508 weight,
509 edge_type: Some(edge_type.into()),
510 }
511 }
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct TransitionGraph {
517 pub nodes: Vec<EmbeddingId>,
519 pub embeddings: Vec<Vec<f32>>,
521 pub edges: Vec<(usize, usize, f32)>,
523 #[serde(default)]
525 pub labels: Vec<Option<usize>>,
526 pub num_classes: Option<usize>,
528 pub directed: bool,
530}
531
532impl Default for TransitionGraph {
533 fn default() -> Self {
534 Self::new()
535 }
536}
537
538impl TransitionGraph {
539 #[must_use]
541 pub fn new() -> Self {
542 Self {
543 nodes: Vec::new(),
544 embeddings: Vec::new(),
545 edges: Vec::new(),
546 labels: Vec::new(),
547 num_classes: None,
548 directed: true,
549 }
550 }
551
552 #[must_use]
554 pub fn undirected() -> Self {
555 Self {
556 directed: false,
557 ..Self::new()
558 }
559 }
560
561 pub fn add_node(&mut self, id: EmbeddingId, embedding: Vec<f32>, label: Option<usize>) {
563 self.nodes.push(id);
564 self.embeddings.push(embedding);
565 self.labels.push(label);
566 }
567
568 pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
570 assert!(from < self.nodes.len(), "Invalid 'from' node index");
571 assert!(to < self.nodes.len(), "Invalid 'to' node index");
572 self.edges.push((from, to, weight));
573
574 if !self.directed {
576 self.edges.push((to, from, weight));
577 }
578 }
579
580 #[must_use]
582 pub fn num_nodes(&self) -> usize {
583 self.nodes.len()
584 }
585
586 #[must_use]
588 pub fn num_edges(&self) -> usize {
589 self.edges.len()
590 }
591
592 #[must_use]
594 pub fn embedding_dim(&self) -> Option<usize> {
595 self.embeddings.first().map(Vec::len)
596 }
597
598 #[must_use]
600 pub fn neighbors(&self, node_idx: usize) -> Vec<(usize, f32)> {
601 self.edges
602 .iter()
603 .filter(|(from, _, _)| *from == node_idx)
604 .map(|(_, to, weight)| (*to, *weight))
605 .collect()
606 }
607
608 #[must_use]
610 pub fn adjacency_list(&self) -> Vec<Vec<(usize, f32)>> {
611 let mut adj = vec![Vec::new(); self.nodes.len()];
612 for &(from, to, weight) in &self.edges {
613 adj[from].push((to, weight));
614 }
615 adj
616 }
617
618 #[must_use]
620 pub fn degrees(&self) -> Vec<usize> {
621 let mut degrees = vec![0; self.nodes.len()];
622 for &(from, to, _) in &self.edges {
623 degrees[from] += 1;
624 if !self.directed {
625 degrees[to] += 1;
626 }
627 }
628 degrees
629 }
630
631 pub fn validate(&self) -> Result<(), String> {
633 if self.nodes.len() != self.embeddings.len() {
634 return Err("Nodes and embeddings count mismatch".to_string());
635 }
636 if !self.labels.is_empty() && self.labels.len() != self.nodes.len() {
637 return Err("Labels count mismatch".to_string());
638 }
639 for &(from, to, _) in &self.edges {
640 if from >= self.nodes.len() || to >= self.nodes.len() {
641 return Err(format!("Invalid edge: ({from}, {to})"));
642 }
643 }
644 Ok(())
645 }
646}
647
648#[derive(Debug, Clone, Serialize, Deserialize)]
650pub struct RefinedEmbedding {
651 pub original_id: EmbeddingId,
653 pub refined_vector: Vec<f32>,
655 pub refinement_score: f32,
657 pub session_id: Option<String>,
659 pub refined_at: Timestamp,
661 pub delta_norm: Option<f32>,
663 pub confidence: f32,
665}
666
667impl RefinedEmbedding {
668 #[must_use]
670 pub fn new(
671 original_id: EmbeddingId,
672 refined_vector: Vec<f32>,
673 refinement_score: f32,
674 ) -> Self {
675 Self {
676 original_id,
677 refined_vector,
678 refinement_score,
679 session_id: None,
680 refined_at: Utc::now(),
681 delta_norm: None,
682 confidence: refinement_score,
683 }
684 }
685
686 pub fn compute_delta(&mut self, original: &[f32]) {
688 if original.len() != self.refined_vector.len() {
689 return;
690 }
691 let delta: f32 = original
692 .iter()
693 .zip(&self.refined_vector)
694 .map(|(a, b)| (a - b).powi(2))
695 .sum();
696 self.delta_norm = Some(delta.sqrt());
697 }
698
699 #[must_use]
701 pub fn dim(&self) -> usize {
702 self.refined_vector.len()
703 }
704
705 pub fn normalize(&mut self) {
707 let norm: f32 = self.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
708 if norm > 1e-10 {
709 for x in &mut self.refined_vector {
710 *x /= norm;
711 }
712 }
713 }
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719
720 #[test]
721 fn test_embedding_id() {
722 let id = EmbeddingId::new("test-123");
723 assert_eq!(id.as_str(), "test-123");
724
725 let generated = EmbeddingId::generate();
726 assert!(!generated.as_str().is_empty());
727 }
728
729 #[test]
730 fn test_gnn_model_type() {
731 assert_eq!(GnnModelType::default(), GnnModelType::Gcn);
732 assert_eq!(GnnModelType::Gat.recommended_heads(), 8);
733 assert_eq!(GnnModelType::Gcn.recommended_heads(), 1);
734 }
735
736 #[test]
737 fn test_training_status() {
738 assert!(!TrainingStatus::Running.is_terminal());
739 assert!(TrainingStatus::Completed.is_terminal());
740 assert!(TrainingStatus::Failed.is_terminal());
741 assert!(TrainingStatus::Paused.can_resume());
742 assert!(!TrainingStatus::Completed.can_resume());
743 }
744
745 #[test]
746 fn test_training_metrics() {
747 let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
748 assert_eq!(metrics.epoch, 1);
749 assert_eq!(metrics.loss, 0.5);
750
751 let better = TrainingMetrics::new(2, 0.3, 0.9, 0.001);
752 assert!(better.is_improving(&metrics));
753 }
754
755 #[test]
756 fn test_learning_session() {
757 let config = LearningConfig::default();
758 let mut session = LearningSession::new(config);
759
760 assert_eq!(session.status, TrainingStatus::Pending);
761
762 session.start();
763 assert_eq!(session.status, TrainingStatus::Running);
764
765 let metrics = TrainingMetrics::new(1, 0.5, 0.8, 0.001);
766 session.update_metrics(metrics);
767 assert_eq!(session.metrics_history.len(), 1);
768
769 session.complete();
770 assert_eq!(session.status, TrainingStatus::Completed);
771 assert!(session.completed_at.is_some());
772 }
773
774 #[test]
775 fn test_transition_graph() {
776 let mut graph = TransitionGraph::new();
777
778 let emb1 = vec![0.1, 0.2, 0.3];
779 let emb2 = vec![0.4, 0.5, 0.6];
780
781 graph.add_node(EmbeddingId::new("n1"), emb1, Some(0));
782 graph.add_node(EmbeddingId::new("n2"), emb2, Some(1));
783 graph.add_edge(0, 1, 0.8);
784
785 assert_eq!(graph.num_nodes(), 2);
786 assert_eq!(graph.num_edges(), 1);
787 assert_eq!(graph.embedding_dim(), Some(3));
788
789 let neighbors = graph.neighbors(0);
790 assert_eq!(neighbors.len(), 1);
791 assert_eq!(neighbors[0], (1, 0.8));
792
793 assert!(graph.validate().is_ok());
794 }
795
796 #[test]
797 fn test_refined_embedding() {
798 let original = vec![1.0, 0.0, 0.0];
799 let refined = vec![0.9, 0.1, 0.0];
800
801 let mut re = RefinedEmbedding::new(
802 EmbeddingId::new("test"),
803 refined,
804 0.95,
805 );
806
807 re.compute_delta(&original);
808 assert!(re.delta_norm.is_some());
809
810 re.normalize();
811 let norm: f32 = re.refined_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
812 assert!((norm - 1.0).abs() < 1e-6);
813 }
814
815 #[test]
816 fn test_early_stopping() {
817 let mut config = LearningConfig::default();
818 config.hyperparameters.early_stopping_patience = Some(3);
819
820 let mut session = LearningSession::new(config);
821 session.start();
822
823 for i in 0..5 {
825 let loss = 1.0 - (i as f32 * 0.1);
826 session.update_metrics(TrainingMetrics::new(i, loss, 0.8, 0.001));
827 }
828 assert!(!session.should_early_stop());
829
830 for i in 5..10 {
832 session.update_metrics(TrainingMetrics::new(i, 0.6, 0.8, 0.001));
833 }
834 assert!(session.should_early_stop());
835 }
836}