sevensense_learning/domain/
repository.rs

1//! Repository traits for the learning domain.
2//!
3//! Defines the persistence abstraction for learning sessions,
4//! refined embeddings, and transition graphs.
5
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::entities::{
10    EmbeddingId, LearningSession, RefinedEmbedding, TrainingStatus, TransitionGraph,
11};
12
13/// Error type for repository operations
14#[derive(Debug, thiserror::Error)]
15pub enum RepositoryError {
16    /// Session not found
17    #[error("Learning session not found: {0}")]
18    SessionNotFound(String),
19
20    /// Embedding not found
21    #[error("Embedding not found: {0}")]
22    EmbeddingNotFound(String),
23
24    /// Graph not found or empty
25    #[error("Transition graph not found")]
26    GraphNotFound,
27
28    /// Serialization error
29    #[error("Serialization error: {0}")]
30    SerializationError(String),
31
32    /// Storage error
33    #[error("Storage error: {0}")]
34    StorageError(String),
35
36    /// Connection error
37    #[error("Connection error: {0}")]
38    ConnectionError(String),
39
40    /// Validation error
41    #[error("Validation error: {0}")]
42    ValidationError(String),
43
44    /// Concurrent modification error
45    #[error("Concurrent modification detected for: {0}")]
46    ConcurrentModification(String),
47
48    /// Internal error
49    #[error("Internal repository error: {0}")]
50    Internal(String),
51}
52
53impl From<serde_json::Error> for RepositoryError {
54    fn from(e: serde_json::Error) -> Self {
55        Self::SerializationError(e.to_string())
56    }
57}
58
59/// Result type for repository operations
60pub type RepositoryResult<T> = Result<T, RepositoryError>;
61
62/// Repository trait for learning persistence operations.
63///
64/// Implementors should provide durable storage for:
65/// - Learning sessions and their state
66/// - Refined embeddings
67/// - Transition graphs
68#[async_trait]
69pub trait LearningRepository: Send + Sync {
70    // =========== Session Operations ===========
71
72    /// Save a learning session
73    async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()>;
74
75    /// Get a learning session by ID
76    async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>>;
77
78    /// Update an existing session
79    async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()>;
80
81    /// Delete a session
82    async fn delete_session(&self, id: &str) -> RepositoryResult<()>;
83
84    /// List sessions with optional status filter
85    async fn list_sessions(
86        &self,
87        status: Option<TrainingStatus>,
88        limit: Option<usize>,
89    ) -> RepositoryResult<Vec<LearningSession>>;
90
91    // =========== Embedding Operations ===========
92
93    /// Save refined embeddings (batch)
94    async fn save_refined_embeddings(
95        &self,
96        embeddings: &[RefinedEmbedding],
97    ) -> RepositoryResult<()>;
98
99    /// Get a refined embedding by original ID
100    async fn get_refined_embedding(
101        &self,
102        original_id: &EmbeddingId,
103    ) -> RepositoryResult<Option<RefinedEmbedding>>;
104
105    /// Get multiple refined embeddings
106    async fn get_refined_embeddings(
107        &self,
108        ids: &[EmbeddingId],
109    ) -> RepositoryResult<Vec<RefinedEmbedding>>;
110
111    /// Delete refined embeddings for a session
112    async fn delete_refined_embeddings(&self, session_id: &str) -> RepositoryResult<usize>;
113
114    // =========== Graph Operations ===========
115
116    /// Get the current transition graph
117    async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph>;
118
119    /// Save a transition graph
120    async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
121
122    /// Update the transition graph (incremental)
123    async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
124
125    /// Clear the transition graph
126    async fn clear_transition_graph(&self) -> RepositoryResult<()>;
127
128    // =========== Checkpoint Operations ===========
129
130    /// Save a model checkpoint
131    async fn save_checkpoint(
132        &self,
133        session_id: &str,
134        epoch: usize,
135        data: &[u8],
136    ) -> RepositoryResult<String>;
137
138    /// Load a model checkpoint
139    async fn load_checkpoint(
140        &self,
141        session_id: &str,
142        epoch: Option<usize>,
143    ) -> RepositoryResult<Option<Vec<u8>>>;
144
145    /// List available checkpoints for a session
146    async fn list_checkpoints(&self, session_id: &str) -> RepositoryResult<Vec<(usize, String)>>;
147
148    /// Delete checkpoints for a session
149    async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize>;
150}
151
152/// Extension trait for repository operations
153#[async_trait]
154pub trait LearningRepositoryExt: LearningRepository {
155    /// Get the latest session for a model type
156    async fn get_latest_session(
157        &self,
158        model_type: crate::GnnModelType,
159    ) -> RepositoryResult<Option<LearningSession>> {
160        let sessions = self.list_sessions(None, Some(100)).await?;
161        Ok(sessions
162            .into_iter()
163            .filter(|s| s.model_type == model_type)
164            .max_by_key(|s| s.started_at))
165    }
166
167    /// Get all completed sessions
168    async fn get_completed_sessions(&self) -> RepositoryResult<Vec<LearningSession>> {
169        self.list_sessions(Some(TrainingStatus::Completed), None).await
170    }
171
172    /// Check if any session is currently running
173    async fn has_running_session(&self) -> RepositoryResult<bool> {
174        let sessions = self.list_sessions(Some(TrainingStatus::Running), Some(1)).await?;
175        Ok(!sessions.is_empty())
176    }
177
178    /// Get embeddings refined in a specific session
179    async fn get_session_embeddings(
180        &self,
181        session_id: &str,
182    ) -> RepositoryResult<Vec<RefinedEmbedding>> {
183        // Default implementation - may be overridden for efficiency
184        let session = self.get_session(session_id).await?;
185        if session.is_none() {
186            return Err(RepositoryError::SessionNotFound(session_id.to_string()));
187        }
188
189        // This would need to be implemented properly in concrete implementations
190        Ok(Vec::new())
191    }
192}
193
194// Blanket implementation
195impl<T: LearningRepository + ?Sized> LearningRepositoryExt for T {}
196
197/// A thread-safe repository handle
198pub type DynLearningRepository = Arc<dyn LearningRepository>;
199
200/// Unit of work pattern for transactional operations
201#[async_trait]
202pub trait UnitOfWork: Send + Sync {
203    /// Begin a transaction
204    async fn begin(&self) -> RepositoryResult<()>;
205
206    /// Commit the transaction
207    async fn commit(&self) -> RepositoryResult<()>;
208
209    /// Rollback the transaction
210    async fn rollback(&self) -> RepositoryResult<()>;
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use std::collections::HashMap;
217    use tokio::sync::RwLock;
218
219    /// In-memory implementation for testing
220    struct InMemoryRepository {
221        sessions: RwLock<HashMap<String, LearningSession>>,
222        embeddings: RwLock<HashMap<String, RefinedEmbedding>>,
223        graph: RwLock<Option<TransitionGraph>>,
224        checkpoints: RwLock<HashMap<String, Vec<(usize, Vec<u8>)>>>,
225    }
226
227    impl InMemoryRepository {
228        fn new() -> Self {
229            Self {
230                sessions: RwLock::new(HashMap::new()),
231                embeddings: RwLock::new(HashMap::new()),
232                graph: RwLock::new(None),
233                checkpoints: RwLock::new(HashMap::new()),
234            }
235        }
236    }
237
238    #[async_trait]
239    impl LearningRepository for InMemoryRepository {
240        async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()> {
241            let mut sessions = self.sessions.write().await;
242            sessions.insert(session.id.clone(), session.clone());
243            Ok(())
244        }
245
246        async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>> {
247            let sessions = self.sessions.read().await;
248            Ok(sessions.get(id).cloned())
249        }
250
251        async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()> {
252            self.save_session(session).await
253        }
254
255        async fn delete_session(&self, id: &str) -> RepositoryResult<()> {
256            let mut sessions = self.sessions.write().await;
257            sessions.remove(id);
258            Ok(())
259        }
260
261        async fn list_sessions(
262            &self,
263            status: Option<TrainingStatus>,
264            limit: Option<usize>,
265        ) -> RepositoryResult<Vec<LearningSession>> {
266            let sessions = self.sessions.read().await;
267            let mut result: Vec<_> = sessions
268                .values()
269                .filter(|s| status.map_or(true, |st| s.status == st))
270                .cloned()
271                .collect();
272            result.sort_by(|a, b| b.started_at.cmp(&a.started_at));
273            if let Some(limit) = limit {
274                result.truncate(limit);
275            }
276            Ok(result)
277        }
278
279        async fn save_refined_embeddings(
280            &self,
281            embeddings: &[RefinedEmbedding],
282        ) -> RepositoryResult<()> {
283            let mut store = self.embeddings.write().await;
284            for emb in embeddings {
285                store.insert(emb.original_id.0.clone(), emb.clone());
286            }
287            Ok(())
288        }
289
290        async fn get_refined_embedding(
291            &self,
292            original_id: &EmbeddingId,
293        ) -> RepositoryResult<Option<RefinedEmbedding>> {
294            let store = self.embeddings.read().await;
295            Ok(store.get(&original_id.0).cloned())
296        }
297
298        async fn get_refined_embeddings(
299            &self,
300            ids: &[EmbeddingId],
301        ) -> RepositoryResult<Vec<RefinedEmbedding>> {
302            let store = self.embeddings.read().await;
303            Ok(ids
304                .iter()
305                .filter_map(|id| store.get(&id.0).cloned())
306                .collect())
307        }
308
309        async fn delete_refined_embeddings(&self, _session_id: &str) -> RepositoryResult<usize> {
310            let mut store = self.embeddings.write().await;
311            let count = store.len();
312            store.clear();
313            Ok(count)
314        }
315
316        async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph> {
317            let graph = self.graph.read().await;
318            graph.clone().ok_or(RepositoryError::GraphNotFound)
319        }
320
321        async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
322            let mut store = self.graph.write().await;
323            *store = Some(graph.clone());
324            Ok(())
325        }
326
327        async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
328            self.save_transition_graph(graph).await
329        }
330
331        async fn clear_transition_graph(&self) -> RepositoryResult<()> {
332            let mut store = self.graph.write().await;
333            *store = None;
334            Ok(())
335        }
336
337        async fn save_checkpoint(
338            &self,
339            session_id: &str,
340            epoch: usize,
341            data: &[u8],
342        ) -> RepositoryResult<String> {
343            let mut store = self.checkpoints.write().await;
344            let checkpoints = store.entry(session_id.to_string()).or_default();
345            checkpoints.push((epoch, data.to_vec()));
346            Ok(format!("{session_id}-{epoch}"))
347        }
348
349        async fn load_checkpoint(
350            &self,
351            session_id: &str,
352            epoch: Option<usize>,
353        ) -> RepositoryResult<Option<Vec<u8>>> {
354            let store = self.checkpoints.read().await;
355            if let Some(checkpoints) = store.get(session_id) {
356                if let Some(epoch) = epoch {
357                    return Ok(checkpoints
358                        .iter()
359                        .find(|(e, _)| *e == epoch)
360                        .map(|(_, d)| d.clone()));
361                }
362                return Ok(checkpoints.last().map(|(_, d)| d.clone()));
363            }
364            Ok(None)
365        }
366
367        async fn list_checkpoints(
368            &self,
369            session_id: &str,
370        ) -> RepositoryResult<Vec<(usize, String)>> {
371            let store = self.checkpoints.read().await;
372            if let Some(checkpoints) = store.get(session_id) {
373                return Ok(checkpoints
374                    .iter()
375                    .map(|(e, _)| (*e, format!("{session_id}-{e}")))
376                    .collect());
377            }
378            Ok(Vec::new())
379        }
380
381        async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize> {
382            let mut store = self.checkpoints.write().await;
383            if let Some(checkpoints) = store.remove(session_id) {
384                return Ok(checkpoints.len());
385            }
386            Ok(0)
387        }
388    }
389
390    #[tokio::test]
391    async fn test_in_memory_repository() {
392        let repo = InMemoryRepository::new();
393        let config = crate::LearningConfig::default();
394        let session = crate::LearningSession::new(config);
395
396        // Save and retrieve session
397        repo.save_session(&session).await.unwrap();
398        let retrieved = repo.get_session(&session.id).await.unwrap();
399        assert!(retrieved.is_some());
400        assert_eq!(retrieved.unwrap().id, session.id);
401
402        // List sessions
403        let sessions = repo.list_sessions(None, None).await.unwrap();
404        assert_eq!(sessions.len(), 1);
405
406        // Delete session
407        repo.delete_session(&session.id).await.unwrap();
408        let retrieved = repo.get_session(&session.id).await.unwrap();
409        assert!(retrieved.is_none());
410    }
411
412    #[tokio::test]
413    async fn test_transition_graph_operations() {
414        let repo = InMemoryRepository::new();
415
416        // Graph should not exist initially
417        assert!(repo.get_transition_graph().await.is_err());
418
419        // Save graph
420        let mut graph = TransitionGraph::new();
421        graph.add_node(
422            crate::EmbeddingId::new("n1"),
423            vec![0.1, 0.2, 0.3],
424            None,
425        );
426        repo.save_transition_graph(&graph).await.unwrap();
427
428        // Retrieve graph
429        let retrieved = repo.get_transition_graph().await.unwrap();
430        assert_eq!(retrieved.num_nodes(), 1);
431
432        // Clear graph
433        repo.clear_transition_graph().await.unwrap();
434        assert!(repo.get_transition_graph().await.is_err());
435    }
436
437    #[tokio::test]
438    async fn test_checkpoint_operations() {
439        let repo = InMemoryRepository::new();
440        let session_id = "test-session";
441
442        // Save checkpoints
443        repo.save_checkpoint(session_id, 1, b"data1").await.unwrap();
444        repo.save_checkpoint(session_id, 2, b"data2").await.unwrap();
445
446        // List checkpoints
447        let checkpoints = repo.list_checkpoints(session_id).await.unwrap();
448        assert_eq!(checkpoints.len(), 2);
449
450        // Load specific checkpoint
451        let data = repo.load_checkpoint(session_id, Some(1)).await.unwrap();
452        assert_eq!(data, Some(b"data1".to_vec()));
453
454        // Load latest checkpoint
455        let data = repo.load_checkpoint(session_id, None).await.unwrap();
456        assert_eq!(data, Some(b"data2".to_vec()));
457
458        // Delete checkpoints
459        let count = repo.delete_checkpoints(session_id).await.unwrap();
460        assert_eq!(count, 2);
461    }
462}