1use crate::embeddings::{BoxedEmbeddingProvider, EmbeddingProvider, HashEmbedding};
28use crate::error::{Result, RuvectorError};
29use crate::types::*;
30use crate::vector_db::VectorDB;
31use parking_lot::RwLock;
32use redb::{Database, TableDefinition};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::path::Path;
36use std::sync::Arc;
37
38const REFLEXION_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("reflexion_episodes");
40const SKILLS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("skills_library");
41const CAUSAL_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("causal_edges");
42const LEARNING_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("learning_sessions");
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ReflexionEpisode {
48 pub id: String,
49 pub task: String,
50 pub actions: Vec<String>,
51 pub observations: Vec<String>,
52 pub critique: String,
53 pub embedding: Vec<f32>,
54 pub timestamp: i64,
55 pub metadata: Option<HashMap<String, serde_json::Value>>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
60pub struct Skill {
61 pub id: String,
62 pub name: String,
63 pub description: String,
64 pub parameters: HashMap<String, String>,
65 pub examples: Vec<String>,
66 pub embedding: Vec<f32>,
67 pub usage_count: usize,
68 pub success_rate: f64,
69 pub created_at: i64,
70 pub updated_at: i64,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
75pub struct CausalEdge {
76 pub id: String,
77 pub causes: Vec<String>, pub effects: Vec<String>, pub confidence: f64,
80 pub context: String,
81 pub embedding: Vec<f32>,
82 pub observations: usize,
83 pub timestamp: i64,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
88pub struct LearningSession {
89 pub id: String,
90 pub algorithm: String, pub state_dim: usize,
92 pub action_dim: usize,
93 pub experiences: Vec<Experience>,
94 pub model_params: Option<Vec<u8>>, pub created_at: i64,
96 pub updated_at: i64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
101pub struct Experience {
102 pub state: Vec<f32>,
103 pub action: Vec<f32>,
104 pub reward: f64,
105 pub next_state: Vec<f32>,
106 pub done: bool,
107 pub timestamp: i64,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
112pub struct Prediction {
113 pub action: Vec<f32>,
114 pub confidence_lower: f64,
115 pub confidence_upper: f64,
116 pub mean_confidence: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct UtilitySearchResult {
122 pub result: SearchResult,
123 pub utility_score: f64,
124 pub similarity_score: f64,
125 pub causal_uplift: f64,
126 pub latency_penalty: f64,
127}
128
129pub struct AgenticDB {
131 vector_db: Arc<VectorDB>,
132 db: Arc<Database>,
133 dimensions: usize,
134 embedding_provider: BoxedEmbeddingProvider,
135}
136
137impl AgenticDB {
138 pub fn new(options: DbOptions) -> Result<Self> {
140 let embedding_provider = Arc::new(HashEmbedding::new(options.dimensions));
141 Self::with_embedding_provider(options, embedding_provider)
142 }
143
144 pub fn with_embedding_provider(
182 options: DbOptions,
183 embedding_provider: BoxedEmbeddingProvider,
184 ) -> Result<Self> {
185 if options.dimensions != embedding_provider.dimensions() {
187 return Err(RuvectorError::InvalidDimension(
188 format!(
189 "Options dimensions ({}) do not match embedding provider dimensions ({})",
190 options.dimensions,
191 embedding_provider.dimensions()
192 )
193 ));
194 }
195
196 let vector_db = Arc::new(VectorDB::new(options.clone())?);
198
199 let agentic_path = format!("{}.agentic", options.storage_path);
201 let db = Arc::new(Database::create(&agentic_path)?);
202
203 let write_txn = db.begin_write()?;
205 {
206 let _ = write_txn.open_table(REFLEXION_TABLE)?;
207 let _ = write_txn.open_table(SKILLS_TABLE)?;
208 let _ = write_txn.open_table(CAUSAL_TABLE)?;
209 let _ = write_txn.open_table(LEARNING_TABLE)?;
210 }
211 write_txn.commit()?;
212
213 Ok(Self {
214 vector_db,
215 db,
216 dimensions: options.dimensions,
217 embedding_provider,
218 })
219 }
220
221 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
223 let mut options = DbOptions::default();
224 options.dimensions = dimensions;
225 Self::new(options)
226 }
227
228 pub fn embedding_provider_name(&self) -> &str {
230 self.embedding_provider.name()
231 }
232
233 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
237 self.vector_db.insert(entry)
238 }
239
240 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
242 self.vector_db.insert_batch(entries)
243 }
244
245 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
247 self.vector_db.search(query)
248 }
249
250 pub fn delete(&self, id: &str) -> Result<bool> {
252 self.vector_db.delete(id)
253 }
254
255 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
257 self.vector_db.get(id)
258 }
259
260 pub fn store_episode(
264 &self,
265 task: String,
266 actions: Vec<String>,
267 observations: Vec<String>,
268 critique: String,
269 ) -> Result<String> {
270 let id = uuid::Uuid::new_v4().to_string();
271
272 let embedding = self.generate_text_embedding(&critique)?;
274
275 let episode = ReflexionEpisode {
276 id: id.clone(),
277 task,
278 actions,
279 observations,
280 critique,
281 embedding: embedding.clone(),
282 timestamp: chrono::Utc::now().timestamp(),
283 metadata: None,
284 };
285
286 let write_txn = self.db.begin_write()?;
288 {
289 let mut table = write_txn.open_table(REFLEXION_TABLE)?;
290 let json = serde_json::to_vec(&episode)
292 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
293 table.insert(id.as_str(), json.as_slice())?;
294 }
295 write_txn.commit()?;
296
297 self.vector_db.insert(VectorEntry {
299 id: Some(format!("reflexion_{}", id)),
300 vector: embedding,
301 metadata: Some({
302 let mut meta = HashMap::new();
303 meta.insert("type".to_string(), serde_json::json!("reflexion"));
304 meta.insert("episode_id".to_string(), serde_json::json!(id.clone()));
305 meta
306 }),
307 })?;
308
309 Ok(id)
310 }
311
312 pub fn retrieve_similar_episodes(
314 &self,
315 query: &str,
316 k: usize,
317 ) -> Result<Vec<ReflexionEpisode>> {
318 let query_embedding = self.generate_text_embedding(query)?;
320
321 let results = self.vector_db.search(SearchQuery {
323 vector: query_embedding,
324 k,
325 filter: Some({
326 let mut filter = HashMap::new();
327 filter.insert("type".to_string(), serde_json::json!("reflexion"));
328 filter
329 }),
330 ef_search: None,
331 })?;
332
333 let mut episodes = Vec::new();
335 let read_txn = self.db.begin_read()?;
336 let table = read_txn.open_table(REFLEXION_TABLE)?;
337
338 for result in results {
339 if let Some(metadata) = result.metadata {
340 if let Some(episode_id) = metadata.get("episode_id") {
341 let id = episode_id.as_str().unwrap();
342 if let Some(data) = table.get(id)? {
343 let episode: ReflexionEpisode = serde_json::from_slice(data.value())
345 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
346 episodes.push(episode);
347 }
348 }
349 }
350 }
351
352 Ok(episodes)
353 }
354
355 pub fn create_skill(
359 &self,
360 name: String,
361 description: String,
362 parameters: HashMap<String, String>,
363 examples: Vec<String>,
364 ) -> Result<String> {
365 let id = uuid::Uuid::new_v4().to_string();
366
367 let embedding = self.generate_text_embedding(&description)?;
369
370 let skill = Skill {
371 id: id.clone(),
372 name,
373 description,
374 parameters,
375 examples,
376 embedding: embedding.clone(),
377 usage_count: 0,
378 success_rate: 0.0,
379 created_at: chrono::Utc::now().timestamp(),
380 updated_at: chrono::Utc::now().timestamp(),
381 };
382
383 let write_txn = self.db.begin_write()?;
385 {
386 let mut table = write_txn.open_table(SKILLS_TABLE)?;
387 let data = bincode::encode_to_vec(&skill, bincode::config::standard())
388 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
389 table.insert(id.as_str(), data.as_slice())?;
390 }
391 write_txn.commit()?;
392
393 self.vector_db.insert(VectorEntry {
395 id: Some(format!("skill_{}", id)),
396 vector: embedding,
397 metadata: Some({
398 let mut meta = HashMap::new();
399 meta.insert("type".to_string(), serde_json::json!("skill"));
400 meta.insert("skill_id".to_string(), serde_json::json!(id.clone()));
401 meta
402 }),
403 })?;
404
405 Ok(id)
406 }
407
408 pub fn search_skills(&self, query_description: &str, k: usize) -> Result<Vec<Skill>> {
410 let query_embedding = self.generate_text_embedding(query_description)?;
411
412 let results = self.vector_db.search(SearchQuery {
413 vector: query_embedding,
414 k,
415 filter: Some({
416 let mut filter = HashMap::new();
417 filter.insert("type".to_string(), serde_json::json!("skill"));
418 filter
419 }),
420 ef_search: None,
421 })?;
422
423 let mut skills = Vec::new();
424 let read_txn = self.db.begin_read()?;
425 let table = read_txn.open_table(SKILLS_TABLE)?;
426
427 for result in results {
428 if let Some(metadata) = result.metadata {
429 if let Some(skill_id) = metadata.get("skill_id") {
430 let id = skill_id.as_str().unwrap();
431 if let Some(data) = table.get(id)? {
432 let (skill, _): (Skill, usize) =
433 bincode::decode_from_slice(data.value(), bincode::config::standard())
434 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
435 skills.push(skill);
436 }
437 }
438 }
439 }
440
441 Ok(skills)
442 }
443
444 pub fn auto_consolidate(
446 &self,
447 action_sequences: Vec<Vec<String>>,
448 success_threshold: usize,
449 ) -> Result<Vec<String>> {
450 let mut skill_ids = Vec::new();
451
452 for sequence in action_sequences {
454 if sequence.len() >= success_threshold {
455 let description = format!("Skill: {}", sequence.join(" -> "));
456 let skill_id = self.create_skill(
457 format!("Auto-Skill-{}", uuid::Uuid::new_v4()),
458 description,
459 HashMap::new(),
460 sequence.clone(),
461 )?;
462 skill_ids.push(skill_id);
463 }
464 }
465
466 Ok(skill_ids)
467 }
468
469 pub fn add_causal_edge(
473 &self,
474 causes: Vec<String>,
475 effects: Vec<String>,
476 confidence: f64,
477 context: String,
478 ) -> Result<String> {
479 let id = uuid::Uuid::new_v4().to_string();
480
481 let embedding = self.generate_text_embedding(&context)?;
483
484 let edge = CausalEdge {
485 id: id.clone(),
486 causes,
487 effects,
488 confidence,
489 context,
490 embedding: embedding.clone(),
491 observations: 1,
492 timestamp: chrono::Utc::now().timestamp(),
493 };
494
495 let write_txn = self.db.begin_write()?;
497 {
498 let mut table = write_txn.open_table(CAUSAL_TABLE)?;
499 let data = bincode::encode_to_vec(&edge, bincode::config::standard())
500 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
501 table.insert(id.as_str(), data.as_slice())?;
502 }
503 write_txn.commit()?;
504
505 self.vector_db.insert(VectorEntry {
507 id: Some(format!("causal_{}", id)),
508 vector: embedding,
509 metadata: Some({
510 let mut meta = HashMap::new();
511 meta.insert("type".to_string(), serde_json::json!("causal"));
512 meta.insert("causal_id".to_string(), serde_json::json!(id.clone()));
513 meta.insert("confidence".to_string(), serde_json::json!(confidence));
514 meta
515 }),
516 })?;
517
518 Ok(id)
519 }
520
521 pub fn query_with_utility(
523 &self,
524 query: &str,
525 k: usize,
526 alpha: f64,
527 beta: f64,
528 gamma: f64,
529 ) -> Result<Vec<UtilitySearchResult>> {
530 let start_time = std::time::Instant::now();
531 let query_embedding = self.generate_text_embedding(query)?;
532
533 let results = self.vector_db.search(SearchQuery {
535 vector: query_embedding,
536 k: k * 2, filter: Some({
538 let mut filter = HashMap::new();
539 filter.insert("type".to_string(), serde_json::json!("causal"));
540 filter
541 }),
542 ef_search: None,
543 })?;
544
545 let mut utility_results = Vec::new();
546
547 for result in results {
548 let similarity_score = 1.0 / (1.0 + result.score as f64); let causal_uplift = if let Some(ref metadata) = result.metadata {
552 metadata
553 .get("confidence")
554 .and_then(|v| v.as_f64())
555 .unwrap_or(0.0)
556 } else {
557 0.0
558 };
559
560 let latency = start_time.elapsed().as_secs_f64();
561 let latency_penalty = latency * gamma;
562
563 let utility_score = alpha * similarity_score + beta * causal_uplift - latency_penalty;
565
566 utility_results.push(UtilitySearchResult {
567 result,
568 utility_score,
569 similarity_score,
570 causal_uplift,
571 latency_penalty,
572 });
573 }
574
575 utility_results.sort_by(|a, b| b.utility_score.partial_cmp(&a.utility_score).unwrap());
577 utility_results.truncate(k);
578
579 Ok(utility_results)
580 }
581
582 pub fn start_session(
586 &self,
587 algorithm: String,
588 state_dim: usize,
589 action_dim: usize,
590 ) -> Result<String> {
591 let id = uuid::Uuid::new_v4().to_string();
592
593 let session = LearningSession {
594 id: id.clone(),
595 algorithm,
596 state_dim,
597 action_dim,
598 experiences: Vec::new(),
599 model_params: None,
600 created_at: chrono::Utc::now().timestamp(),
601 updated_at: chrono::Utc::now().timestamp(),
602 };
603
604 let write_txn = self.db.begin_write()?;
605 {
606 let mut table = write_txn.open_table(LEARNING_TABLE)?;
607 let data = bincode::encode_to_vec(&session, bincode::config::standard())
608 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
609 table.insert(id.as_str(), data.as_slice())?;
610 }
611 write_txn.commit()?;
612
613 Ok(id)
614 }
615
616 pub fn add_experience(
618 &self,
619 session_id: &str,
620 state: Vec<f32>,
621 action: Vec<f32>,
622 reward: f64,
623 next_state: Vec<f32>,
624 done: bool,
625 ) -> Result<()> {
626 let read_txn = self.db.begin_read()?;
627 let table = read_txn.open_table(LEARNING_TABLE)?;
628
629 let data = table
630 .get(session_id)?
631 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
632
633 let (mut session, _): (LearningSession, usize) =
634 bincode::decode_from_slice(data.value(), bincode::config::standard())
635 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
636
637 drop(table);
638 drop(read_txn);
639
640 session.experiences.push(Experience {
642 state,
643 action,
644 reward,
645 next_state,
646 done,
647 timestamp: chrono::Utc::now().timestamp(),
648 });
649 session.updated_at = chrono::Utc::now().timestamp();
650
651 let write_txn = self.db.begin_write()?;
653 {
654 let mut table = write_txn.open_table(LEARNING_TABLE)?;
655 let data = bincode::encode_to_vec(&session, bincode::config::standard())
656 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
657 table.insert(session_id, data.as_slice())?;
658 }
659 write_txn.commit()?;
660
661 Ok(())
662 }
663
664 pub fn predict_with_confidence(&self, session_id: &str, state: Vec<f32>) -> Result<Prediction> {
666 let read_txn = self.db.begin_read()?;
667 let table = read_txn.open_table(LEARNING_TABLE)?;
668
669 let data = table
670 .get(session_id)?
671 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
672
673 let (session, _): (LearningSession, usize) =
674 bincode::decode_from_slice(data.value(), bincode::config::standard())
675 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
676
677 let mut similar_actions = Vec::new();
679 let mut rewards = Vec::new();
680
681 for exp in &session.experiences {
682 let distance = euclidean_distance(&state, &exp.state);
683 if distance < 1.0 {
684 similar_actions.push(exp.action.clone());
686 rewards.push(exp.reward);
687 }
688 }
689
690 if similar_actions.is_empty() {
691 return Ok(Prediction {
693 action: vec![0.0; session.action_dim],
694 confidence_lower: 0.0,
695 confidence_upper: 0.0,
696 mean_confidence: 0.0,
697 });
698 }
699
700 let total_reward: f64 = rewards.iter().sum();
702 let mut action = vec![0.0; session.action_dim];
703
704 for (act, reward) in similar_actions.iter().zip(rewards.iter()) {
705 let weight = reward / total_reward;
706 for (i, val) in act.iter().enumerate() {
707 action[i] += val * weight as f32;
708 }
709 }
710
711 let mean_reward = total_reward / rewards.len() as f64;
713 let std_dev = calculate_std_dev(&rewards, mean_reward);
714
715 Ok(Prediction {
716 action,
717 confidence_lower: mean_reward - 1.96 * std_dev,
718 confidence_upper: mean_reward + 1.96 * std_dev,
719 mean_confidence: mean_reward,
720 })
721 }
722
723 pub fn get_session(&self, session_id: &str) -> Result<Option<LearningSession>> {
725 let read_txn = self.db.begin_read()?;
726 let table = read_txn.open_table(LEARNING_TABLE)?;
727
728 if let Some(data) = table.get(session_id)? {
729 let (session, _): (LearningSession, usize) =
730 bincode::decode_from_slice(data.value(), bincode::config::standard())
731 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
732 Ok(Some(session))
733 } else {
734 Ok(None)
735 }
736 }
737
738 fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
761 self.embedding_provider.embed(text)
762 }
763}
764
765fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
767 a.iter()
768 .zip(b.iter())
769 .map(|(x, y)| (x - y).powi(2))
770 .sum::<f32>()
771 .sqrt()
772}
773
774fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
775 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
776 variance.sqrt()
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782 use tempfile::tempdir;
783
784 fn create_test_db() -> Result<AgenticDB> {
785 let dir = tempdir().unwrap();
786 let mut options = DbOptions::default();
787 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
788 options.dimensions = 128;
789 AgenticDB::new(options)
790 }
791
792 #[test]
793 fn test_reflexion_episode() -> Result<()> {
794 let db = create_test_db()?;
795
796 let id = db.store_episode(
797 "Solve math problem".to_string(),
798 vec!["read problem".to_string(), "calculate".to_string()],
799 vec!["got 42".to_string()],
800 "Should have shown work".to_string(),
801 )?;
802
803 let episodes = db.retrieve_similar_episodes("math problem solving", 5)?;
804 assert!(!episodes.is_empty());
805 assert_eq!(episodes[0].id, id);
806
807 Ok(())
808 }
809
810 #[test]
811 fn test_skill_library() -> Result<()> {
812 let db = create_test_db()?;
813
814 let mut params = HashMap::new();
815 params.insert("input".to_string(), "string".to_string());
816
817 let skill_id = db.create_skill(
818 "Parse JSON".to_string(),
819 "Parse JSON from string".to_string(),
820 params,
821 vec!["json.parse()".to_string()],
822 )?;
823
824 let skills = db.search_skills("parse json data", 5)?;
825 assert!(!skills.is_empty());
826
827 Ok(())
828 }
829
830 #[test]
831 fn test_causal_edge() -> Result<()> {
832 let db = create_test_db()?;
833
834 let edge_id = db.add_causal_edge(
835 vec!["rain".to_string()],
836 vec!["wet ground".to_string()],
837 0.95,
838 "Weather observation".to_string(),
839 )?;
840
841 let results = db.query_with_utility("weather patterns", 5, 0.7, 0.2, 0.1)?;
842 assert!(!results.is_empty());
843
844 Ok(())
845 }
846
847 #[test]
848 fn test_learning_session() -> Result<()> {
849 let db = create_test_db()?;
850
851 let session_id = db.start_session("Q-Learning".to_string(), 4, 2)?;
852
853 db.add_experience(
854 &session_id,
855 vec![1.0, 0.0, 0.0, 0.0],
856 vec![1.0, 0.0],
857 1.0,
858 vec![0.0, 1.0, 0.0, 0.0],
859 false,
860 )?;
861
862 let prediction = db.predict_with_confidence(&session_id, vec![1.0, 0.0, 0.0, 0.0])?;
863 assert_eq!(prediction.action.len(), 2);
864
865 Ok(())
866 }
867
868 #[test]
869 fn test_auto_consolidate() -> Result<()> {
870 let db = create_test_db()?;
871
872 let sequences = vec![
873 vec![
874 "step1".to_string(),
875 "step2".to_string(),
876 "step3".to_string(),
877 ],
878 vec![
879 "action1".to_string(),
880 "action2".to_string(),
881 "action3".to_string(),
882 ],
883 ];
884
885 let skill_ids = db.auto_consolidate(sequences, 3)?;
886 assert_eq!(skill_ids.len(), 2);
887
888 Ok(())
889 }
890}