sevensense_learning/domain/
repository.rs1use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::entities::{
10 EmbeddingId, LearningSession, RefinedEmbedding, TrainingStatus, TransitionGraph,
11};
12
13#[derive(Debug, thiserror::Error)]
15pub enum RepositoryError {
16 #[error("Learning session not found: {0}")]
18 SessionNotFound(String),
19
20 #[error("Embedding not found: {0}")]
22 EmbeddingNotFound(String),
23
24 #[error("Transition graph not found")]
26 GraphNotFound,
27
28 #[error("Serialization error: {0}")]
30 SerializationError(String),
31
32 #[error("Storage error: {0}")]
34 StorageError(String),
35
36 #[error("Connection error: {0}")]
38 ConnectionError(String),
39
40 #[error("Validation error: {0}")]
42 ValidationError(String),
43
44 #[error("Concurrent modification detected for: {0}")]
46 ConcurrentModification(String),
47
48 #[error("Internal repository error: {0}")]
50 Internal(String),
51}
52
53impl From<serde_json::Error> for RepositoryError {
54 fn from(e: serde_json::Error) -> Self {
55 Self::SerializationError(e.to_string())
56 }
57}
58
59pub type RepositoryResult<T> = Result<T, RepositoryError>;
61
62#[async_trait]
69pub trait LearningRepository: Send + Sync {
70 async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()>;
74
75 async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>>;
77
78 async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()>;
80
81 async fn delete_session(&self, id: &str) -> RepositoryResult<()>;
83
84 async fn list_sessions(
86 &self,
87 status: Option<TrainingStatus>,
88 limit: Option<usize>,
89 ) -> RepositoryResult<Vec<LearningSession>>;
90
91 async fn save_refined_embeddings(
95 &self,
96 embeddings: &[RefinedEmbedding],
97 ) -> RepositoryResult<()>;
98
99 async fn get_refined_embedding(
101 &self,
102 original_id: &EmbeddingId,
103 ) -> RepositoryResult<Option<RefinedEmbedding>>;
104
105 async fn get_refined_embeddings(
107 &self,
108 ids: &[EmbeddingId],
109 ) -> RepositoryResult<Vec<RefinedEmbedding>>;
110
111 async fn delete_refined_embeddings(&self, session_id: &str) -> RepositoryResult<usize>;
113
114 async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph>;
118
119 async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
121
122 async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()>;
124
125 async fn clear_transition_graph(&self) -> RepositoryResult<()>;
127
128 async fn save_checkpoint(
132 &self,
133 session_id: &str,
134 epoch: usize,
135 data: &[u8],
136 ) -> RepositoryResult<String>;
137
138 async fn load_checkpoint(
140 &self,
141 session_id: &str,
142 epoch: Option<usize>,
143 ) -> RepositoryResult<Option<Vec<u8>>>;
144
145 async fn list_checkpoints(&self, session_id: &str) -> RepositoryResult<Vec<(usize, String)>>;
147
148 async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize>;
150}
151
152#[async_trait]
154pub trait LearningRepositoryExt: LearningRepository {
155 async fn get_latest_session(
157 &self,
158 model_type: crate::GnnModelType,
159 ) -> RepositoryResult<Option<LearningSession>> {
160 let sessions = self.list_sessions(None, Some(100)).await?;
161 Ok(sessions
162 .into_iter()
163 .filter(|s| s.model_type == model_type)
164 .max_by_key(|s| s.started_at))
165 }
166
167 async fn get_completed_sessions(&self) -> RepositoryResult<Vec<LearningSession>> {
169 self.list_sessions(Some(TrainingStatus::Completed), None).await
170 }
171
172 async fn has_running_session(&self) -> RepositoryResult<bool> {
174 let sessions = self.list_sessions(Some(TrainingStatus::Running), Some(1)).await?;
175 Ok(!sessions.is_empty())
176 }
177
178 async fn get_session_embeddings(
180 &self,
181 session_id: &str,
182 ) -> RepositoryResult<Vec<RefinedEmbedding>> {
183 let session = self.get_session(session_id).await?;
185 if session.is_none() {
186 return Err(RepositoryError::SessionNotFound(session_id.to_string()));
187 }
188
189 Ok(Vec::new())
191 }
192}
193
194impl<T: LearningRepository + ?Sized> LearningRepositoryExt for T {}
196
197pub type DynLearningRepository = Arc<dyn LearningRepository>;
199
200#[async_trait]
202pub trait UnitOfWork: Send + Sync {
203 async fn begin(&self) -> RepositoryResult<()>;
205
206 async fn commit(&self) -> RepositoryResult<()>;
208
209 async fn rollback(&self) -> RepositoryResult<()>;
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::collections::HashMap;
217 use tokio::sync::RwLock;
218
219 struct InMemoryRepository {
221 sessions: RwLock<HashMap<String, LearningSession>>,
222 embeddings: RwLock<HashMap<String, RefinedEmbedding>>,
223 graph: RwLock<Option<TransitionGraph>>,
224 checkpoints: RwLock<HashMap<String, Vec<(usize, Vec<u8>)>>>,
225 }
226
227 impl InMemoryRepository {
228 fn new() -> Self {
229 Self {
230 sessions: RwLock::new(HashMap::new()),
231 embeddings: RwLock::new(HashMap::new()),
232 graph: RwLock::new(None),
233 checkpoints: RwLock::new(HashMap::new()),
234 }
235 }
236 }
237
238 #[async_trait]
239 impl LearningRepository for InMemoryRepository {
240 async fn save_session(&self, session: &LearningSession) -> RepositoryResult<()> {
241 let mut sessions = self.sessions.write().await;
242 sessions.insert(session.id.clone(), session.clone());
243 Ok(())
244 }
245
246 async fn get_session(&self, id: &str) -> RepositoryResult<Option<LearningSession>> {
247 let sessions = self.sessions.read().await;
248 Ok(sessions.get(id).cloned())
249 }
250
251 async fn update_session(&self, session: &LearningSession) -> RepositoryResult<()> {
252 self.save_session(session).await
253 }
254
255 async fn delete_session(&self, id: &str) -> RepositoryResult<()> {
256 let mut sessions = self.sessions.write().await;
257 sessions.remove(id);
258 Ok(())
259 }
260
261 async fn list_sessions(
262 &self,
263 status: Option<TrainingStatus>,
264 limit: Option<usize>,
265 ) -> RepositoryResult<Vec<LearningSession>> {
266 let sessions = self.sessions.read().await;
267 let mut result: Vec<_> = sessions
268 .values()
269 .filter(|s| status.map_or(true, |st| s.status == st))
270 .cloned()
271 .collect();
272 result.sort_by(|a, b| b.started_at.cmp(&a.started_at));
273 if let Some(limit) = limit {
274 result.truncate(limit);
275 }
276 Ok(result)
277 }
278
279 async fn save_refined_embeddings(
280 &self,
281 embeddings: &[RefinedEmbedding],
282 ) -> RepositoryResult<()> {
283 let mut store = self.embeddings.write().await;
284 for emb in embeddings {
285 store.insert(emb.original_id.0.clone(), emb.clone());
286 }
287 Ok(())
288 }
289
290 async fn get_refined_embedding(
291 &self,
292 original_id: &EmbeddingId,
293 ) -> RepositoryResult<Option<RefinedEmbedding>> {
294 let store = self.embeddings.read().await;
295 Ok(store.get(&original_id.0).cloned())
296 }
297
298 async fn get_refined_embeddings(
299 &self,
300 ids: &[EmbeddingId],
301 ) -> RepositoryResult<Vec<RefinedEmbedding>> {
302 let store = self.embeddings.read().await;
303 Ok(ids
304 .iter()
305 .filter_map(|id| store.get(&id.0).cloned())
306 .collect())
307 }
308
309 async fn delete_refined_embeddings(&self, _session_id: &str) -> RepositoryResult<usize> {
310 let mut store = self.embeddings.write().await;
311 let count = store.len();
312 store.clear();
313 Ok(count)
314 }
315
316 async fn get_transition_graph(&self) -> RepositoryResult<TransitionGraph> {
317 let graph = self.graph.read().await;
318 graph.clone().ok_or(RepositoryError::GraphNotFound)
319 }
320
321 async fn save_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
322 let mut store = self.graph.write().await;
323 *store = Some(graph.clone());
324 Ok(())
325 }
326
327 async fn update_transition_graph(&self, graph: &TransitionGraph) -> RepositoryResult<()> {
328 self.save_transition_graph(graph).await
329 }
330
331 async fn clear_transition_graph(&self) -> RepositoryResult<()> {
332 let mut store = self.graph.write().await;
333 *store = None;
334 Ok(())
335 }
336
337 async fn save_checkpoint(
338 &self,
339 session_id: &str,
340 epoch: usize,
341 data: &[u8],
342 ) -> RepositoryResult<String> {
343 let mut store = self.checkpoints.write().await;
344 let checkpoints = store.entry(session_id.to_string()).or_default();
345 checkpoints.push((epoch, data.to_vec()));
346 Ok(format!("{session_id}-{epoch}"))
347 }
348
349 async fn load_checkpoint(
350 &self,
351 session_id: &str,
352 epoch: Option<usize>,
353 ) -> RepositoryResult<Option<Vec<u8>>> {
354 let store = self.checkpoints.read().await;
355 if let Some(checkpoints) = store.get(session_id) {
356 if let Some(epoch) = epoch {
357 return Ok(checkpoints
358 .iter()
359 .find(|(e, _)| *e == epoch)
360 .map(|(_, d)| d.clone()));
361 }
362 return Ok(checkpoints.last().map(|(_, d)| d.clone()));
363 }
364 Ok(None)
365 }
366
367 async fn list_checkpoints(
368 &self,
369 session_id: &str,
370 ) -> RepositoryResult<Vec<(usize, String)>> {
371 let store = self.checkpoints.read().await;
372 if let Some(checkpoints) = store.get(session_id) {
373 return Ok(checkpoints
374 .iter()
375 .map(|(e, _)| (*e, format!("{session_id}-{e}")))
376 .collect());
377 }
378 Ok(Vec::new())
379 }
380
381 async fn delete_checkpoints(&self, session_id: &str) -> RepositoryResult<usize> {
382 let mut store = self.checkpoints.write().await;
383 if let Some(checkpoints) = store.remove(session_id) {
384 return Ok(checkpoints.len());
385 }
386 Ok(0)
387 }
388 }
389
390 #[tokio::test]
391 async fn test_in_memory_repository() {
392 let repo = InMemoryRepository::new();
393 let config = crate::LearningConfig::default();
394 let session = crate::LearningSession::new(config);
395
396 repo.save_session(&session).await.unwrap();
398 let retrieved = repo.get_session(&session.id).await.unwrap();
399 assert!(retrieved.is_some());
400 assert_eq!(retrieved.unwrap().id, session.id);
401
402 let sessions = repo.list_sessions(None, None).await.unwrap();
404 assert_eq!(sessions.len(), 1);
405
406 repo.delete_session(&session.id).await.unwrap();
408 let retrieved = repo.get_session(&session.id).await.unwrap();
409 assert!(retrieved.is_none());
410 }
411
412 #[tokio::test]
413 async fn test_transition_graph_operations() {
414 let repo = InMemoryRepository::new();
415
416 assert!(repo.get_transition_graph().await.is_err());
418
419 let mut graph = TransitionGraph::new();
421 graph.add_node(
422 crate::EmbeddingId::new("n1"),
423 vec![0.1, 0.2, 0.3],
424 None,
425 );
426 repo.save_transition_graph(&graph).await.unwrap();
427
428 let retrieved = repo.get_transition_graph().await.unwrap();
430 assert_eq!(retrieved.num_nodes(), 1);
431
432 repo.clear_transition_graph().await.unwrap();
434 assert!(repo.get_transition_graph().await.is_err());
435 }
436
437 #[tokio::test]
438 async fn test_checkpoint_operations() {
439 let repo = InMemoryRepository::new();
440 let session_id = "test-session";
441
442 repo.save_checkpoint(session_id, 1, b"data1").await.unwrap();
444 repo.save_checkpoint(session_id, 2, b"data2").await.unwrap();
445
446 let checkpoints = repo.list_checkpoints(session_id).await.unwrap();
448 assert_eq!(checkpoints.len(), 2);
449
450 let data = repo.load_checkpoint(session_id, Some(1)).await.unwrap();
452 assert_eq!(data, Some(b"data1".to_vec()));
453
454 let data = repo.load_checkpoint(session_id, None).await.unwrap();
456 assert_eq!(data, Some(b"data2".to_vec()));
457
458 let count = repo.delete_checkpoints(session_id).await.unwrap();
460 assert_eq!(count, 2);
461 }
462}