1use crate::embeddings::{BoxedEmbeddingProvider, EmbeddingProvider, HashEmbedding};
28use crate::error::{Result, RuvectorError};
29use crate::types::*;
30use crate::vector_db::VectorDB;
31use parking_lot::RwLock;
32use redb::{Database, TableDefinition};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::path::Path;
36use std::sync::Arc;
37
38const REFLEXION_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("reflexion_episodes");
40const SKILLS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("skills_library");
41const CAUSAL_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("causal_edges");
42const LEARNING_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("learning_sessions");
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ReflexionEpisode {
48 pub id: String,
49 pub task: String,
50 pub actions: Vec<String>,
51 pub observations: Vec<String>,
52 pub critique: String,
53 pub embedding: Vec<f32>,
54 pub timestamp: i64,
55 pub metadata: Option<HashMap<String, serde_json::Value>>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
60pub struct Skill {
61 pub id: String,
62 pub name: String,
63 pub description: String,
64 pub parameters: HashMap<String, String>,
65 pub examples: Vec<String>,
66 pub embedding: Vec<f32>,
67 pub usage_count: usize,
68 pub success_rate: f64,
69 pub created_at: i64,
70 pub updated_at: i64,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
75pub struct CausalEdge {
76 pub id: String,
77 pub causes: Vec<String>, pub effects: Vec<String>, pub confidence: f64,
80 pub context: String,
81 pub embedding: Vec<f32>,
82 pub observations: usize,
83 pub timestamp: i64,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
88pub struct LearningSession {
89 pub id: String,
90 pub algorithm: String, pub state_dim: usize,
92 pub action_dim: usize,
93 pub experiences: Vec<Experience>,
94 pub model_params: Option<Vec<u8>>, pub created_at: i64,
96 pub updated_at: i64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
101pub struct Experience {
102 pub state: Vec<f32>,
103 pub action: Vec<f32>,
104 pub reward: f64,
105 pub next_state: Vec<f32>,
106 pub done: bool,
107 pub timestamp: i64,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
112pub struct Prediction {
113 pub action: Vec<f32>,
114 pub confidence_lower: f64,
115 pub confidence_upper: f64,
116 pub mean_confidence: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct UtilitySearchResult {
122 pub result: SearchResult,
123 pub utility_score: f64,
124 pub similarity_score: f64,
125 pub causal_uplift: f64,
126 pub latency_penalty: f64,
127}
128
129pub struct AgenticDB {
131 vector_db: Arc<VectorDB>,
132 db: Arc<Database>,
133 dimensions: usize,
134 embedding_provider: BoxedEmbeddingProvider,
135}
136
137impl AgenticDB {
138 pub fn new(options: DbOptions) -> Result<Self> {
140 let embedding_provider = Arc::new(HashEmbedding::new(options.dimensions));
141 Self::with_embedding_provider(options, embedding_provider)
142 }
143
144 pub fn with_embedding_provider(
182 options: DbOptions,
183 embedding_provider: BoxedEmbeddingProvider,
184 ) -> Result<Self> {
185 if options.dimensions != embedding_provider.dimensions() {
187 return Err(RuvectorError::InvalidDimension(format!(
188 "Options dimensions ({}) do not match embedding provider dimensions ({})",
189 options.dimensions,
190 embedding_provider.dimensions()
191 )));
192 }
193
194 let vector_db = Arc::new(VectorDB::new(options.clone())?);
196
197 let agentic_path = format!("{}.agentic", options.storage_path);
199 let db = Arc::new(Database::create(&agentic_path)?);
200
201 let write_txn = db.begin_write()?;
203 {
204 let _ = write_txn.open_table(REFLEXION_TABLE)?;
205 let _ = write_txn.open_table(SKILLS_TABLE)?;
206 let _ = write_txn.open_table(CAUSAL_TABLE)?;
207 let _ = write_txn.open_table(LEARNING_TABLE)?;
208 }
209 write_txn.commit()?;
210
211 Ok(Self {
212 vector_db,
213 db,
214 dimensions: options.dimensions,
215 embedding_provider,
216 })
217 }
218
219 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
221 let mut options = DbOptions::default();
222 options.dimensions = dimensions;
223 Self::new(options)
224 }
225
226 pub fn embedding_provider_name(&self) -> &str {
228 self.embedding_provider.name()
229 }
230
231 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
235 self.vector_db.insert(entry)
236 }
237
238 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
240 self.vector_db.insert_batch(entries)
241 }
242
243 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
245 self.vector_db.search(query)
246 }
247
248 pub fn delete(&self, id: &str) -> Result<bool> {
250 self.vector_db.delete(id)
251 }
252
253 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
255 self.vector_db.get(id)
256 }
257
258 pub fn store_episode(
262 &self,
263 task: String,
264 actions: Vec<String>,
265 observations: Vec<String>,
266 critique: String,
267 ) -> Result<String> {
268 let id = uuid::Uuid::new_v4().to_string();
269
270 let embedding = self.generate_text_embedding(&critique)?;
272
273 let episode = ReflexionEpisode {
274 id: id.clone(),
275 task,
276 actions,
277 observations,
278 critique,
279 embedding: embedding.clone(),
280 timestamp: chrono::Utc::now().timestamp(),
281 metadata: None,
282 };
283
284 let write_txn = self.db.begin_write()?;
286 {
287 let mut table = write_txn.open_table(REFLEXION_TABLE)?;
288 let json = serde_json::to_vec(&episode)
290 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
291 table.insert(id.as_str(), json.as_slice())?;
292 }
293 write_txn.commit()?;
294
295 self.vector_db.insert(VectorEntry {
297 id: Some(format!("reflexion_{}", id)),
298 vector: embedding,
299 metadata: Some({
300 let mut meta = HashMap::new();
301 meta.insert("type".to_string(), serde_json::json!("reflexion"));
302 meta.insert("episode_id".to_string(), serde_json::json!(id.clone()));
303 meta
304 }),
305 })?;
306
307 Ok(id)
308 }
309
310 pub fn retrieve_similar_episodes(
312 &self,
313 query: &str,
314 k: usize,
315 ) -> Result<Vec<ReflexionEpisode>> {
316 let query_embedding = self.generate_text_embedding(query)?;
318
319 let results = self.vector_db.search(SearchQuery {
321 vector: query_embedding,
322 k,
323 filter: Some({
324 let mut filter = HashMap::new();
325 filter.insert("type".to_string(), serde_json::json!("reflexion"));
326 filter
327 }),
328 ef_search: None,
329 })?;
330
331 let mut episodes = Vec::new();
333 let read_txn = self.db.begin_read()?;
334 let table = read_txn.open_table(REFLEXION_TABLE)?;
335
336 for result in results {
337 if let Some(metadata) = result.metadata {
338 if let Some(episode_id) = metadata.get("episode_id") {
339 let id = episode_id.as_str().unwrap();
340 if let Some(data) = table.get(id)? {
341 let episode: ReflexionEpisode = serde_json::from_slice(data.value())
343 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
344 episodes.push(episode);
345 }
346 }
347 }
348 }
349
350 Ok(episodes)
351 }
352
353 pub fn create_skill(
357 &self,
358 name: String,
359 description: String,
360 parameters: HashMap<String, String>,
361 examples: Vec<String>,
362 ) -> Result<String> {
363 let id = uuid::Uuid::new_v4().to_string();
364
365 let embedding = self.generate_text_embedding(&description)?;
367
368 let skill = Skill {
369 id: id.clone(),
370 name,
371 description,
372 parameters,
373 examples,
374 embedding: embedding.clone(),
375 usage_count: 0,
376 success_rate: 0.0,
377 created_at: chrono::Utc::now().timestamp(),
378 updated_at: chrono::Utc::now().timestamp(),
379 };
380
381 let write_txn = self.db.begin_write()?;
383 {
384 let mut table = write_txn.open_table(SKILLS_TABLE)?;
385 let data = bincode::encode_to_vec(&skill, bincode::config::standard())
386 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
387 table.insert(id.as_str(), data.as_slice())?;
388 }
389 write_txn.commit()?;
390
391 self.vector_db.insert(VectorEntry {
393 id: Some(format!("skill_{}", id)),
394 vector: embedding,
395 metadata: Some({
396 let mut meta = HashMap::new();
397 meta.insert("type".to_string(), serde_json::json!("skill"));
398 meta.insert("skill_id".to_string(), serde_json::json!(id.clone()));
399 meta
400 }),
401 })?;
402
403 Ok(id)
404 }
405
406 pub fn search_skills(&self, query_description: &str, k: usize) -> Result<Vec<Skill>> {
408 let query_embedding = self.generate_text_embedding(query_description)?;
409
410 let results = self.vector_db.search(SearchQuery {
411 vector: query_embedding,
412 k,
413 filter: Some({
414 let mut filter = HashMap::new();
415 filter.insert("type".to_string(), serde_json::json!("skill"));
416 filter
417 }),
418 ef_search: None,
419 })?;
420
421 let mut skills = Vec::new();
422 let read_txn = self.db.begin_read()?;
423 let table = read_txn.open_table(SKILLS_TABLE)?;
424
425 for result in results {
426 if let Some(metadata) = result.metadata {
427 if let Some(skill_id) = metadata.get("skill_id") {
428 let id = skill_id.as_str().unwrap();
429 if let Some(data) = table.get(id)? {
430 let (skill, _): (Skill, usize) =
431 bincode::decode_from_slice(data.value(), bincode::config::standard())
432 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
433 skills.push(skill);
434 }
435 }
436 }
437 }
438
439 Ok(skills)
440 }
441
442 pub fn auto_consolidate(
444 &self,
445 action_sequences: Vec<Vec<String>>,
446 success_threshold: usize,
447 ) -> Result<Vec<String>> {
448 let mut skill_ids = Vec::new();
449
450 for sequence in action_sequences {
452 if sequence.len() >= success_threshold {
453 let description = format!("Skill: {}", sequence.join(" -> "));
454 let skill_id = self.create_skill(
455 format!("Auto-Skill-{}", uuid::Uuid::new_v4()),
456 description,
457 HashMap::new(),
458 sequence.clone(),
459 )?;
460 skill_ids.push(skill_id);
461 }
462 }
463
464 Ok(skill_ids)
465 }
466
467 pub fn add_causal_edge(
471 &self,
472 causes: Vec<String>,
473 effects: Vec<String>,
474 confidence: f64,
475 context: String,
476 ) -> Result<String> {
477 let id = uuid::Uuid::new_v4().to_string();
478
479 let embedding = self.generate_text_embedding(&context)?;
481
482 let edge = CausalEdge {
483 id: id.clone(),
484 causes,
485 effects,
486 confidence,
487 context,
488 embedding: embedding.clone(),
489 observations: 1,
490 timestamp: chrono::Utc::now().timestamp(),
491 };
492
493 let write_txn = self.db.begin_write()?;
495 {
496 let mut table = write_txn.open_table(CAUSAL_TABLE)?;
497 let data = bincode::encode_to_vec(&edge, bincode::config::standard())
498 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
499 table.insert(id.as_str(), data.as_slice())?;
500 }
501 write_txn.commit()?;
502
503 self.vector_db.insert(VectorEntry {
505 id: Some(format!("causal_{}", id)),
506 vector: embedding,
507 metadata: Some({
508 let mut meta = HashMap::new();
509 meta.insert("type".to_string(), serde_json::json!("causal"));
510 meta.insert("causal_id".to_string(), serde_json::json!(id.clone()));
511 meta.insert("confidence".to_string(), serde_json::json!(confidence));
512 meta
513 }),
514 })?;
515
516 Ok(id)
517 }
518
519 pub fn query_with_utility(
521 &self,
522 query: &str,
523 k: usize,
524 alpha: f64,
525 beta: f64,
526 gamma: f64,
527 ) -> Result<Vec<UtilitySearchResult>> {
528 let start_time = std::time::Instant::now();
529 let query_embedding = self.generate_text_embedding(query)?;
530
531 let results = self.vector_db.search(SearchQuery {
533 vector: query_embedding,
534 k: k * 2, filter: Some({
536 let mut filter = HashMap::new();
537 filter.insert("type".to_string(), serde_json::json!("causal"));
538 filter
539 }),
540 ef_search: None,
541 })?;
542
543 let mut utility_results = Vec::new();
544
545 for result in results {
546 let similarity_score = 1.0 / (1.0 + result.score as f64); let causal_uplift = if let Some(ref metadata) = result.metadata {
550 metadata
551 .get("confidence")
552 .and_then(|v| v.as_f64())
553 .unwrap_or(0.0)
554 } else {
555 0.0
556 };
557
558 let latency = start_time.elapsed().as_secs_f64();
559 let latency_penalty = latency * gamma;
560
561 let utility_score = alpha * similarity_score + beta * causal_uplift - latency_penalty;
563
564 utility_results.push(UtilitySearchResult {
565 result,
566 utility_score,
567 similarity_score,
568 causal_uplift,
569 latency_penalty,
570 });
571 }
572
573 utility_results.sort_by(|a, b| b.utility_score.partial_cmp(&a.utility_score).unwrap());
575 utility_results.truncate(k);
576
577 Ok(utility_results)
578 }
579
580 pub fn start_session(
584 &self,
585 algorithm: String,
586 state_dim: usize,
587 action_dim: usize,
588 ) -> Result<String> {
589 let id = uuid::Uuid::new_v4().to_string();
590
591 let session = LearningSession {
592 id: id.clone(),
593 algorithm,
594 state_dim,
595 action_dim,
596 experiences: Vec::new(),
597 model_params: None,
598 created_at: chrono::Utc::now().timestamp(),
599 updated_at: chrono::Utc::now().timestamp(),
600 };
601
602 let write_txn = self.db.begin_write()?;
603 {
604 let mut table = write_txn.open_table(LEARNING_TABLE)?;
605 let data = bincode::encode_to_vec(&session, bincode::config::standard())
606 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
607 table.insert(id.as_str(), data.as_slice())?;
608 }
609 write_txn.commit()?;
610
611 Ok(id)
612 }
613
614 pub fn add_experience(
616 &self,
617 session_id: &str,
618 state: Vec<f32>,
619 action: Vec<f32>,
620 reward: f64,
621 next_state: Vec<f32>,
622 done: bool,
623 ) -> Result<()> {
624 let read_txn = self.db.begin_read()?;
625 let table = read_txn.open_table(LEARNING_TABLE)?;
626
627 let data = table
628 .get(session_id)?
629 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
630
631 let (mut session, _): (LearningSession, usize) =
632 bincode::decode_from_slice(data.value(), bincode::config::standard())
633 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
634
635 drop(table);
636 drop(read_txn);
637
638 session.experiences.push(Experience {
640 state,
641 action,
642 reward,
643 next_state,
644 done,
645 timestamp: chrono::Utc::now().timestamp(),
646 });
647 session.updated_at = chrono::Utc::now().timestamp();
648
649 let write_txn = self.db.begin_write()?;
651 {
652 let mut table = write_txn.open_table(LEARNING_TABLE)?;
653 let data = bincode::encode_to_vec(&session, bincode::config::standard())
654 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
655 table.insert(session_id, data.as_slice())?;
656 }
657 write_txn.commit()?;
658
659 Ok(())
660 }
661
662 pub fn predict_with_confidence(&self, session_id: &str, state: Vec<f32>) -> Result<Prediction> {
664 let read_txn = self.db.begin_read()?;
665 let table = read_txn.open_table(LEARNING_TABLE)?;
666
667 let data = table
668 .get(session_id)?
669 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
670
671 let (session, _): (LearningSession, usize) =
672 bincode::decode_from_slice(data.value(), bincode::config::standard())
673 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
674
675 let mut similar_actions = Vec::new();
677 let mut rewards = Vec::new();
678
679 for exp in &session.experiences {
680 let distance = euclidean_distance(&state, &exp.state);
681 if distance < 1.0 {
682 similar_actions.push(exp.action.clone());
684 rewards.push(exp.reward);
685 }
686 }
687
688 if similar_actions.is_empty() {
689 return Ok(Prediction {
691 action: vec![0.0; session.action_dim],
692 confidence_lower: 0.0,
693 confidence_upper: 0.0,
694 mean_confidence: 0.0,
695 });
696 }
697
698 let total_reward: f64 = rewards.iter().sum();
700 let mut action = vec![0.0; session.action_dim];
701
702 for (act, reward) in similar_actions.iter().zip(rewards.iter()) {
703 let weight = reward / total_reward;
704 for (i, val) in act.iter().enumerate() {
705 action[i] += val * weight as f32;
706 }
707 }
708
709 let mean_reward = total_reward / rewards.len() as f64;
711 let std_dev = calculate_std_dev(&rewards, mean_reward);
712
713 Ok(Prediction {
714 action,
715 confidence_lower: mean_reward - 1.96 * std_dev,
716 confidence_upper: mean_reward + 1.96 * std_dev,
717 mean_confidence: mean_reward,
718 })
719 }
720
721 pub fn get_session(&self, session_id: &str) -> Result<Option<LearningSession>> {
723 let read_txn = self.db.begin_read()?;
724 let table = read_txn.open_table(LEARNING_TABLE)?;
725
726 if let Some(data) = table.get(session_id)? {
727 let (session, _): (LearningSession, usize) =
728 bincode::decode_from_slice(data.value(), bincode::config::standard())
729 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
730 Ok(Some(session))
731 } else {
732 Ok(None)
733 }
734 }
735
736 fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
759 self.embedding_provider.embed(text)
760 }
761}
762
763fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
765 a.iter()
766 .zip(b.iter())
767 .map(|(x, y)| (x - y).powi(2))
768 .sum::<f32>()
769 .sqrt()
770}
771
772fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
773 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
774 variance.sqrt()
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780 use tempfile::tempdir;
781
782 fn create_test_db() -> Result<AgenticDB> {
783 let dir = tempdir().unwrap();
784 let mut options = DbOptions::default();
785 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
786 options.dimensions = 128;
787 AgenticDB::new(options)
788 }
789
790 #[test]
791 fn test_reflexion_episode() -> Result<()> {
792 let db = create_test_db()?;
793
794 let id = db.store_episode(
795 "Solve math problem".to_string(),
796 vec!["read problem".to_string(), "calculate".to_string()],
797 vec!["got 42".to_string()],
798 "Should have shown work".to_string(),
799 )?;
800
801 let episodes = db.retrieve_similar_episodes("math problem solving", 5)?;
802 assert!(!episodes.is_empty());
803 assert_eq!(episodes[0].id, id);
804
805 Ok(())
806 }
807
808 #[test]
809 fn test_skill_library() -> Result<()> {
810 let db = create_test_db()?;
811
812 let mut params = HashMap::new();
813 params.insert("input".to_string(), "string".to_string());
814
815 let skill_id = db.create_skill(
816 "Parse JSON".to_string(),
817 "Parse JSON from string".to_string(),
818 params,
819 vec!["json.parse()".to_string()],
820 )?;
821
822 let skills = db.search_skills("parse json data", 5)?;
823 assert!(!skills.is_empty());
824
825 Ok(())
826 }
827
828 #[test]
829 fn test_causal_edge() -> Result<()> {
830 let db = create_test_db()?;
831
832 let edge_id = db.add_causal_edge(
833 vec!["rain".to_string()],
834 vec!["wet ground".to_string()],
835 0.95,
836 "Weather observation".to_string(),
837 )?;
838
839 let results = db.query_with_utility("weather patterns", 5, 0.7, 0.2, 0.1)?;
840 assert!(!results.is_empty());
841
842 Ok(())
843 }
844
845 #[test]
846 fn test_learning_session() -> Result<()> {
847 let db = create_test_db()?;
848
849 let session_id = db.start_session("Q-Learning".to_string(), 4, 2)?;
850
851 db.add_experience(
852 &session_id,
853 vec![1.0, 0.0, 0.0, 0.0],
854 vec![1.0, 0.0],
855 1.0,
856 vec![0.0, 1.0, 0.0, 0.0],
857 false,
858 )?;
859
860 let prediction = db.predict_with_confidence(&session_id, vec![1.0, 0.0, 0.0, 0.0])?;
861 assert_eq!(prediction.action.len(), 2);
862
863 Ok(())
864 }
865
866 #[test]
867 fn test_auto_consolidate() -> Result<()> {
868 let db = create_test_db()?;
869
870 let sequences = vec![
871 vec![
872 "step1".to_string(),
873 "step2".to_string(),
874 "step3".to_string(),
875 ],
876 vec![
877 "action1".to_string(),
878 "action2".to_string(),
879 "action3".to_string(),
880 ],
881 ];
882
883 let skill_ids = db.auto_consolidate(sequences, 3)?;
884 assert_eq!(skill_ids.len(), 2);
885
886 Ok(())
887 }
888}