1use crate::embeddings::{BoxedEmbeddingProvider, HashEmbedding};
28use crate::error::{Result, RuvectorError};
29use crate::types::*;
30use crate::vector_db::VectorDB;
31use parking_lot::RwLock;
32use redb::{Database, TableDefinition};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::Arc;
36
37const REFLEXION_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("reflexion_episodes");
39const SKILLS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("skills_library");
40const CAUSAL_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("causal_edges");
41const LEARNING_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("learning_sessions");
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ReflexionEpisode {
47 pub id: String,
48 pub task: String,
49 pub actions: Vec<String>,
50 pub observations: Vec<String>,
51 pub critique: String,
52 pub embedding: Vec<f32>,
53 pub timestamp: i64,
54 pub metadata: Option<HashMap<String, serde_json::Value>>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
59pub struct Skill {
60 pub id: String,
61 pub name: String,
62 pub description: String,
63 pub parameters: HashMap<String, String>,
64 pub examples: Vec<String>,
65 pub embedding: Vec<f32>,
66 pub usage_count: usize,
67 pub success_rate: f64,
68 pub created_at: i64,
69 pub updated_at: i64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
74pub struct CausalEdge {
75 pub id: String,
76 pub causes: Vec<String>, pub effects: Vec<String>, pub confidence: f64,
79 pub context: String,
80 pub embedding: Vec<f32>,
81 pub observations: usize,
82 pub timestamp: i64,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
87pub struct LearningSession {
88 pub id: String,
89 pub algorithm: String, pub state_dim: usize,
91 pub action_dim: usize,
92 pub experiences: Vec<Experience>,
93 pub model_params: Option<Vec<u8>>, pub created_at: i64,
95 pub updated_at: i64,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
100pub struct Experience {
101 pub state: Vec<f32>,
102 pub action: Vec<f32>,
103 pub reward: f64,
104 pub next_state: Vec<f32>,
105 pub done: bool,
106 pub timestamp: i64,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
111pub struct Prediction {
112 pub action: Vec<f32>,
113 pub confidence_lower: f64,
114 pub confidence_upper: f64,
115 pub mean_confidence: f64,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct UtilitySearchResult {
121 pub result: SearchResult,
122 pub utility_score: f64,
123 pub similarity_score: f64,
124 pub causal_uplift: f64,
125 pub latency_penalty: f64,
126}
127
128pub struct AgenticDB {
130 vector_db: Arc<VectorDB>,
131 db: Arc<Database>,
132 _dimensions: usize,
133 embedding_provider: BoxedEmbeddingProvider,
134}
135
136impl AgenticDB {
137 pub fn new(options: DbOptions) -> Result<Self> {
139 let embedding_provider = Arc::new(HashEmbedding::new(options.dimensions));
140 Self::with_embedding_provider(options, embedding_provider)
141 }
142
143 pub fn with_embedding_provider(
181 options: DbOptions,
182 embedding_provider: BoxedEmbeddingProvider,
183 ) -> Result<Self> {
184 if options.dimensions != embedding_provider.dimensions() {
186 return Err(RuvectorError::InvalidDimension(format!(
187 "Options dimensions ({}) do not match embedding provider dimensions ({})",
188 options.dimensions,
189 embedding_provider.dimensions()
190 )));
191 }
192
193 let vector_db = Arc::new(VectorDB::new(options.clone())?);
195
196 let agentic_path = format!("{}.agentic", options.storage_path);
198 let db = Arc::new(Database::create(&agentic_path)?);
199
200 let write_txn = db.begin_write()?;
202 {
203 let _ = write_txn.open_table(REFLEXION_TABLE)?;
204 let _ = write_txn.open_table(SKILLS_TABLE)?;
205 let _ = write_txn.open_table(CAUSAL_TABLE)?;
206 let _ = write_txn.open_table(LEARNING_TABLE)?;
207 }
208 write_txn.commit()?;
209
210 Ok(Self {
211 vector_db,
212 db,
213 _dimensions: options.dimensions,
214 embedding_provider,
215 })
216 }
217
218 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
220 let options = DbOptions {
221 dimensions,
222 ..DbOptions::default()
223 };
224 Self::new(options)
225 }
226
227 pub fn embedding_provider_name(&self) -> &str {
229 self.embedding_provider.name()
230 }
231
232 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
236 self.vector_db.insert(entry)
237 }
238
239 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
241 self.vector_db.insert_batch(entries)
242 }
243
244 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
246 self.vector_db.search(query)
247 }
248
249 pub fn delete(&self, id: &str) -> Result<bool> {
251 self.vector_db.delete(id)
252 }
253
254 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
256 self.vector_db.get(id)
257 }
258
259 pub fn store_episode(
263 &self,
264 task: String,
265 actions: Vec<String>,
266 observations: Vec<String>,
267 critique: String,
268 ) -> Result<String> {
269 let id = uuid::Uuid::new_v4().to_string();
270
271 let embedding = self.generate_text_embedding(&critique)?;
273
274 let episode = ReflexionEpisode {
275 id: id.clone(),
276 task,
277 actions,
278 observations,
279 critique,
280 embedding: embedding.clone(),
281 timestamp: chrono::Utc::now().timestamp(),
282 metadata: None,
283 };
284
285 let write_txn = self.db.begin_write()?;
287 {
288 let mut table = write_txn.open_table(REFLEXION_TABLE)?;
289 let json = serde_json::to_vec(&episode)
291 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
292 table.insert(id.as_str(), json.as_slice())?;
293 }
294 write_txn.commit()?;
295
296 self.vector_db.insert(VectorEntry {
298 id: Some(format!("reflexion_{}", id)),
299 vector: embedding,
300 metadata: Some({
301 let mut meta = HashMap::new();
302 meta.insert("type".to_string(), serde_json::json!("reflexion"));
303 meta.insert("episode_id".to_string(), serde_json::json!(id.clone()));
304 meta
305 }),
306 })?;
307
308 Ok(id)
309 }
310
311 pub fn retrieve_similar_episodes(
313 &self,
314 query: &str,
315 k: usize,
316 ) -> Result<Vec<ReflexionEpisode>> {
317 let query_embedding = self.generate_text_embedding(query)?;
319
320 let results = self.vector_db.search(SearchQuery {
322 vector: query_embedding,
323 k,
324 filter: Some({
325 let mut filter = HashMap::new();
326 filter.insert("type".to_string(), serde_json::json!("reflexion"));
327 filter
328 }),
329 ef_search: None,
330 })?;
331
332 let mut episodes = Vec::new();
334 let read_txn = self.db.begin_read()?;
335 let table = read_txn.open_table(REFLEXION_TABLE)?;
336
337 for result in results {
338 if let Some(metadata) = result.metadata {
339 if let Some(episode_id) = metadata.get("episode_id") {
340 let id = episode_id.as_str().unwrap();
341 if let Some(data) = table.get(id)? {
342 let episode: ReflexionEpisode = serde_json::from_slice(data.value())
344 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
345 episodes.push(episode);
346 }
347 }
348 }
349 }
350
351 Ok(episodes)
352 }
353
354 pub fn create_skill(
358 &self,
359 name: String,
360 description: String,
361 parameters: HashMap<String, String>,
362 examples: Vec<String>,
363 ) -> Result<String> {
364 let id = uuid::Uuid::new_v4().to_string();
365
366 let embedding = self.generate_text_embedding(&description)?;
368
369 let skill = Skill {
370 id: id.clone(),
371 name,
372 description,
373 parameters,
374 examples,
375 embedding: embedding.clone(),
376 usage_count: 0,
377 success_rate: 0.0,
378 created_at: chrono::Utc::now().timestamp(),
379 updated_at: chrono::Utc::now().timestamp(),
380 };
381
382 let write_txn = self.db.begin_write()?;
384 {
385 let mut table = write_txn.open_table(SKILLS_TABLE)?;
386 let data = bincode::encode_to_vec(&skill, bincode::config::standard())
387 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
388 table.insert(id.as_str(), data.as_slice())?;
389 }
390 write_txn.commit()?;
391
392 self.vector_db.insert(VectorEntry {
394 id: Some(format!("skill_{}", id)),
395 vector: embedding,
396 metadata: Some({
397 let mut meta = HashMap::new();
398 meta.insert("type".to_string(), serde_json::json!("skill"));
399 meta.insert("skill_id".to_string(), serde_json::json!(id.clone()));
400 meta
401 }),
402 })?;
403
404 Ok(id)
405 }
406
407 pub fn search_skills(&self, query_description: &str, k: usize) -> Result<Vec<Skill>> {
409 let query_embedding = self.generate_text_embedding(query_description)?;
410
411 let results = self.vector_db.search(SearchQuery {
412 vector: query_embedding,
413 k,
414 filter: Some({
415 let mut filter = HashMap::new();
416 filter.insert("type".to_string(), serde_json::json!("skill"));
417 filter
418 }),
419 ef_search: None,
420 })?;
421
422 let mut skills = Vec::new();
423 let read_txn = self.db.begin_read()?;
424 let table = read_txn.open_table(SKILLS_TABLE)?;
425
426 for result in results {
427 if let Some(metadata) = result.metadata {
428 if let Some(skill_id) = metadata.get("skill_id") {
429 let id = skill_id.as_str().unwrap();
430 if let Some(data) = table.get(id)? {
431 let (skill, _): (Skill, usize) =
432 bincode::decode_from_slice(data.value(), bincode::config::standard())
433 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
434 skills.push(skill);
435 }
436 }
437 }
438 }
439
440 Ok(skills)
441 }
442
443 pub fn auto_consolidate(
445 &self,
446 action_sequences: Vec<Vec<String>>,
447 success_threshold: usize,
448 ) -> Result<Vec<String>> {
449 let mut skill_ids = Vec::new();
450
451 for sequence in action_sequences {
453 if sequence.len() >= success_threshold {
454 let description = format!("Skill: {}", sequence.join(" -> "));
455 let skill_id = self.create_skill(
456 format!("Auto-Skill-{}", uuid::Uuid::new_v4()),
457 description,
458 HashMap::new(),
459 sequence.clone(),
460 )?;
461 skill_ids.push(skill_id);
462 }
463 }
464
465 Ok(skill_ids)
466 }
467
468 pub fn add_causal_edge(
472 &self,
473 causes: Vec<String>,
474 effects: Vec<String>,
475 confidence: f64,
476 context: String,
477 ) -> Result<String> {
478 let id = uuid::Uuid::new_v4().to_string();
479
480 let embedding = self.generate_text_embedding(&context)?;
482
483 let edge = CausalEdge {
484 id: id.clone(),
485 causes,
486 effects,
487 confidence,
488 context,
489 embedding: embedding.clone(),
490 observations: 1,
491 timestamp: chrono::Utc::now().timestamp(),
492 };
493
494 let write_txn = self.db.begin_write()?;
496 {
497 let mut table = write_txn.open_table(CAUSAL_TABLE)?;
498 let data = bincode::encode_to_vec(&edge, bincode::config::standard())
499 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
500 table.insert(id.as_str(), data.as_slice())?;
501 }
502 write_txn.commit()?;
503
504 self.vector_db.insert(VectorEntry {
506 id: Some(format!("causal_{}", id)),
507 vector: embedding,
508 metadata: Some({
509 let mut meta = HashMap::new();
510 meta.insert("type".to_string(), serde_json::json!("causal"));
511 meta.insert("causal_id".to_string(), serde_json::json!(id.clone()));
512 meta.insert("confidence".to_string(), serde_json::json!(confidence));
513 meta
514 }),
515 })?;
516
517 Ok(id)
518 }
519
520 pub fn query_with_utility(
522 &self,
523 query: &str,
524 k: usize,
525 alpha: f64,
526 beta: f64,
527 gamma: f64,
528 ) -> Result<Vec<UtilitySearchResult>> {
529 let start_time = std::time::Instant::now();
530 let query_embedding = self.generate_text_embedding(query)?;
531
532 let results = self.vector_db.search(SearchQuery {
534 vector: query_embedding,
535 k: k * 2, filter: Some({
537 let mut filter = HashMap::new();
538 filter.insert("type".to_string(), serde_json::json!("causal"));
539 filter
540 }),
541 ef_search: None,
542 })?;
543
544 let mut utility_results = Vec::new();
545
546 for result in results {
547 let similarity_score = 1.0 / (1.0 + result.score as f64); let causal_uplift = if let Some(ref metadata) = result.metadata {
551 metadata
552 .get("confidence")
553 .and_then(|v| v.as_f64())
554 .unwrap_or(0.0)
555 } else {
556 0.0
557 };
558
559 let latency = start_time.elapsed().as_secs_f64();
560 let latency_penalty = latency * gamma;
561
562 let utility_score = alpha * similarity_score + beta * causal_uplift - latency_penalty;
564
565 utility_results.push(UtilitySearchResult {
566 result,
567 utility_score,
568 similarity_score,
569 causal_uplift,
570 latency_penalty,
571 });
572 }
573
574 utility_results.sort_by(|a, b| b.utility_score.partial_cmp(&a.utility_score).unwrap());
576 utility_results.truncate(k);
577
578 Ok(utility_results)
579 }
580
581 pub fn start_session(
585 &self,
586 algorithm: String,
587 state_dim: usize,
588 action_dim: usize,
589 ) -> Result<String> {
590 let id = uuid::Uuid::new_v4().to_string();
591
592 let session = LearningSession {
593 id: id.clone(),
594 algorithm,
595 state_dim,
596 action_dim,
597 experiences: Vec::new(),
598 model_params: None,
599 created_at: chrono::Utc::now().timestamp(),
600 updated_at: chrono::Utc::now().timestamp(),
601 };
602
603 let write_txn = self.db.begin_write()?;
604 {
605 let mut table = write_txn.open_table(LEARNING_TABLE)?;
606 let data = bincode::encode_to_vec(&session, bincode::config::standard())
607 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
608 table.insert(id.as_str(), data.as_slice())?;
609 }
610 write_txn.commit()?;
611
612 Ok(id)
613 }
614
615 pub fn add_experience(
617 &self,
618 session_id: &str,
619 state: Vec<f32>,
620 action: Vec<f32>,
621 reward: f64,
622 next_state: Vec<f32>,
623 done: bool,
624 ) -> Result<()> {
625 let read_txn = self.db.begin_read()?;
626 let table = read_txn.open_table(LEARNING_TABLE)?;
627
628 let data = table
629 .get(session_id)?
630 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
631
632 let (mut session, _): (LearningSession, usize) =
633 bincode::decode_from_slice(data.value(), bincode::config::standard())
634 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
635
636 drop(table);
637 drop(read_txn);
638
639 session.experiences.push(Experience {
641 state,
642 action,
643 reward,
644 next_state,
645 done,
646 timestamp: chrono::Utc::now().timestamp(),
647 });
648 session.updated_at = chrono::Utc::now().timestamp();
649
650 let write_txn = self.db.begin_write()?;
652 {
653 let mut table = write_txn.open_table(LEARNING_TABLE)?;
654 let data = bincode::encode_to_vec(&session, bincode::config::standard())
655 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
656 table.insert(session_id, data.as_slice())?;
657 }
658 write_txn.commit()?;
659
660 Ok(())
661 }
662
663 pub fn predict_with_confidence(&self, session_id: &str, state: Vec<f32>) -> Result<Prediction> {
665 let read_txn = self.db.begin_read()?;
666 let table = read_txn.open_table(LEARNING_TABLE)?;
667
668 let data = table
669 .get(session_id)?
670 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
671
672 let (session, _): (LearningSession, usize) =
673 bincode::decode_from_slice(data.value(), bincode::config::standard())
674 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
675
676 let mut similar_actions = Vec::new();
678 let mut rewards = Vec::new();
679
680 for exp in &session.experiences {
681 let distance = euclidean_distance(&state, &exp.state);
682 if distance < 1.0 {
683 similar_actions.push(exp.action.clone());
685 rewards.push(exp.reward);
686 }
687 }
688
689 if similar_actions.is_empty() {
690 return Ok(Prediction {
692 action: vec![0.0; session.action_dim],
693 confidence_lower: 0.0,
694 confidence_upper: 0.0,
695 mean_confidence: 0.0,
696 });
697 }
698
699 let total_reward: f64 = rewards.iter().sum();
701 let mut action = vec![0.0; session.action_dim];
702
703 for (act, reward) in similar_actions.iter().zip(rewards.iter()) {
704 let weight = reward / total_reward;
705 for (i, val) in act.iter().enumerate() {
706 action[i] += val * weight as f32;
707 }
708 }
709
710 let mean_reward = total_reward / rewards.len() as f64;
712 let std_dev = calculate_std_dev(&rewards, mean_reward);
713
714 Ok(Prediction {
715 action,
716 confidence_lower: mean_reward - 1.96 * std_dev,
717 confidence_upper: mean_reward + 1.96 * std_dev,
718 mean_confidence: mean_reward,
719 })
720 }
721
722 pub fn get_session(&self, session_id: &str) -> Result<Option<LearningSession>> {
724 let read_txn = self.db.begin_read()?;
725 let table = read_txn.open_table(LEARNING_TABLE)?;
726
727 if let Some(data) = table.get(session_id)? {
728 let (session, _): (LearningSession, usize) =
729 bincode::decode_from_slice(data.value(), bincode::config::standard())
730 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
731 Ok(Some(session))
732 } else {
733 Ok(None)
734 }
735 }
736
737 fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
760 self.embedding_provider.embed(text)
761 }
762}
763
764fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
766 a.iter()
767 .zip(b.iter())
768 .map(|(x, y)| (x - y).powi(2))
769 .sum::<f32>()
770 .sqrt()
771}
772
773fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
774 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
775 variance.sqrt()
776}
777
778pub struct PolicyMemoryStore<'a> {
792 db: &'a AgenticDB,
793}
794
795#[derive(Debug, Clone, Serialize, Deserialize)]
797pub struct PolicyAction {
798 pub action: String,
800 pub reward: f64,
802 pub q_value: f64,
804 pub state_embedding: Vec<f32>,
806 pub timestamp: i64,
808}
809
810#[derive(Debug, Clone, Serialize, Deserialize)]
812pub struct PolicyEntry {
813 pub id: String,
815 pub state_id: String,
817 pub action: PolicyAction,
819 pub metadata: Option<HashMap<String, serde_json::Value>>,
821}
822
823impl<'a> PolicyMemoryStore<'a> {
824 pub fn new(db: &'a AgenticDB) -> Self {
826 Self { db }
827 }
828
829 pub fn store_policy(
831 &self,
832 state_id: &str,
833 state_embedding: Vec<f32>,
834 action: &str,
835 reward: f64,
836 q_value: f64,
837 ) -> Result<String> {
838 let id = uuid::Uuid::new_v4().to_string();
839 let timestamp = chrono::Utc::now().timestamp();
840
841 let _entry = PolicyEntry {
842 id: id.clone(),
843 state_id: state_id.to_string(),
844 action: PolicyAction {
845 action: action.to_string(),
846 reward,
847 q_value,
848 state_embedding: state_embedding.clone(),
849 timestamp,
850 },
851 metadata: None,
852 };
853
854 self.db.vector_db.insert(VectorEntry {
856 id: Some(format!("policy_{}", id)),
857 vector: state_embedding,
858 metadata: Some({
859 let mut meta = HashMap::new();
860 meta.insert("type".to_string(), serde_json::json!("policy"));
861 meta.insert("policy_id".to_string(), serde_json::json!(id.clone()));
862 meta.insert("state_id".to_string(), serde_json::json!(state_id));
863 meta.insert("action".to_string(), serde_json::json!(action));
864 meta.insert("reward".to_string(), serde_json::json!(reward));
865 meta.insert("q_value".to_string(), serde_json::json!(q_value));
866 meta
867 }),
868 })?;
869
870 Ok(id)
871 }
872
873 pub fn retrieve_similar_states(
875 &self,
876 state_embedding: &[f32],
877 k: usize,
878 ) -> Result<Vec<PolicyEntry>> {
879 let results = self.db.vector_db.search(SearchQuery {
880 vector: state_embedding.to_vec(),
881 k,
882 filter: Some({
883 let mut filter = HashMap::new();
884 filter.insert("type".to_string(), serde_json::json!("policy"));
885 filter
886 }),
887 ef_search: None,
888 })?;
889
890 let mut entries = Vec::new();
891 for result in results {
892 if let Some(metadata) = result.metadata {
893 let policy_id = metadata
894 .get("policy_id")
895 .and_then(|v| v.as_str())
896 .unwrap_or("");
897 let state_id = metadata
898 .get("state_id")
899 .and_then(|v| v.as_str())
900 .unwrap_or("");
901 let action = metadata
902 .get("action")
903 .and_then(|v| v.as_str())
904 .unwrap_or("");
905 let reward = metadata
906 .get("reward")
907 .and_then(|v| v.as_f64())
908 .unwrap_or(0.0);
909 let q_value = metadata
910 .get("q_value")
911 .and_then(|v| v.as_f64())
912 .unwrap_or(0.0);
913
914 entries.push(PolicyEntry {
915 id: policy_id.to_string(),
916 state_id: state_id.to_string(),
917 action: PolicyAction {
918 action: action.to_string(),
919 reward,
920 q_value,
921 state_embedding: result.vector.unwrap_or_default(),
922 timestamp: 0,
923 },
924 metadata: None,
925 });
926 }
927 }
928
929 Ok(entries)
930 }
931
932 pub fn get_best_action(&self, state_embedding: &[f32], k: usize) -> Result<Option<String>> {
934 let similar = self.retrieve_similar_states(state_embedding, k)?;
935
936 similar
937 .into_iter()
938 .max_by(|a, b| a.action.q_value.partial_cmp(&b.action.q_value).unwrap())
939 .map(|entry| Ok(entry.action.action))
940 .transpose()
941 }
942
943 pub fn update_q_value(&self, policy_id: &str, _new_q_value: f64) -> Result<()> {
945 let _ = self.db.vector_db.delete(&format!("policy_{}", policy_id));
948 Ok(())
949 }
950}
951
952pub struct SessionStateIndex<'a> {
957 db: &'a AgenticDB,
958 session_id: String,
959 ttl_seconds: i64,
960}
961
962#[derive(Debug, Clone, Serialize, Deserialize)]
964pub struct SessionTurn {
965 pub id: String,
967 pub session_id: String,
969 pub turn_number: usize,
971 pub role: String,
973 pub content: String,
975 pub embedding: Vec<f32>,
977 pub timestamp: i64,
979 pub expires_at: i64,
981}
982
983impl<'a> SessionStateIndex<'a> {
984 pub fn new(db: &'a AgenticDB, session_id: &str, ttl_seconds: i64) -> Self {
986 Self {
987 db,
988 session_id: session_id.to_string(),
989 ttl_seconds,
990 }
991 }
992
993 pub fn add_turn(&self, turn_number: usize, role: &str, content: &str) -> Result<String> {
995 let id = uuid::Uuid::new_v4().to_string();
996 let timestamp = chrono::Utc::now().timestamp();
997 let expires_at = timestamp + self.ttl_seconds;
998
999 let embedding = self.db.generate_text_embedding(content)?;
1001
1002 self.db.vector_db.insert(VectorEntry {
1004 id: Some(format!("session_{}_{}", self.session_id, id)),
1005 vector: embedding,
1006 metadata: Some({
1007 let mut meta = HashMap::new();
1008 meta.insert("type".to_string(), serde_json::json!("session_turn"));
1009 meta.insert(
1010 "session_id".to_string(),
1011 serde_json::json!(self.session_id.clone()),
1012 );
1013 meta.insert("turn_id".to_string(), serde_json::json!(id.clone()));
1014 meta.insert("turn_number".to_string(), serde_json::json!(turn_number));
1015 meta.insert("role".to_string(), serde_json::json!(role));
1016 meta.insert("content".to_string(), serde_json::json!(content));
1017 meta.insert("timestamp".to_string(), serde_json::json!(timestamp));
1018 meta.insert("expires_at".to_string(), serde_json::json!(expires_at));
1019 meta
1020 }),
1021 })?;
1022
1023 Ok(id)
1024 }
1025
1026 pub fn find_relevant_turns(&self, query: &str, k: usize) -> Result<Vec<SessionTurn>> {
1028 let query_embedding = self.db.generate_text_embedding(query)?;
1029 let current_time = chrono::Utc::now().timestamp();
1030
1031 let results = self.db.vector_db.search(SearchQuery {
1032 vector: query_embedding,
1033 k: k * 2, filter: Some({
1035 let mut filter = HashMap::new();
1036 filter.insert("type".to_string(), serde_json::json!("session_turn"));
1037 filter.insert(
1038 "session_id".to_string(),
1039 serde_json::json!(self.session_id.clone()),
1040 );
1041 filter
1042 }),
1043 ef_search: None,
1044 })?;
1045
1046 let mut turns = Vec::new();
1047 for result in results {
1048 if let Some(metadata) = result.metadata {
1049 let expires_at = metadata
1050 .get("expires_at")
1051 .and_then(|v| v.as_i64())
1052 .unwrap_or(0);
1053
1054 if expires_at < current_time {
1056 continue;
1057 }
1058
1059 turns.push(SessionTurn {
1060 id: metadata
1061 .get("turn_id")
1062 .and_then(|v| v.as_str())
1063 .unwrap_or("")
1064 .to_string(),
1065 session_id: self.session_id.clone(),
1066 turn_number: metadata
1067 .get("turn_number")
1068 .and_then(|v| v.as_u64())
1069 .unwrap_or(0) as usize,
1070 role: metadata
1071 .get("role")
1072 .and_then(|v| v.as_str())
1073 .unwrap_or("")
1074 .to_string(),
1075 content: metadata
1076 .get("content")
1077 .and_then(|v| v.as_str())
1078 .unwrap_or("")
1079 .to_string(),
1080 embedding: result.vector.unwrap_or_default(),
1081 timestamp: metadata
1082 .get("timestamp")
1083 .and_then(|v| v.as_i64())
1084 .unwrap_or(0),
1085 expires_at,
1086 });
1087
1088 if turns.len() >= k {
1089 break;
1090 }
1091 }
1092 }
1093
1094 Ok(turns)
1095 }
1096
1097 pub fn get_session_context(&self) -> Result<Vec<SessionTurn>> {
1099 let mut turns = self.find_relevant_turns("", 1000)?;
1100 turns.sort_by_key(|t| t.turn_number);
1101 Ok(turns)
1102 }
1103
1104 pub fn cleanup_expired(&self) -> Result<usize> {
1106 let current_time = chrono::Utc::now().timestamp();
1107 let all_turns = self.find_relevant_turns("", 10000)?;
1108 let mut deleted = 0;
1109
1110 for turn in all_turns {
1111 if turn.expires_at < current_time {
1112 let _ = self
1113 .db
1114 .vector_db
1115 .delete(&format!("session_{}_{}", self.session_id, turn.id));
1116 deleted += 1;
1117 }
1118 }
1119
1120 Ok(deleted)
1121 }
1122}
1123
1124pub struct WitnessLog<'a> {
1128 db: &'a AgenticDB,
1129 last_hash: RwLock<Option<String>>,
1130}
1131
1132#[derive(Debug, Clone, Serialize, Deserialize)]
1134pub struct WitnessEntry {
1135 pub id: String,
1137 pub prev_hash: Option<String>,
1139 pub hash: String,
1141 pub agent_id: String,
1143 pub action_type: String,
1145 pub details: String,
1147 pub embedding: Vec<f32>,
1149 pub timestamp: i64,
1151 pub metadata: Option<HashMap<String, serde_json::Value>>,
1153}
1154
1155impl<'a> WitnessLog<'a> {
1156 pub fn new(db: &'a AgenticDB) -> Self {
1158 Self {
1159 db,
1160 last_hash: RwLock::new(None),
1161 }
1162 }
1163
1164 fn compute_hash(
1166 prev_hash: &Option<String>,
1167 agent_id: &str,
1168 action_type: &str,
1169 details: &str,
1170 timestamp: i64,
1171 ) -> String {
1172 use std::collections::hash_map::DefaultHasher;
1173 use std::hash::{Hash, Hasher};
1174
1175 let mut hasher = DefaultHasher::new();
1176 if let Some(prev) = prev_hash {
1177 prev.hash(&mut hasher);
1178 }
1179 agent_id.hash(&mut hasher);
1180 action_type.hash(&mut hasher);
1181 details.hash(&mut hasher);
1182 timestamp.hash(&mut hasher);
1183 format!("{:016x}", hasher.finish())
1184 }
1185
1186 pub fn append(&self, agent_id: &str, action_type: &str, details: &str) -> Result<String> {
1188 let id = uuid::Uuid::new_v4().to_string();
1189 let timestamp = chrono::Utc::now().timestamp();
1190
1191 let prev_hash = self.last_hash.read().clone();
1193
1194 let hash = Self::compute_hash(&prev_hash, agent_id, action_type, details, timestamp);
1196
1197 let embedding = self
1199 .db
1200 .generate_text_embedding(&format!("{} {} {}", agent_id, action_type, details))?;
1201
1202 self.db.vector_db.insert(VectorEntry {
1204 id: Some(format!("witness_{}", id)),
1205 vector: embedding.clone(),
1206 metadata: Some({
1207 let mut meta = HashMap::new();
1208 meta.insert("type".to_string(), serde_json::json!("witness"));
1209 meta.insert("witness_id".to_string(), serde_json::json!(id.clone()));
1210 meta.insert("agent_id".to_string(), serde_json::json!(agent_id));
1211 meta.insert("action_type".to_string(), serde_json::json!(action_type));
1212 meta.insert("details".to_string(), serde_json::json!(details));
1213 meta.insert("timestamp".to_string(), serde_json::json!(timestamp));
1214 meta.insert("hash".to_string(), serde_json::json!(hash.clone()));
1215 if let Some(ref prev) = prev_hash {
1216 meta.insert("prev_hash".to_string(), serde_json::json!(prev));
1217 }
1218 meta
1219 }),
1220 })?;
1221
1222 *self.last_hash.write() = Some(hash.clone());
1224
1225 Ok(id)
1226 }
1227
1228 pub fn search(&self, query: &str, k: usize) -> Result<Vec<WitnessEntry>> {
1230 let query_embedding = self.db.generate_text_embedding(query)?;
1231
1232 let results = self.db.vector_db.search(SearchQuery {
1233 vector: query_embedding,
1234 k,
1235 filter: Some({
1236 let mut filter = HashMap::new();
1237 filter.insert("type".to_string(), serde_json::json!("witness"));
1238 filter
1239 }),
1240 ef_search: None,
1241 })?;
1242
1243 let mut entries = Vec::new();
1244 for result in results {
1245 if let Some(metadata) = result.metadata {
1246 entries.push(WitnessEntry {
1247 id: metadata
1248 .get("witness_id")
1249 .and_then(|v| v.as_str())
1250 .unwrap_or("")
1251 .to_string(),
1252 prev_hash: metadata
1253 .get("prev_hash")
1254 .and_then(|v| v.as_str())
1255 .map(|s| s.to_string()),
1256 hash: metadata
1257 .get("hash")
1258 .and_then(|v| v.as_str())
1259 .unwrap_or("")
1260 .to_string(),
1261 agent_id: metadata
1262 .get("agent_id")
1263 .and_then(|v| v.as_str())
1264 .unwrap_or("")
1265 .to_string(),
1266 action_type: metadata
1267 .get("action_type")
1268 .and_then(|v| v.as_str())
1269 .unwrap_or("")
1270 .to_string(),
1271 details: metadata
1272 .get("details")
1273 .and_then(|v| v.as_str())
1274 .unwrap_or("")
1275 .to_string(),
1276 embedding: result.vector.unwrap_or_default(),
1277 timestamp: metadata
1278 .get("timestamp")
1279 .and_then(|v| v.as_i64())
1280 .unwrap_or(0),
1281 metadata: None,
1282 });
1283 }
1284 }
1285
1286 Ok(entries)
1287 }
1288
1289 pub fn get_by_agent(&self, agent_id: &str, k: usize) -> Result<Vec<WitnessEntry>> {
1291 self.search(agent_id, k)
1293 }
1294
1295 pub fn verify_chain(&self) -> Result<bool> {
1297 let entries = self.search("", 10000)?;
1298
1299 let mut sorted_entries = entries;
1301 sorted_entries.sort_by_key(|e| e.timestamp);
1302
1303 for i in 1..sorted_entries.len() {
1305 let prev = &sorted_entries[i - 1];
1306 let curr = &sorted_entries[i];
1307
1308 if let Some(ref prev_hash) = curr.prev_hash {
1309 if prev_hash != &prev.hash {
1310 return Ok(false);
1311 }
1312 }
1313 }
1314
1315 Ok(true)
1316 }
1317}
1318
1319impl AgenticDB {
1320 pub fn policy_memory(&self) -> PolicyMemoryStore<'_> {
1322 PolicyMemoryStore::new(self)
1323 }
1324
1325 pub fn session_index(&self, session_id: &str, ttl_seconds: i64) -> SessionStateIndex<'_> {
1327 SessionStateIndex::new(self, session_id, ttl_seconds)
1328 }
1329
1330 pub fn witness_log(&self) -> WitnessLog<'_> {
1332 WitnessLog::new(self)
1333 }
1334}
1335
1336#[cfg(test)]
1337mod tests {
1338 use super::*;
1339 use tempfile::tempdir;
1340
1341 fn create_test_db() -> Result<AgenticDB> {
1342 let dir = tempdir().unwrap();
1343 let mut options = DbOptions::default();
1344 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
1345 options.dimensions = 128;
1346 AgenticDB::new(options)
1347 }
1348
1349 #[test]
1350 fn test_reflexion_episode() -> Result<()> {
1351 let db = create_test_db()?;
1352
1353 let id = db.store_episode(
1354 "Solve math problem".to_string(),
1355 vec!["read problem".to_string(), "calculate".to_string()],
1356 vec!["got 42".to_string()],
1357 "Should have shown work".to_string(),
1358 )?;
1359
1360 let episodes = db.retrieve_similar_episodes("math problem solving", 5)?;
1361 assert!(!episodes.is_empty());
1362 assert_eq!(episodes[0].id, id);
1363
1364 Ok(())
1365 }
1366
1367 #[test]
1368 fn test_skill_library() -> Result<()> {
1369 let db = create_test_db()?;
1370
1371 let mut params = HashMap::new();
1372 params.insert("input".to_string(), "string".to_string());
1373
1374 let skill_id = db.create_skill(
1375 "Parse JSON".to_string(),
1376 "Parse JSON from string".to_string(),
1377 params,
1378 vec!["json.parse()".to_string()],
1379 )?;
1380
1381 let skills = db.search_skills("parse json data", 5)?;
1382 assert!(!skills.is_empty());
1383
1384 Ok(())
1385 }
1386
1387 #[test]
1388 fn test_causal_edge() -> Result<()> {
1389 let db = create_test_db()?;
1390
1391 let edge_id = db.add_causal_edge(
1392 vec!["rain".to_string()],
1393 vec!["wet ground".to_string()],
1394 0.95,
1395 "Weather observation".to_string(),
1396 )?;
1397
1398 let results = db.query_with_utility("weather patterns", 5, 0.7, 0.2, 0.1)?;
1399 assert!(!results.is_empty());
1400
1401 Ok(())
1402 }
1403
1404 #[test]
1405 fn test_learning_session() -> Result<()> {
1406 let db = create_test_db()?;
1407
1408 let session_id = db.start_session("Q-Learning".to_string(), 4, 2)?;
1409
1410 db.add_experience(
1411 &session_id,
1412 vec![1.0, 0.0, 0.0, 0.0],
1413 vec![1.0, 0.0],
1414 1.0,
1415 vec![0.0, 1.0, 0.0, 0.0],
1416 false,
1417 )?;
1418
1419 let prediction = db.predict_with_confidence(&session_id, vec![1.0, 0.0, 0.0, 0.0])?;
1420 assert_eq!(prediction.action.len(), 2);
1421
1422 Ok(())
1423 }
1424
1425 #[test]
1426 fn test_auto_consolidate() -> Result<()> {
1427 let db = create_test_db()?;
1428
1429 let sequences = vec![
1430 vec![
1431 "step1".to_string(),
1432 "step2".to_string(),
1433 "step3".to_string(),
1434 ],
1435 vec![
1436 "action1".to_string(),
1437 "action2".to_string(),
1438 "action3".to_string(),
1439 ],
1440 ];
1441
1442 let skill_ids = db.auto_consolidate(sequences, 3)?;
1443 assert_eq!(skill_ids.len(), 2);
1444
1445 Ok(())
1446 }
1447}