1use crate::embeddings::{BoxedEmbeddingProvider, HashEmbedding};
28use crate::error::{Result, RuvectorError};
29use crate::types::*;
30use crate::vector_db::VectorDB;
31use parking_lot::RwLock;
32use redb::{Database, TableDefinition};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::Arc;
36
37const REFLEXION_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("reflexion_episodes");
39const SKILLS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("skills_library");
40const CAUSAL_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("causal_edges");
41const LEARNING_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("learning_sessions");
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ReflexionEpisode {
47 pub id: String,
48 pub task: String,
49 pub actions: Vec<String>,
50 pub observations: Vec<String>,
51 pub critique: String,
52 pub embedding: Vec<f32>,
53 pub timestamp: i64,
54 pub metadata: Option<HashMap<String, serde_json::Value>>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
59pub struct Skill {
60 pub id: String,
61 pub name: String,
62 pub description: String,
63 pub parameters: HashMap<String, String>,
64 pub examples: Vec<String>,
65 pub embedding: Vec<f32>,
66 pub usage_count: usize,
67 pub success_rate: f64,
68 pub created_at: i64,
69 pub updated_at: i64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
74pub struct CausalEdge {
75 pub id: String,
76 pub causes: Vec<String>, pub effects: Vec<String>, pub confidence: f64,
79 pub context: String,
80 pub embedding: Vec<f32>,
81 pub observations: usize,
82 pub timestamp: i64,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
87pub struct LearningSession {
88 pub id: String,
89 pub algorithm: String, pub state_dim: usize,
91 pub action_dim: usize,
92 pub experiences: Vec<Experience>,
93 pub model_params: Option<Vec<u8>>, pub created_at: i64,
95 pub updated_at: i64,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
100pub struct Experience {
101 pub state: Vec<f32>,
102 pub action: Vec<f32>,
103 pub reward: f64,
104 pub next_state: Vec<f32>,
105 pub done: bool,
106 pub timestamp: i64,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
111pub struct Prediction {
112 pub action: Vec<f32>,
113 pub confidence_lower: f64,
114 pub confidence_upper: f64,
115 pub mean_confidence: f64,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct UtilitySearchResult {
121 pub result: SearchResult,
122 pub utility_score: f64,
123 pub similarity_score: f64,
124 pub causal_uplift: f64,
125 pub latency_penalty: f64,
126}
127
128pub struct AgenticDB {
130 vector_db: Arc<VectorDB>,
131 db: Arc<Database>,
132 _dimensions: usize,
133 embedding_provider: BoxedEmbeddingProvider,
134}
135
136impl AgenticDB {
137 pub fn new(options: DbOptions) -> Result<Self> {
139 let embedding_provider = Arc::new(HashEmbedding::new(options.dimensions));
140 Self::with_embedding_provider(options, embedding_provider)
141 }
142
143 pub fn with_embedding_provider(
181 options: DbOptions,
182 embedding_provider: BoxedEmbeddingProvider,
183 ) -> Result<Self> {
184 if options.dimensions != embedding_provider.dimensions() {
186 return Err(RuvectorError::InvalidDimension(format!(
187 "Options dimensions ({}) do not match embedding provider dimensions ({})",
188 options.dimensions,
189 embedding_provider.dimensions()
190 )));
191 }
192
193 let vector_db = Arc::new(VectorDB::new(options.clone())?);
195
196 let agentic_path = format!("{}.agentic", options.storage_path);
198 let db = Arc::new(Database::create(&agentic_path)?);
199
200 let write_txn = db.begin_write()?;
202 {
203 let _ = write_txn.open_table(REFLEXION_TABLE)?;
204 let _ = write_txn.open_table(SKILLS_TABLE)?;
205 let _ = write_txn.open_table(CAUSAL_TABLE)?;
206 let _ = write_txn.open_table(LEARNING_TABLE)?;
207 }
208 write_txn.commit()?;
209
210 Ok(Self {
211 vector_db,
212 db,
213 _dimensions: options.dimensions,
214 embedding_provider,
215 })
216 }
217
218 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
220 let options = DbOptions {
221 dimensions,
222 ..DbOptions::default()
223 };
224 Self::new(options)
225 }
226
227 pub fn embedding_provider_name(&self) -> &str {
229 self.embedding_provider.name()
230 }
231
232 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
236 self.vector_db.insert(entry)
237 }
238
239 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
241 self.vector_db.insert_batch(entries)
242 }
243
244 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
246 self.vector_db.search(query)
247 }
248
249 pub fn delete(&self, id: &str) -> Result<bool> {
251 self.vector_db.delete(id)
252 }
253
254 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
256 self.vector_db.get(id)
257 }
258
259 pub fn store_episode(
263 &self,
264 task: String,
265 actions: Vec<String>,
266 observations: Vec<String>,
267 critique: String,
268 ) -> Result<String> {
269 let id = uuid::Uuid::new_v4().to_string();
270
271 let embedding = self.generate_text_embedding(&critique)?;
273
274 let episode = ReflexionEpisode {
275 id: id.clone(),
276 task,
277 actions,
278 observations,
279 critique,
280 embedding: embedding.clone(),
281 timestamp: chrono::Utc::now().timestamp(),
282 metadata: None,
283 };
284
285 let write_txn = self.db.begin_write()?;
287 {
288 let mut table = write_txn.open_table(REFLEXION_TABLE)?;
289 let json = serde_json::to_vec(&episode)
291 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
292 table.insert(id.as_str(), json.as_slice())?;
293 }
294 write_txn.commit()?;
295
296 self.vector_db.insert(VectorEntry {
298 id: Some(format!("reflexion_{}", id)),
299 vector: embedding,
300 metadata: Some({
301 let mut meta = HashMap::new();
302 meta.insert("type".to_string(), serde_json::json!("reflexion"));
303 meta.insert("episode_id".to_string(), serde_json::json!(id.clone()));
304 meta
305 }),
306 })?;
307
308 Ok(id)
309 }
310
311 pub fn retrieve_similar_episodes(
313 &self,
314 query: &str,
315 k: usize,
316 ) -> Result<Vec<ReflexionEpisode>> {
317 let query_embedding = self.generate_text_embedding(query)?;
319
320 let results = self.vector_db.search(SearchQuery {
322 vector: query_embedding,
323 k,
324 filter: Some({
325 let mut filter = HashMap::new();
326 filter.insert("type".to_string(), serde_json::json!("reflexion"));
327 filter
328 }),
329 ef_search: None,
330 })?;
331
332 let mut episodes = Vec::new();
334 let read_txn = self.db.begin_read()?;
335 let table = read_txn.open_table(REFLEXION_TABLE)?;
336
337 for result in results {
338 if let Some(metadata) = result.metadata {
339 if let Some(episode_id) = metadata.get("episode_id") {
340 let id = episode_id.as_str().unwrap();
341 if let Some(data) = table.get(id)? {
342 let episode: ReflexionEpisode = serde_json::from_slice(data.value())
344 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
345 episodes.push(episode);
346 }
347 }
348 }
349 }
350
351 Ok(episodes)
352 }
353
354 pub fn create_skill(
358 &self,
359 name: String,
360 description: String,
361 parameters: HashMap<String, String>,
362 examples: Vec<String>,
363 ) -> Result<String> {
364 let id = uuid::Uuid::new_v4().to_string();
365
366 let embedding = self.generate_text_embedding(&description)?;
368
369 let skill = Skill {
370 id: id.clone(),
371 name,
372 description,
373 parameters,
374 examples,
375 embedding: embedding.clone(),
376 usage_count: 0,
377 success_rate: 0.0,
378 created_at: chrono::Utc::now().timestamp(),
379 updated_at: chrono::Utc::now().timestamp(),
380 };
381
382 let write_txn = self.db.begin_write()?;
384 {
385 let mut table = write_txn.open_table(SKILLS_TABLE)?;
386 let data = bincode::encode_to_vec(&skill, bincode::config::standard())
387 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
388 table.insert(id.as_str(), data.as_slice())?;
389 }
390 write_txn.commit()?;
391
392 self.vector_db.insert(VectorEntry {
394 id: Some(format!("skill_{}", id)),
395 vector: embedding,
396 metadata: Some({
397 let mut meta = HashMap::new();
398 meta.insert("type".to_string(), serde_json::json!("skill"));
399 meta.insert("skill_id".to_string(), serde_json::json!(id.clone()));
400 meta
401 }),
402 })?;
403
404 Ok(id)
405 }
406
407 pub fn search_skills(&self, query_description: &str, k: usize) -> Result<Vec<Skill>> {
409 let query_embedding = self.generate_text_embedding(query_description)?;
410
411 let results = self.vector_db.search(SearchQuery {
412 vector: query_embedding,
413 k,
414 filter: Some({
415 let mut filter = HashMap::new();
416 filter.insert("type".to_string(), serde_json::json!("skill"));
417 filter
418 }),
419 ef_search: None,
420 })?;
421
422 let mut skills = Vec::new();
423 let read_txn = self.db.begin_read()?;
424 let table = read_txn.open_table(SKILLS_TABLE)?;
425
426 for result in results {
427 if let Some(metadata) = result.metadata {
428 if let Some(skill_id) = metadata.get("skill_id") {
429 let id = skill_id.as_str().unwrap();
430 if let Some(data) = table.get(id)? {
431 let (skill, _): (Skill, usize) =
432 bincode::decode_from_slice(data.value(), bincode::config::standard())
433 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
434 skills.push(skill);
435 }
436 }
437 }
438 }
439
440 Ok(skills)
441 }
442
443 pub fn auto_consolidate(
445 &self,
446 action_sequences: Vec<Vec<String>>,
447 success_threshold: usize,
448 ) -> Result<Vec<String>> {
449 let mut skill_ids = Vec::new();
450
451 for sequence in action_sequences {
453 if sequence.len() >= success_threshold {
454 let description = format!("Skill: {}", sequence.join(" -> "));
455 let skill_id = self.create_skill(
456 format!("Auto-Skill-{}", uuid::Uuid::new_v4()),
457 description,
458 HashMap::new(),
459 sequence.clone(),
460 )?;
461 skill_ids.push(skill_id);
462 }
463 }
464
465 Ok(skill_ids)
466 }
467
468 pub fn add_causal_edge(
472 &self,
473 causes: Vec<String>,
474 effects: Vec<String>,
475 confidence: f64,
476 context: String,
477 ) -> Result<String> {
478 let id = uuid::Uuid::new_v4().to_string();
479
480 let embedding = self.generate_text_embedding(&context)?;
482
483 let edge = CausalEdge {
484 id: id.clone(),
485 causes,
486 effects,
487 confidence,
488 context,
489 embedding: embedding.clone(),
490 observations: 1,
491 timestamp: chrono::Utc::now().timestamp(),
492 };
493
494 let write_txn = self.db.begin_write()?;
496 {
497 let mut table = write_txn.open_table(CAUSAL_TABLE)?;
498 let data = bincode::encode_to_vec(&edge, bincode::config::standard())
499 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
500 table.insert(id.as_str(), data.as_slice())?;
501 }
502 write_txn.commit()?;
503
504 self.vector_db.insert(VectorEntry {
506 id: Some(format!("causal_{}", id)),
507 vector: embedding,
508 metadata: Some({
509 let mut meta = HashMap::new();
510 meta.insert("type".to_string(), serde_json::json!("causal"));
511 meta.insert("causal_id".to_string(), serde_json::json!(id.clone()));
512 meta.insert("confidence".to_string(), serde_json::json!(confidence));
513 meta
514 }),
515 })?;
516
517 Ok(id)
518 }
519
520 pub fn query_with_utility(
522 &self,
523 query: &str,
524 k: usize,
525 alpha: f64,
526 beta: f64,
527 gamma: f64,
528 ) -> Result<Vec<UtilitySearchResult>> {
529 let start_time = std::time::Instant::now();
530 let query_embedding = self.generate_text_embedding(query)?;
531
532 let results = self.vector_db.search(SearchQuery {
534 vector: query_embedding,
535 k: k * 2, filter: Some({
537 let mut filter = HashMap::new();
538 filter.insert("type".to_string(), serde_json::json!("causal"));
539 filter
540 }),
541 ef_search: None,
542 })?;
543
544 let mut utility_results = Vec::new();
545
546 for result in results {
547 let similarity_score = 1.0 / (1.0 + result.score as f64); let causal_uplift = if let Some(ref metadata) = result.metadata {
551 metadata
552 .get("confidence")
553 .and_then(|v| v.as_f64())
554 .unwrap_or(0.0)
555 } else {
556 0.0
557 };
558
559 let latency = start_time.elapsed().as_secs_f64();
560 let latency_penalty = latency * gamma;
561
562 let utility_score = alpha * similarity_score + beta * causal_uplift - latency_penalty;
564
565 utility_results.push(UtilitySearchResult {
566 result,
567 utility_score,
568 similarity_score,
569 causal_uplift,
570 latency_penalty,
571 });
572 }
573
574 utility_results.sort_by(|a, b| b.utility_score.partial_cmp(&a.utility_score).unwrap());
576 utility_results.truncate(k);
577
578 Ok(utility_results)
579 }
580
581 pub fn start_session(
585 &self,
586 algorithm: String,
587 state_dim: usize,
588 action_dim: usize,
589 ) -> Result<String> {
590 let id = uuid::Uuid::new_v4().to_string();
591
592 let session = LearningSession {
593 id: id.clone(),
594 algorithm,
595 state_dim,
596 action_dim,
597 experiences: Vec::new(),
598 model_params: None,
599 created_at: chrono::Utc::now().timestamp(),
600 updated_at: chrono::Utc::now().timestamp(),
601 };
602
603 let write_txn = self.db.begin_write()?;
604 {
605 let mut table = write_txn.open_table(LEARNING_TABLE)?;
606 let data = bincode::encode_to_vec(&session, bincode::config::standard())
607 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
608 table.insert(id.as_str(), data.as_slice())?;
609 }
610 write_txn.commit()?;
611
612 Ok(id)
613 }
614
615 pub fn add_experience(
617 &self,
618 session_id: &str,
619 state: Vec<f32>,
620 action: Vec<f32>,
621 reward: f64,
622 next_state: Vec<f32>,
623 done: bool,
624 ) -> Result<()> {
625 let read_txn = self.db.begin_read()?;
626 let table = read_txn.open_table(LEARNING_TABLE)?;
627
628 let data = table
629 .get(session_id)?
630 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
631
632 let (mut session, _): (LearningSession, usize) =
633 bincode::decode_from_slice(data.value(), bincode::config::standard())
634 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
635
636 drop(table);
637 drop(read_txn);
638
639 session.experiences.push(Experience {
641 state,
642 action,
643 reward,
644 next_state,
645 done,
646 timestamp: chrono::Utc::now().timestamp(),
647 });
648 session.updated_at = chrono::Utc::now().timestamp();
649
650 let write_txn = self.db.begin_write()?;
652 {
653 let mut table = write_txn.open_table(LEARNING_TABLE)?;
654 let data = bincode::encode_to_vec(&session, bincode::config::standard())
655 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
656 table.insert(session_id, data.as_slice())?;
657 }
658 write_txn.commit()?;
659
660 Ok(())
661 }
662
663 pub fn predict_with_confidence(&self, session_id: &str, state: Vec<f32>) -> Result<Prediction> {
665 let read_txn = self.db.begin_read()?;
666 let table = read_txn.open_table(LEARNING_TABLE)?;
667
668 let data = table
669 .get(session_id)?
670 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
671
672 let (session, _): (LearningSession, usize) =
673 bincode::decode_from_slice(data.value(), bincode::config::standard())
674 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
675
676 let mut similar_actions = Vec::new();
678 let mut rewards = Vec::new();
679
680 for exp in &session.experiences {
681 let distance = euclidean_distance(&state, &exp.state);
682 if distance < 1.0 {
683 similar_actions.push(exp.action.clone());
685 rewards.push(exp.reward);
686 }
687 }
688
689 if similar_actions.is_empty() {
690 return Ok(Prediction {
692 action: vec![0.0; session.action_dim],
693 confidence_lower: 0.0,
694 confidence_upper: 0.0,
695 mean_confidence: 0.0,
696 });
697 }
698
699 let total_reward: f64 = rewards.iter().sum();
701 let mut action = vec![0.0; session.action_dim];
702
703 for (act, reward) in similar_actions.iter().zip(rewards.iter()) {
704 let weight = reward / total_reward;
705 for (i, val) in act.iter().enumerate() {
706 action[i] += val * weight as f32;
707 }
708 }
709
710 let mean_reward = total_reward / rewards.len() as f64;
712 let std_dev = calculate_std_dev(&rewards, mean_reward);
713
714 Ok(Prediction {
715 action,
716 confidence_lower: mean_reward - 1.96 * std_dev,
717 confidence_upper: mean_reward + 1.96 * std_dev,
718 mean_confidence: mean_reward,
719 })
720 }
721
722 pub fn get_session(&self, session_id: &str) -> Result<Option<LearningSession>> {
724 let read_txn = self.db.begin_read()?;
725 let table = read_txn.open_table(LEARNING_TABLE)?;
726
727 if let Some(data) = table.get(session_id)? {
728 let (session, _): (LearningSession, usize) =
729 bincode::decode_from_slice(data.value(), bincode::config::standard())
730 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
731 Ok(Some(session))
732 } else {
733 Ok(None)
734 }
735 }
736
737 fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
760 self.embedding_provider.embed(text)
761 }
762}
763
764fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
766 a.iter()
767 .zip(b.iter())
768 .map(|(x, y)| (x - y).powi(2))
769 .sum::<f32>()
770 .sqrt()
771}
772
773fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
774 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
775 variance.sqrt()
776}
777
778pub struct PolicyMemoryStore<'a> {
792 db: &'a AgenticDB,
793}
794
795#[derive(Debug, Clone, Serialize, Deserialize)]
797pub struct PolicyAction {
798 pub action: String,
800 pub reward: f64,
802 pub q_value: f64,
804 pub state_embedding: Vec<f32>,
806 pub timestamp: i64,
808}
809
810#[derive(Debug, Clone, Serialize, Deserialize)]
812pub struct PolicyEntry {
813 pub id: String,
815 pub state_id: String,
817 pub action: PolicyAction,
819 pub metadata: Option<HashMap<String, serde_json::Value>>,
821}
822
823impl<'a> PolicyMemoryStore<'a> {
824 pub fn new(db: &'a AgenticDB) -> Self {
826 Self { db }
827 }
828
829 pub fn store_policy(
831 &self,
832 state_id: &str,
833 state_embedding: Vec<f32>,
834 action: &str,
835 reward: f64,
836 q_value: f64,
837 ) -> Result<String> {
838 let id = uuid::Uuid::new_v4().to_string();
839 let timestamp = chrono::Utc::now().timestamp();
840
841 let _entry = PolicyEntry {
842 id: id.clone(),
843 state_id: state_id.to_string(),
844 action: PolicyAction {
845 action: action.to_string(),
846 reward,
847 q_value,
848 state_embedding: state_embedding.clone(),
849 timestamp,
850 },
851 metadata: None,
852 };
853
854 self.db.vector_db.insert(VectorEntry {
856 id: Some(format!("policy_{}", id)),
857 vector: state_embedding,
858 metadata: Some({
859 let mut meta = HashMap::new();
860 meta.insert("type".to_string(), serde_json::json!("policy"));
861 meta.insert("policy_id".to_string(), serde_json::json!(id.clone()));
862 meta.insert("state_id".to_string(), serde_json::json!(state_id));
863 meta.insert("action".to_string(), serde_json::json!(action));
864 meta.insert("reward".to_string(), serde_json::json!(reward));
865 meta.insert("q_value".to_string(), serde_json::json!(q_value));
866 meta
867 }),
868 })?;
869
870 Ok(id)
871 }
872
873 pub fn retrieve_similar_states(
875 &self,
876 state_embedding: &[f32],
877 k: usize,
878 ) -> Result<Vec<PolicyEntry>> {
879 let results = self.db.vector_db.search(SearchQuery {
880 vector: state_embedding.to_vec(),
881 k,
882 filter: Some({
883 let mut filter = HashMap::new();
884 filter.insert("type".to_string(), serde_json::json!("policy"));
885 filter
886 }),
887 ef_search: None,
888 })?;
889
890 let mut entries = Vec::new();
891 for result in results {
892 if let Some(metadata) = result.metadata {
893 let policy_id = metadata
894 .get("policy_id")
895 .and_then(|v| v.as_str())
896 .unwrap_or("");
897 let state_id = metadata
898 .get("state_id")
899 .and_then(|v| v.as_str())
900 .unwrap_or("");
901 let action = metadata
902 .get("action")
903 .and_then(|v| v.as_str())
904 .unwrap_or("");
905 let reward = metadata
906 .get("reward")
907 .and_then(|v| v.as_f64())
908 .unwrap_or(0.0);
909 let q_value = metadata
910 .get("q_value")
911 .and_then(|v| v.as_f64())
912 .unwrap_or(0.0);
913
914 entries.push(PolicyEntry {
915 id: policy_id.to_string(),
916 state_id: state_id.to_string(),
917 action: PolicyAction {
918 action: action.to_string(),
919 reward,
920 q_value,
921 state_embedding: result.vector.unwrap_or_default(),
922 timestamp: 0,
923 },
924 metadata: None,
925 });
926 }
927 }
928
929 Ok(entries)
930 }
931
932 pub fn get_best_action(&self, state_embedding: &[f32], k: usize) -> Result<Option<String>> {
934 let similar = self.retrieve_similar_states(state_embedding, k)?;
935
936 similar
937 .into_iter()
938 .max_by(|a, b| a.action.q_value.partial_cmp(&b.action.q_value).unwrap())
939 .map(|entry| Ok(entry.action.action))
940 .transpose()
941 }
942
943 pub fn update_q_value(&self, policy_id: &str, new_q_value: f64) -> Result<()> {
949 let key = format!("policy_{}", policy_id);
950
951 let mut entry = self
953 .db
954 .vector_db
955 .get(&key)?
956 .ok_or_else(|| crate::error::RuvectorError::VectorNotFound(key.clone()))?;
957
958 let mut meta = entry.metadata.take().unwrap_or_default();
960 meta.insert("q_value".to_string(), serde_json::json!(new_q_value));
961 entry.metadata = Some(meta);
962
963 let _ = self.db.vector_db.delete(&key);
966 self.db.vector_db.insert(entry)?;
967 Ok(())
968 }
969}
970
971pub struct SessionStateIndex<'a> {
976 db: &'a AgenticDB,
977 session_id: String,
978 ttl_seconds: i64,
979}
980
981#[derive(Debug, Clone, Serialize, Deserialize)]
983pub struct SessionTurn {
984 pub id: String,
986 pub session_id: String,
988 pub turn_number: usize,
990 pub role: String,
992 pub content: String,
994 pub embedding: Vec<f32>,
996 pub timestamp: i64,
998 pub expires_at: i64,
1000}
1001
1002impl<'a> SessionStateIndex<'a> {
1003 pub fn new(db: &'a AgenticDB, session_id: &str, ttl_seconds: i64) -> Self {
1005 Self {
1006 db,
1007 session_id: session_id.to_string(),
1008 ttl_seconds,
1009 }
1010 }
1011
1012 pub fn add_turn(&self, turn_number: usize, role: &str, content: &str) -> Result<String> {
1014 let id = uuid::Uuid::new_v4().to_string();
1015 let timestamp = chrono::Utc::now().timestamp();
1016 let expires_at = timestamp + self.ttl_seconds;
1017
1018 let embedding = self.db.generate_text_embedding(content)?;
1020
1021 self.db.vector_db.insert(VectorEntry {
1023 id: Some(format!("session_{}_{}", self.session_id, id)),
1024 vector: embedding,
1025 metadata: Some({
1026 let mut meta = HashMap::new();
1027 meta.insert("type".to_string(), serde_json::json!("session_turn"));
1028 meta.insert(
1029 "session_id".to_string(),
1030 serde_json::json!(self.session_id.clone()),
1031 );
1032 meta.insert("turn_id".to_string(), serde_json::json!(id.clone()));
1033 meta.insert("turn_number".to_string(), serde_json::json!(turn_number));
1034 meta.insert("role".to_string(), serde_json::json!(role));
1035 meta.insert("content".to_string(), serde_json::json!(content));
1036 meta.insert("timestamp".to_string(), serde_json::json!(timestamp));
1037 meta.insert("expires_at".to_string(), serde_json::json!(expires_at));
1038 meta
1039 }),
1040 })?;
1041
1042 Ok(id)
1043 }
1044
1045 pub fn find_relevant_turns(&self, query: &str, k: usize) -> Result<Vec<SessionTurn>> {
1047 let query_embedding = self.db.generate_text_embedding(query)?;
1048 let current_time = chrono::Utc::now().timestamp();
1049
1050 let results = self.db.vector_db.search(SearchQuery {
1051 vector: query_embedding,
1052 k: k * 2, filter: Some({
1054 let mut filter = HashMap::new();
1055 filter.insert("type".to_string(), serde_json::json!("session_turn"));
1056 filter.insert(
1057 "session_id".to_string(),
1058 serde_json::json!(self.session_id.clone()),
1059 );
1060 filter
1061 }),
1062 ef_search: None,
1063 })?;
1064
1065 let mut turns = Vec::new();
1066 for result in results {
1067 if let Some(metadata) = result.metadata {
1068 let expires_at = metadata
1069 .get("expires_at")
1070 .and_then(|v| v.as_i64())
1071 .unwrap_or(0);
1072
1073 if expires_at < current_time {
1075 continue;
1076 }
1077
1078 turns.push(SessionTurn {
1079 id: metadata
1080 .get("turn_id")
1081 .and_then(|v| v.as_str())
1082 .unwrap_or("")
1083 .to_string(),
1084 session_id: self.session_id.clone(),
1085 turn_number: metadata
1086 .get("turn_number")
1087 .and_then(|v| v.as_u64())
1088 .unwrap_or(0) as usize,
1089 role: metadata
1090 .get("role")
1091 .and_then(|v| v.as_str())
1092 .unwrap_or("")
1093 .to_string(),
1094 content: metadata
1095 .get("content")
1096 .and_then(|v| v.as_str())
1097 .unwrap_or("")
1098 .to_string(),
1099 embedding: result.vector.unwrap_or_default(),
1100 timestamp: metadata
1101 .get("timestamp")
1102 .and_then(|v| v.as_i64())
1103 .unwrap_or(0),
1104 expires_at,
1105 });
1106
1107 if turns.len() >= k {
1108 break;
1109 }
1110 }
1111 }
1112
1113 Ok(turns)
1114 }
1115
1116 pub fn get_session_context(&self) -> Result<Vec<SessionTurn>> {
1118 let mut turns = self.find_relevant_turns("", 1000)?;
1119 turns.sort_by_key(|t| t.turn_number);
1120 Ok(turns)
1121 }
1122
1123 pub fn cleanup_expired(&self) -> Result<usize> {
1125 let current_time = chrono::Utc::now().timestamp();
1126 let all_turns = self.find_relevant_turns("", 10000)?;
1127 let mut deleted = 0;
1128
1129 for turn in all_turns {
1130 if turn.expires_at < current_time {
1131 let _ = self
1132 .db
1133 .vector_db
1134 .delete(&format!("session_{}_{}", self.session_id, turn.id));
1135 deleted += 1;
1136 }
1137 }
1138
1139 Ok(deleted)
1140 }
1141}
1142
1143pub struct WitnessLog<'a> {
1147 db: &'a AgenticDB,
1148 last_hash: RwLock<Option<String>>,
1149}
1150
1151#[derive(Debug, Clone, Serialize, Deserialize)]
1153pub struct WitnessEntry {
1154 pub id: String,
1156 pub prev_hash: Option<String>,
1158 pub hash: String,
1160 pub agent_id: String,
1162 pub action_type: String,
1164 pub details: String,
1166 pub embedding: Vec<f32>,
1168 pub timestamp: i64,
1170 pub metadata: Option<HashMap<String, serde_json::Value>>,
1172}
1173
1174impl<'a> WitnessLog<'a> {
1175 pub fn new(db: &'a AgenticDB) -> Self {
1177 Self {
1178 db,
1179 last_hash: RwLock::new(None),
1180 }
1181 }
1182
1183 fn compute_hash(
1185 prev_hash: &Option<String>,
1186 agent_id: &str,
1187 action_type: &str,
1188 details: &str,
1189 timestamp: i64,
1190 ) -> String {
1191 use std::collections::hash_map::DefaultHasher;
1192 use std::hash::{Hash, Hasher};
1193
1194 let mut hasher = DefaultHasher::new();
1195 if let Some(prev) = prev_hash {
1196 prev.hash(&mut hasher);
1197 }
1198 agent_id.hash(&mut hasher);
1199 action_type.hash(&mut hasher);
1200 details.hash(&mut hasher);
1201 timestamp.hash(&mut hasher);
1202 format!("{:016x}", hasher.finish())
1203 }
1204
1205 pub fn append(&self, agent_id: &str, action_type: &str, details: &str) -> Result<String> {
1207 let id = uuid::Uuid::new_v4().to_string();
1208 let timestamp = chrono::Utc::now().timestamp();
1209
1210 let prev_hash = self.last_hash.read().clone();
1212
1213 let hash = Self::compute_hash(&prev_hash, agent_id, action_type, details, timestamp);
1215
1216 let embedding = self
1218 .db
1219 .generate_text_embedding(&format!("{} {} {}", agent_id, action_type, details))?;
1220
1221 self.db.vector_db.insert(VectorEntry {
1223 id: Some(format!("witness_{}", id)),
1224 vector: embedding.clone(),
1225 metadata: Some({
1226 let mut meta = HashMap::new();
1227 meta.insert("type".to_string(), serde_json::json!("witness"));
1228 meta.insert("witness_id".to_string(), serde_json::json!(id.clone()));
1229 meta.insert("agent_id".to_string(), serde_json::json!(agent_id));
1230 meta.insert("action_type".to_string(), serde_json::json!(action_type));
1231 meta.insert("details".to_string(), serde_json::json!(details));
1232 meta.insert("timestamp".to_string(), serde_json::json!(timestamp));
1233 meta.insert("hash".to_string(), serde_json::json!(hash.clone()));
1234 if let Some(ref prev) = prev_hash {
1235 meta.insert("prev_hash".to_string(), serde_json::json!(prev));
1236 }
1237 meta
1238 }),
1239 })?;
1240
1241 *self.last_hash.write() = Some(hash.clone());
1243
1244 Ok(id)
1245 }
1246
1247 pub fn search(&self, query: &str, k: usize) -> Result<Vec<WitnessEntry>> {
1249 let query_embedding = self.db.generate_text_embedding(query)?;
1250
1251 let results = self.db.vector_db.search(SearchQuery {
1252 vector: query_embedding,
1253 k,
1254 filter: Some({
1255 let mut filter = HashMap::new();
1256 filter.insert("type".to_string(), serde_json::json!("witness"));
1257 filter
1258 }),
1259 ef_search: None,
1260 })?;
1261
1262 let mut entries = Vec::new();
1263 for result in results {
1264 if let Some(metadata) = result.metadata {
1265 entries.push(WitnessEntry {
1266 id: metadata
1267 .get("witness_id")
1268 .and_then(|v| v.as_str())
1269 .unwrap_or("")
1270 .to_string(),
1271 prev_hash: metadata
1272 .get("prev_hash")
1273 .and_then(|v| v.as_str())
1274 .map(|s| s.to_string()),
1275 hash: metadata
1276 .get("hash")
1277 .and_then(|v| v.as_str())
1278 .unwrap_or("")
1279 .to_string(),
1280 agent_id: metadata
1281 .get("agent_id")
1282 .and_then(|v| v.as_str())
1283 .unwrap_or("")
1284 .to_string(),
1285 action_type: metadata
1286 .get("action_type")
1287 .and_then(|v| v.as_str())
1288 .unwrap_or("")
1289 .to_string(),
1290 details: metadata
1291 .get("details")
1292 .and_then(|v| v.as_str())
1293 .unwrap_or("")
1294 .to_string(),
1295 embedding: result.vector.unwrap_or_default(),
1296 timestamp: metadata
1297 .get("timestamp")
1298 .and_then(|v| v.as_i64())
1299 .unwrap_or(0),
1300 metadata: None,
1301 });
1302 }
1303 }
1304
1305 Ok(entries)
1306 }
1307
1308 pub fn get_by_agent(&self, agent_id: &str, k: usize) -> Result<Vec<WitnessEntry>> {
1310 self.search(agent_id, k)
1312 }
1313
1314 pub fn verify_chain(&self) -> Result<bool> {
1316 let entries = self.search("", 10000)?;
1317
1318 let mut sorted_entries = entries;
1320 sorted_entries.sort_by_key(|e| e.timestamp);
1321
1322 for i in 1..sorted_entries.len() {
1324 let prev = &sorted_entries[i - 1];
1325 let curr = &sorted_entries[i];
1326
1327 if let Some(ref prev_hash) = curr.prev_hash {
1328 if prev_hash != &prev.hash {
1329 return Ok(false);
1330 }
1331 }
1332 }
1333
1334 Ok(true)
1335 }
1336}
1337
1338impl AgenticDB {
1339 pub fn policy_memory(&self) -> PolicyMemoryStore<'_> {
1341 PolicyMemoryStore::new(self)
1342 }
1343
1344 pub fn session_index(&self, session_id: &str, ttl_seconds: i64) -> SessionStateIndex<'_> {
1346 SessionStateIndex::new(self, session_id, ttl_seconds)
1347 }
1348
1349 pub fn witness_log(&self) -> WitnessLog<'_> {
1351 WitnessLog::new(self)
1352 }
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357 use super::*;
1358 use tempfile::tempdir;
1359
1360 fn create_test_db() -> Result<AgenticDB> {
1361 let dir = tempdir().unwrap();
1362 let mut options = DbOptions::default();
1363 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
1364 options.dimensions = 128;
1365 AgenticDB::new(options)
1366 }
1367
1368 #[test]
1369 fn test_reflexion_episode() -> Result<()> {
1370 let db = create_test_db()?;
1371
1372 let id = db.store_episode(
1373 "Solve math problem".to_string(),
1374 vec!["read problem".to_string(), "calculate".to_string()],
1375 vec!["got 42".to_string()],
1376 "Should have shown work".to_string(),
1377 )?;
1378
1379 let episodes = db.retrieve_similar_episodes("math problem solving", 5)?;
1380 assert!(!episodes.is_empty());
1381 assert_eq!(episodes[0].id, id);
1382
1383 Ok(())
1384 }
1385
1386 #[test]
1387 fn test_update_q_value_preserves_entry() -> Result<()> {
1388 let db = create_test_db()?;
1391 let pm = db.policy_memory();
1392 let emb = vec![0.1_f32; 128];
1393
1394 let id = pm.store_policy("state_1", emb.clone(), "shoot", 1.0, 0.5)?;
1395
1396 pm.update_q_value(&id, 0.9)?;
1398
1399 let hits = pm.retrieve_similar_states(&emb, 5)?;
1401 assert!(
1402 !hits.is_empty(),
1403 "policy was deleted by update_q_value (#562 regression)"
1404 );
1405 assert!(
1406 hits.iter().any(|p| (p.action.q_value - 0.9).abs() < 1e-9),
1407 "q_value not updated to 0.9; got {:?}",
1408 hits.iter().map(|p| p.action.q_value).collect::<Vec<_>>()
1409 );
1410
1411 assert!(pm.update_q_value("does_not_exist", 1.0).is_err());
1413
1414 Ok(())
1415 }
1416
1417 #[test]
1418 fn test_skill_library() -> Result<()> {
1419 let db = create_test_db()?;
1420
1421 let mut params = HashMap::new();
1422 params.insert("input".to_string(), "string".to_string());
1423
1424 let skill_id = db.create_skill(
1425 "Parse JSON".to_string(),
1426 "Parse JSON from string".to_string(),
1427 params,
1428 vec!["json.parse()".to_string()],
1429 )?;
1430
1431 let skills = db.search_skills("parse json data", 5)?;
1432 assert!(!skills.is_empty());
1433
1434 Ok(())
1435 }
1436
1437 #[test]
1438 fn test_causal_edge() -> Result<()> {
1439 let db = create_test_db()?;
1440
1441 let edge_id = db.add_causal_edge(
1442 vec!["rain".to_string()],
1443 vec!["wet ground".to_string()],
1444 0.95,
1445 "Weather observation".to_string(),
1446 )?;
1447
1448 let results = db.query_with_utility("weather patterns", 5, 0.7, 0.2, 0.1)?;
1449 assert!(!results.is_empty());
1450
1451 Ok(())
1452 }
1453
1454 #[test]
1455 fn test_learning_session() -> Result<()> {
1456 let db = create_test_db()?;
1457
1458 let session_id = db.start_session("Q-Learning".to_string(), 4, 2)?;
1459
1460 db.add_experience(
1461 &session_id,
1462 vec![1.0, 0.0, 0.0, 0.0],
1463 vec![1.0, 0.0],
1464 1.0,
1465 vec![0.0, 1.0, 0.0, 0.0],
1466 false,
1467 )?;
1468
1469 let prediction = db.predict_with_confidence(&session_id, vec![1.0, 0.0, 0.0, 0.0])?;
1470 assert_eq!(prediction.action.len(), 2);
1471
1472 Ok(())
1473 }
1474
1475 #[test]
1476 fn test_auto_consolidate() -> Result<()> {
1477 let db = create_test_db()?;
1478
1479 let sequences = vec![
1480 vec![
1481 "step1".to_string(),
1482 "step2".to_string(),
1483 "step3".to_string(),
1484 ],
1485 vec![
1486 "action1".to_string(),
1487 "action2".to_string(),
1488 "action3".to_string(),
1489 ],
1490 ];
1491
1492 let skill_ids = db.auto_consolidate(sequences, 3)?;
1493 assert_eq!(skill_ids.len(), 2);
1494
1495 Ok(())
1496 }
1497}