1use crate::embeddings::{BoxedEmbeddingProvider, EmbeddingProvider, HashEmbedding};
28use crate::error::{Result, RuvectorError};
29use crate::types::*;
30use crate::vector_db::VectorDB;
31use parking_lot::RwLock;
32use redb::{Database, TableDefinition};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::path::Path;
36use std::sync::Arc;
37
38const REFLEXION_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("reflexion_episodes");
40const SKILLS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("skills_library");
41const CAUSAL_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("causal_edges");
42const LEARNING_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("learning_sessions");
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ReflexionEpisode {
48 pub id: String,
49 pub task: String,
50 pub actions: Vec<String>,
51 pub observations: Vec<String>,
52 pub critique: String,
53 pub embedding: Vec<f32>,
54 pub timestamp: i64,
55 pub metadata: Option<HashMap<String, serde_json::Value>>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
60pub struct Skill {
61 pub id: String,
62 pub name: String,
63 pub description: String,
64 pub parameters: HashMap<String, String>,
65 pub examples: Vec<String>,
66 pub embedding: Vec<f32>,
67 pub usage_count: usize,
68 pub success_rate: f64,
69 pub created_at: i64,
70 pub updated_at: i64,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
75pub struct CausalEdge {
76 pub id: String,
77 pub causes: Vec<String>, pub effects: Vec<String>, pub confidence: f64,
80 pub context: String,
81 pub embedding: Vec<f32>,
82 pub observations: usize,
83 pub timestamp: i64,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
88pub struct LearningSession {
89 pub id: String,
90 pub algorithm: String, pub state_dim: usize,
92 pub action_dim: usize,
93 pub experiences: Vec<Experience>,
94 pub model_params: Option<Vec<u8>>, pub created_at: i64,
96 pub updated_at: i64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
101pub struct Experience {
102 pub state: Vec<f32>,
103 pub action: Vec<f32>,
104 pub reward: f64,
105 pub next_state: Vec<f32>,
106 pub done: bool,
107 pub timestamp: i64,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, bincode::Encode, bincode::Decode)]
112pub struct Prediction {
113 pub action: Vec<f32>,
114 pub confidence_lower: f64,
115 pub confidence_upper: f64,
116 pub mean_confidence: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct UtilitySearchResult {
122 pub result: SearchResult,
123 pub utility_score: f64,
124 pub similarity_score: f64,
125 pub causal_uplift: f64,
126 pub latency_penalty: f64,
127}
128
129pub struct AgenticDB {
131 vector_db: Arc<VectorDB>,
132 db: Arc<Database>,
133 dimensions: usize,
134 embedding_provider: BoxedEmbeddingProvider,
135}
136
137impl AgenticDB {
138 pub fn new(options: DbOptions) -> Result<Self> {
140 let embedding_provider = Arc::new(HashEmbedding::new(options.dimensions));
141 Self::with_embedding_provider(options, embedding_provider)
142 }
143
144 pub fn with_embedding_provider(
182 options: DbOptions,
183 embedding_provider: BoxedEmbeddingProvider,
184 ) -> Result<Self> {
185 if options.dimensions != embedding_provider.dimensions() {
187 return Err(RuvectorError::InvalidDimension(format!(
188 "Options dimensions ({}) do not match embedding provider dimensions ({})",
189 options.dimensions,
190 embedding_provider.dimensions()
191 )));
192 }
193
194 let vector_db = Arc::new(VectorDB::new(options.clone())?);
196
197 let agentic_path = format!("{}.agentic", options.storage_path);
199 let db = Arc::new(Database::create(&agentic_path)?);
200
201 let write_txn = db.begin_write()?;
203 {
204 let _ = write_txn.open_table(REFLEXION_TABLE)?;
205 let _ = write_txn.open_table(SKILLS_TABLE)?;
206 let _ = write_txn.open_table(CAUSAL_TABLE)?;
207 let _ = write_txn.open_table(LEARNING_TABLE)?;
208 }
209 write_txn.commit()?;
210
211 Ok(Self {
212 vector_db,
213 db,
214 dimensions: options.dimensions,
215 embedding_provider,
216 })
217 }
218
219 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
221 let mut options = DbOptions::default();
222 options.dimensions = dimensions;
223 Self::new(options)
224 }
225
226 pub fn embedding_provider_name(&self) -> &str {
228 self.embedding_provider.name()
229 }
230
231 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
235 self.vector_db.insert(entry)
236 }
237
238 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
240 self.vector_db.insert_batch(entries)
241 }
242
243 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
245 self.vector_db.search(query)
246 }
247
248 pub fn delete(&self, id: &str) -> Result<bool> {
250 self.vector_db.delete(id)
251 }
252
253 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
255 self.vector_db.get(id)
256 }
257
258 pub fn store_episode(
262 &self,
263 task: String,
264 actions: Vec<String>,
265 observations: Vec<String>,
266 critique: String,
267 ) -> Result<String> {
268 let id = uuid::Uuid::new_v4().to_string();
269
270 let embedding = self.generate_text_embedding(&critique)?;
272
273 let episode = ReflexionEpisode {
274 id: id.clone(),
275 task,
276 actions,
277 observations,
278 critique,
279 embedding: embedding.clone(),
280 timestamp: chrono::Utc::now().timestamp(),
281 metadata: None,
282 };
283
284 let write_txn = self.db.begin_write()?;
286 {
287 let mut table = write_txn.open_table(REFLEXION_TABLE)?;
288 let json = serde_json::to_vec(&episode)
290 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
291 table.insert(id.as_str(), json.as_slice())?;
292 }
293 write_txn.commit()?;
294
295 self.vector_db.insert(VectorEntry {
297 id: Some(format!("reflexion_{}", id)),
298 vector: embedding,
299 metadata: Some({
300 let mut meta = HashMap::new();
301 meta.insert("type".to_string(), serde_json::json!("reflexion"));
302 meta.insert("episode_id".to_string(), serde_json::json!(id.clone()));
303 meta
304 }),
305 })?;
306
307 Ok(id)
308 }
309
310 pub fn retrieve_similar_episodes(
312 &self,
313 query: &str,
314 k: usize,
315 ) -> Result<Vec<ReflexionEpisode>> {
316 let query_embedding = self.generate_text_embedding(query)?;
318
319 let results = self.vector_db.search(SearchQuery {
321 vector: query_embedding,
322 k,
323 filter: Some({
324 let mut filter = HashMap::new();
325 filter.insert("type".to_string(), serde_json::json!("reflexion"));
326 filter
327 }),
328 ef_search: None,
329 })?;
330
331 let mut episodes = Vec::new();
333 let read_txn = self.db.begin_read()?;
334 let table = read_txn.open_table(REFLEXION_TABLE)?;
335
336 for result in results {
337 if let Some(metadata) = result.metadata {
338 if let Some(episode_id) = metadata.get("episode_id") {
339 let id = episode_id.as_str().unwrap();
340 if let Some(data) = table.get(id)? {
341 let episode: ReflexionEpisode = serde_json::from_slice(data.value())
343 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
344 episodes.push(episode);
345 }
346 }
347 }
348 }
349
350 Ok(episodes)
351 }
352
353 pub fn create_skill(
357 &self,
358 name: String,
359 description: String,
360 parameters: HashMap<String, String>,
361 examples: Vec<String>,
362 ) -> Result<String> {
363 let id = uuid::Uuid::new_v4().to_string();
364
365 let embedding = self.generate_text_embedding(&description)?;
367
368 let skill = Skill {
369 id: id.clone(),
370 name,
371 description,
372 parameters,
373 examples,
374 embedding: embedding.clone(),
375 usage_count: 0,
376 success_rate: 0.0,
377 created_at: chrono::Utc::now().timestamp(),
378 updated_at: chrono::Utc::now().timestamp(),
379 };
380
381 let write_txn = self.db.begin_write()?;
383 {
384 let mut table = write_txn.open_table(SKILLS_TABLE)?;
385 let data = bincode::encode_to_vec(&skill, bincode::config::standard())
386 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
387 table.insert(id.as_str(), data.as_slice())?;
388 }
389 write_txn.commit()?;
390
391 self.vector_db.insert(VectorEntry {
393 id: Some(format!("skill_{}", id)),
394 vector: embedding,
395 metadata: Some({
396 let mut meta = HashMap::new();
397 meta.insert("type".to_string(), serde_json::json!("skill"));
398 meta.insert("skill_id".to_string(), serde_json::json!(id.clone()));
399 meta
400 }),
401 })?;
402
403 Ok(id)
404 }
405
406 pub fn search_skills(&self, query_description: &str, k: usize) -> Result<Vec<Skill>> {
408 let query_embedding = self.generate_text_embedding(query_description)?;
409
410 let results = self.vector_db.search(SearchQuery {
411 vector: query_embedding,
412 k,
413 filter: Some({
414 let mut filter = HashMap::new();
415 filter.insert("type".to_string(), serde_json::json!("skill"));
416 filter
417 }),
418 ef_search: None,
419 })?;
420
421 let mut skills = Vec::new();
422 let read_txn = self.db.begin_read()?;
423 let table = read_txn.open_table(SKILLS_TABLE)?;
424
425 for result in results {
426 if let Some(metadata) = result.metadata {
427 if let Some(skill_id) = metadata.get("skill_id") {
428 let id = skill_id.as_str().unwrap();
429 if let Some(data) = table.get(id)? {
430 let (skill, _): (Skill, usize) =
431 bincode::decode_from_slice(data.value(), bincode::config::standard())
432 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
433 skills.push(skill);
434 }
435 }
436 }
437 }
438
439 Ok(skills)
440 }
441
442 pub fn auto_consolidate(
444 &self,
445 action_sequences: Vec<Vec<String>>,
446 success_threshold: usize,
447 ) -> Result<Vec<String>> {
448 let mut skill_ids = Vec::new();
449
450 for sequence in action_sequences {
452 if sequence.len() >= success_threshold {
453 let description = format!("Skill: {}", sequence.join(" -> "));
454 let skill_id = self.create_skill(
455 format!("Auto-Skill-{}", uuid::Uuid::new_v4()),
456 description,
457 HashMap::new(),
458 sequence.clone(),
459 )?;
460 skill_ids.push(skill_id);
461 }
462 }
463
464 Ok(skill_ids)
465 }
466
467 pub fn add_causal_edge(
471 &self,
472 causes: Vec<String>,
473 effects: Vec<String>,
474 confidence: f64,
475 context: String,
476 ) -> Result<String> {
477 let id = uuid::Uuid::new_v4().to_string();
478
479 let embedding = self.generate_text_embedding(&context)?;
481
482 let edge = CausalEdge {
483 id: id.clone(),
484 causes,
485 effects,
486 confidence,
487 context,
488 embedding: embedding.clone(),
489 observations: 1,
490 timestamp: chrono::Utc::now().timestamp(),
491 };
492
493 let write_txn = self.db.begin_write()?;
495 {
496 let mut table = write_txn.open_table(CAUSAL_TABLE)?;
497 let data = bincode::encode_to_vec(&edge, bincode::config::standard())
498 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
499 table.insert(id.as_str(), data.as_slice())?;
500 }
501 write_txn.commit()?;
502
503 self.vector_db.insert(VectorEntry {
505 id: Some(format!("causal_{}", id)),
506 vector: embedding,
507 metadata: Some({
508 let mut meta = HashMap::new();
509 meta.insert("type".to_string(), serde_json::json!("causal"));
510 meta.insert("causal_id".to_string(), serde_json::json!(id.clone()));
511 meta.insert("confidence".to_string(), serde_json::json!(confidence));
512 meta
513 }),
514 })?;
515
516 Ok(id)
517 }
518
519 pub fn query_with_utility(
521 &self,
522 query: &str,
523 k: usize,
524 alpha: f64,
525 beta: f64,
526 gamma: f64,
527 ) -> Result<Vec<UtilitySearchResult>> {
528 let start_time = std::time::Instant::now();
529 let query_embedding = self.generate_text_embedding(query)?;
530
531 let results = self.vector_db.search(SearchQuery {
533 vector: query_embedding,
534 k: k * 2, filter: Some({
536 let mut filter = HashMap::new();
537 filter.insert("type".to_string(), serde_json::json!("causal"));
538 filter
539 }),
540 ef_search: None,
541 })?;
542
543 let mut utility_results = Vec::new();
544
545 for result in results {
546 let similarity_score = 1.0 / (1.0 + result.score as f64); let causal_uplift = if let Some(ref metadata) = result.metadata {
550 metadata
551 .get("confidence")
552 .and_then(|v| v.as_f64())
553 .unwrap_or(0.0)
554 } else {
555 0.0
556 };
557
558 let latency = start_time.elapsed().as_secs_f64();
559 let latency_penalty = latency * gamma;
560
561 let utility_score = alpha * similarity_score + beta * causal_uplift - latency_penalty;
563
564 utility_results.push(UtilitySearchResult {
565 result,
566 utility_score,
567 similarity_score,
568 causal_uplift,
569 latency_penalty,
570 });
571 }
572
573 utility_results.sort_by(|a, b| b.utility_score.partial_cmp(&a.utility_score).unwrap());
575 utility_results.truncate(k);
576
577 Ok(utility_results)
578 }
579
580 pub fn start_session(
584 &self,
585 algorithm: String,
586 state_dim: usize,
587 action_dim: usize,
588 ) -> Result<String> {
589 let id = uuid::Uuid::new_v4().to_string();
590
591 let session = LearningSession {
592 id: id.clone(),
593 algorithm,
594 state_dim,
595 action_dim,
596 experiences: Vec::new(),
597 model_params: None,
598 created_at: chrono::Utc::now().timestamp(),
599 updated_at: chrono::Utc::now().timestamp(),
600 };
601
602 let write_txn = self.db.begin_write()?;
603 {
604 let mut table = write_txn.open_table(LEARNING_TABLE)?;
605 let data = bincode::encode_to_vec(&session, bincode::config::standard())
606 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
607 table.insert(id.as_str(), data.as_slice())?;
608 }
609 write_txn.commit()?;
610
611 Ok(id)
612 }
613
614 pub fn add_experience(
616 &self,
617 session_id: &str,
618 state: Vec<f32>,
619 action: Vec<f32>,
620 reward: f64,
621 next_state: Vec<f32>,
622 done: bool,
623 ) -> Result<()> {
624 let read_txn = self.db.begin_read()?;
625 let table = read_txn.open_table(LEARNING_TABLE)?;
626
627 let data = table
628 .get(session_id)?
629 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
630
631 let (mut session, _): (LearningSession, usize) =
632 bincode::decode_from_slice(data.value(), bincode::config::standard())
633 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
634
635 drop(table);
636 drop(read_txn);
637
638 session.experiences.push(Experience {
640 state,
641 action,
642 reward,
643 next_state,
644 done,
645 timestamp: chrono::Utc::now().timestamp(),
646 });
647 session.updated_at = chrono::Utc::now().timestamp();
648
649 let write_txn = self.db.begin_write()?;
651 {
652 let mut table = write_txn.open_table(LEARNING_TABLE)?;
653 let data = bincode::encode_to_vec(&session, bincode::config::standard())
654 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
655 table.insert(session_id, data.as_slice())?;
656 }
657 write_txn.commit()?;
658
659 Ok(())
660 }
661
662 pub fn predict_with_confidence(&self, session_id: &str, state: Vec<f32>) -> Result<Prediction> {
664 let read_txn = self.db.begin_read()?;
665 let table = read_txn.open_table(LEARNING_TABLE)?;
666
667 let data = table
668 .get(session_id)?
669 .ok_or_else(|| RuvectorError::VectorNotFound(session_id.to_string()))?;
670
671 let (session, _): (LearningSession, usize) =
672 bincode::decode_from_slice(data.value(), bincode::config::standard())
673 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
674
675 let mut similar_actions = Vec::new();
677 let mut rewards = Vec::new();
678
679 for exp in &session.experiences {
680 let distance = euclidean_distance(&state, &exp.state);
681 if distance < 1.0 {
682 similar_actions.push(exp.action.clone());
684 rewards.push(exp.reward);
685 }
686 }
687
688 if similar_actions.is_empty() {
689 return Ok(Prediction {
691 action: vec![0.0; session.action_dim],
692 confidence_lower: 0.0,
693 confidence_upper: 0.0,
694 mean_confidence: 0.0,
695 });
696 }
697
698 let total_reward: f64 = rewards.iter().sum();
700 let mut action = vec![0.0; session.action_dim];
701
702 for (act, reward) in similar_actions.iter().zip(rewards.iter()) {
703 let weight = reward / total_reward;
704 for (i, val) in act.iter().enumerate() {
705 action[i] += val * weight as f32;
706 }
707 }
708
709 let mean_reward = total_reward / rewards.len() as f64;
711 let std_dev = calculate_std_dev(&rewards, mean_reward);
712
713 Ok(Prediction {
714 action,
715 confidence_lower: mean_reward - 1.96 * std_dev,
716 confidence_upper: mean_reward + 1.96 * std_dev,
717 mean_confidence: mean_reward,
718 })
719 }
720
721 pub fn get_session(&self, session_id: &str) -> Result<Option<LearningSession>> {
723 let read_txn = self.db.begin_read()?;
724 let table = read_txn.open_table(LEARNING_TABLE)?;
725
726 if let Some(data) = table.get(session_id)? {
727 let (session, _): (LearningSession, usize) =
728 bincode::decode_from_slice(data.value(), bincode::config::standard())
729 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
730 Ok(Some(session))
731 } else {
732 Ok(None)
733 }
734 }
735
736 fn generate_text_embedding(&self, text: &str) -> Result<Vec<f32>> {
759 self.embedding_provider.embed(text)
760 }
761}
762
763fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
765 a.iter()
766 .zip(b.iter())
767 .map(|(x, y)| (x - y).powi(2))
768 .sum::<f32>()
769 .sqrt()
770}
771
772fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
773 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
774 variance.sqrt()
775}
776
777pub struct PolicyMemoryStore<'a> {
791 db: &'a AgenticDB,
792}
793
794#[derive(Debug, Clone, Serialize, Deserialize)]
796pub struct PolicyAction {
797 pub action: String,
799 pub reward: f64,
801 pub q_value: f64,
803 pub state_embedding: Vec<f32>,
805 pub timestamp: i64,
807}
808
809#[derive(Debug, Clone, Serialize, Deserialize)]
811pub struct PolicyEntry {
812 pub id: String,
814 pub state_id: String,
816 pub action: PolicyAction,
818 pub metadata: Option<HashMap<String, serde_json::Value>>,
820}
821
822impl<'a> PolicyMemoryStore<'a> {
823 pub fn new(db: &'a AgenticDB) -> Self {
825 Self { db }
826 }
827
828 pub fn store_policy(
830 &self,
831 state_id: &str,
832 state_embedding: Vec<f32>,
833 action: &str,
834 reward: f64,
835 q_value: f64,
836 ) -> Result<String> {
837 let id = uuid::Uuid::new_v4().to_string();
838 let timestamp = chrono::Utc::now().timestamp();
839
840 let entry = PolicyEntry {
841 id: id.clone(),
842 state_id: state_id.to_string(),
843 action: PolicyAction {
844 action: action.to_string(),
845 reward,
846 q_value,
847 state_embedding: state_embedding.clone(),
848 timestamp,
849 },
850 metadata: None,
851 };
852
853 self.db.vector_db.insert(VectorEntry {
855 id: Some(format!("policy_{}", id)),
856 vector: state_embedding,
857 metadata: Some({
858 let mut meta = HashMap::new();
859 meta.insert("type".to_string(), serde_json::json!("policy"));
860 meta.insert("policy_id".to_string(), serde_json::json!(id.clone()));
861 meta.insert("state_id".to_string(), serde_json::json!(state_id));
862 meta.insert("action".to_string(), serde_json::json!(action));
863 meta.insert("reward".to_string(), serde_json::json!(reward));
864 meta.insert("q_value".to_string(), serde_json::json!(q_value));
865 meta
866 }),
867 })?;
868
869 Ok(id)
870 }
871
872 pub fn retrieve_similar_states(
874 &self,
875 state_embedding: &[f32],
876 k: usize,
877 ) -> Result<Vec<PolicyEntry>> {
878 let results = self.db.vector_db.search(SearchQuery {
879 vector: state_embedding.to_vec(),
880 k,
881 filter: Some({
882 let mut filter = HashMap::new();
883 filter.insert("type".to_string(), serde_json::json!("policy"));
884 filter
885 }),
886 ef_search: None,
887 })?;
888
889 let mut entries = Vec::new();
890 for result in results {
891 if let Some(metadata) = result.metadata {
892 let policy_id = metadata
893 .get("policy_id")
894 .and_then(|v| v.as_str())
895 .unwrap_or("");
896 let state_id = metadata
897 .get("state_id")
898 .and_then(|v| v.as_str())
899 .unwrap_or("");
900 let action = metadata
901 .get("action")
902 .and_then(|v| v.as_str())
903 .unwrap_or("");
904 let reward = metadata
905 .get("reward")
906 .and_then(|v| v.as_f64())
907 .unwrap_or(0.0);
908 let q_value = metadata
909 .get("q_value")
910 .and_then(|v| v.as_f64())
911 .unwrap_or(0.0);
912
913 entries.push(PolicyEntry {
914 id: policy_id.to_string(),
915 state_id: state_id.to_string(),
916 action: PolicyAction {
917 action: action.to_string(),
918 reward,
919 q_value,
920 state_embedding: result.vector.unwrap_or_default(),
921 timestamp: 0,
922 },
923 metadata: None,
924 });
925 }
926 }
927
928 Ok(entries)
929 }
930
931 pub fn get_best_action(&self, state_embedding: &[f32], k: usize) -> Result<Option<String>> {
933 let similar = self.retrieve_similar_states(state_embedding, k)?;
934
935 similar
936 .into_iter()
937 .max_by(|a, b| a.action.q_value.partial_cmp(&b.action.q_value).unwrap())
938 .map(|entry| Ok(entry.action.action))
939 .transpose()
940 }
941
942 pub fn update_q_value(&self, policy_id: &str, new_q_value: f64) -> Result<()> {
944 let _ = self.db.vector_db.delete(&format!("policy_{}", policy_id));
947 Ok(())
948 }
949}
950
951pub struct SessionStateIndex<'a> {
956 db: &'a AgenticDB,
957 session_id: String,
958 ttl_seconds: i64,
959}
960
961#[derive(Debug, Clone, Serialize, Deserialize)]
963pub struct SessionTurn {
964 pub id: String,
966 pub session_id: String,
968 pub turn_number: usize,
970 pub role: String,
972 pub content: String,
974 pub embedding: Vec<f32>,
976 pub timestamp: i64,
978 pub expires_at: i64,
980}
981
982impl<'a> SessionStateIndex<'a> {
983 pub fn new(db: &'a AgenticDB, session_id: &str, ttl_seconds: i64) -> Self {
985 Self {
986 db,
987 session_id: session_id.to_string(),
988 ttl_seconds,
989 }
990 }
991
992 pub fn add_turn(&self, turn_number: usize, role: &str, content: &str) -> Result<String> {
994 let id = uuid::Uuid::new_v4().to_string();
995 let timestamp = chrono::Utc::now().timestamp();
996 let expires_at = timestamp + self.ttl_seconds;
997
998 let embedding = self.db.generate_text_embedding(content)?;
1000
1001 self.db.vector_db.insert(VectorEntry {
1003 id: Some(format!("session_{}_{}", self.session_id, id)),
1004 vector: embedding,
1005 metadata: Some({
1006 let mut meta = HashMap::new();
1007 meta.insert("type".to_string(), serde_json::json!("session_turn"));
1008 meta.insert(
1009 "session_id".to_string(),
1010 serde_json::json!(self.session_id.clone()),
1011 );
1012 meta.insert("turn_id".to_string(), serde_json::json!(id.clone()));
1013 meta.insert("turn_number".to_string(), serde_json::json!(turn_number));
1014 meta.insert("role".to_string(), serde_json::json!(role));
1015 meta.insert("content".to_string(), serde_json::json!(content));
1016 meta.insert("timestamp".to_string(), serde_json::json!(timestamp));
1017 meta.insert("expires_at".to_string(), serde_json::json!(expires_at));
1018 meta
1019 }),
1020 })?;
1021
1022 Ok(id)
1023 }
1024
1025 pub fn find_relevant_turns(&self, query: &str, k: usize) -> Result<Vec<SessionTurn>> {
1027 let query_embedding = self.db.generate_text_embedding(query)?;
1028 let current_time = chrono::Utc::now().timestamp();
1029
1030 let results = self.db.vector_db.search(SearchQuery {
1031 vector: query_embedding,
1032 k: k * 2, filter: Some({
1034 let mut filter = HashMap::new();
1035 filter.insert("type".to_string(), serde_json::json!("session_turn"));
1036 filter.insert(
1037 "session_id".to_string(),
1038 serde_json::json!(self.session_id.clone()),
1039 );
1040 filter
1041 }),
1042 ef_search: None,
1043 })?;
1044
1045 let mut turns = Vec::new();
1046 for result in results {
1047 if let Some(metadata) = result.metadata {
1048 let expires_at = metadata
1049 .get("expires_at")
1050 .and_then(|v| v.as_i64())
1051 .unwrap_or(0);
1052
1053 if expires_at < current_time {
1055 continue;
1056 }
1057
1058 turns.push(SessionTurn {
1059 id: metadata
1060 .get("turn_id")
1061 .and_then(|v| v.as_str())
1062 .unwrap_or("")
1063 .to_string(),
1064 session_id: self.session_id.clone(),
1065 turn_number: metadata
1066 .get("turn_number")
1067 .and_then(|v| v.as_u64())
1068 .unwrap_or(0) as usize,
1069 role: metadata
1070 .get("role")
1071 .and_then(|v| v.as_str())
1072 .unwrap_or("")
1073 .to_string(),
1074 content: metadata
1075 .get("content")
1076 .and_then(|v| v.as_str())
1077 .unwrap_or("")
1078 .to_string(),
1079 embedding: result.vector.unwrap_or_default(),
1080 timestamp: metadata
1081 .get("timestamp")
1082 .and_then(|v| v.as_i64())
1083 .unwrap_or(0),
1084 expires_at,
1085 });
1086
1087 if turns.len() >= k {
1088 break;
1089 }
1090 }
1091 }
1092
1093 Ok(turns)
1094 }
1095
1096 pub fn get_session_context(&self) -> Result<Vec<SessionTurn>> {
1098 let mut turns = self.find_relevant_turns("", 1000)?;
1099 turns.sort_by_key(|t| t.turn_number);
1100 Ok(turns)
1101 }
1102
1103 pub fn cleanup_expired(&self) -> Result<usize> {
1105 let current_time = chrono::Utc::now().timestamp();
1106 let all_turns = self.find_relevant_turns("", 10000)?;
1107 let mut deleted = 0;
1108
1109 for turn in all_turns {
1110 if turn.expires_at < current_time {
1111 let _ = self
1112 .db
1113 .vector_db
1114 .delete(&format!("session_{}_{}", self.session_id, turn.id));
1115 deleted += 1;
1116 }
1117 }
1118
1119 Ok(deleted)
1120 }
1121}
1122
1123pub struct WitnessLog<'a> {
1127 db: &'a AgenticDB,
1128 last_hash: RwLock<Option<String>>,
1129}
1130
1131#[derive(Debug, Clone, Serialize, Deserialize)]
1133pub struct WitnessEntry {
1134 pub id: String,
1136 pub prev_hash: Option<String>,
1138 pub hash: String,
1140 pub agent_id: String,
1142 pub action_type: String,
1144 pub details: String,
1146 pub embedding: Vec<f32>,
1148 pub timestamp: i64,
1150 pub metadata: Option<HashMap<String, serde_json::Value>>,
1152}
1153
1154impl<'a> WitnessLog<'a> {
1155 pub fn new(db: &'a AgenticDB) -> Self {
1157 Self {
1158 db,
1159 last_hash: RwLock::new(None),
1160 }
1161 }
1162
1163 fn compute_hash(
1165 prev_hash: &Option<String>,
1166 agent_id: &str,
1167 action_type: &str,
1168 details: &str,
1169 timestamp: i64,
1170 ) -> String {
1171 use std::collections::hash_map::DefaultHasher;
1172 use std::hash::{Hash, Hasher};
1173
1174 let mut hasher = DefaultHasher::new();
1175 if let Some(prev) = prev_hash {
1176 prev.hash(&mut hasher);
1177 }
1178 agent_id.hash(&mut hasher);
1179 action_type.hash(&mut hasher);
1180 details.hash(&mut hasher);
1181 timestamp.hash(&mut hasher);
1182 format!("{:016x}", hasher.finish())
1183 }
1184
1185 pub fn append(&self, agent_id: &str, action_type: &str, details: &str) -> Result<String> {
1187 let id = uuid::Uuid::new_v4().to_string();
1188 let timestamp = chrono::Utc::now().timestamp();
1189
1190 let prev_hash = self.last_hash.read().clone();
1192
1193 let hash = Self::compute_hash(&prev_hash, agent_id, action_type, details, timestamp);
1195
1196 let embedding = self
1198 .db
1199 .generate_text_embedding(&format!("{} {} {}", agent_id, action_type, details))?;
1200
1201 self.db.vector_db.insert(VectorEntry {
1203 id: Some(format!("witness_{}", id)),
1204 vector: embedding.clone(),
1205 metadata: Some({
1206 let mut meta = HashMap::new();
1207 meta.insert("type".to_string(), serde_json::json!("witness"));
1208 meta.insert("witness_id".to_string(), serde_json::json!(id.clone()));
1209 meta.insert("agent_id".to_string(), serde_json::json!(agent_id));
1210 meta.insert("action_type".to_string(), serde_json::json!(action_type));
1211 meta.insert("details".to_string(), serde_json::json!(details));
1212 meta.insert("timestamp".to_string(), serde_json::json!(timestamp));
1213 meta.insert("hash".to_string(), serde_json::json!(hash.clone()));
1214 if let Some(ref prev) = prev_hash {
1215 meta.insert("prev_hash".to_string(), serde_json::json!(prev));
1216 }
1217 meta
1218 }),
1219 })?;
1220
1221 *self.last_hash.write() = Some(hash.clone());
1223
1224 Ok(id)
1225 }
1226
1227 pub fn search(&self, query: &str, k: usize) -> Result<Vec<WitnessEntry>> {
1229 let query_embedding = self.db.generate_text_embedding(query)?;
1230
1231 let results = self.db.vector_db.search(SearchQuery {
1232 vector: query_embedding,
1233 k,
1234 filter: Some({
1235 let mut filter = HashMap::new();
1236 filter.insert("type".to_string(), serde_json::json!("witness"));
1237 filter
1238 }),
1239 ef_search: None,
1240 })?;
1241
1242 let mut entries = Vec::new();
1243 for result in results {
1244 if let Some(metadata) = result.metadata {
1245 entries.push(WitnessEntry {
1246 id: metadata
1247 .get("witness_id")
1248 .and_then(|v| v.as_str())
1249 .unwrap_or("")
1250 .to_string(),
1251 prev_hash: metadata
1252 .get("prev_hash")
1253 .and_then(|v| v.as_str())
1254 .map(|s| s.to_string()),
1255 hash: metadata
1256 .get("hash")
1257 .and_then(|v| v.as_str())
1258 .unwrap_or("")
1259 .to_string(),
1260 agent_id: metadata
1261 .get("agent_id")
1262 .and_then(|v| v.as_str())
1263 .unwrap_or("")
1264 .to_string(),
1265 action_type: metadata
1266 .get("action_type")
1267 .and_then(|v| v.as_str())
1268 .unwrap_or("")
1269 .to_string(),
1270 details: metadata
1271 .get("details")
1272 .and_then(|v| v.as_str())
1273 .unwrap_or("")
1274 .to_string(),
1275 embedding: result.vector.unwrap_or_default(),
1276 timestamp: metadata
1277 .get("timestamp")
1278 .and_then(|v| v.as_i64())
1279 .unwrap_or(0),
1280 metadata: None,
1281 });
1282 }
1283 }
1284
1285 Ok(entries)
1286 }
1287
1288 pub fn get_by_agent(&self, agent_id: &str, k: usize) -> Result<Vec<WitnessEntry>> {
1290 self.search(agent_id, k)
1292 }
1293
1294 pub fn verify_chain(&self) -> Result<bool> {
1296 let entries = self.search("", 10000)?;
1297
1298 let mut sorted_entries = entries;
1300 sorted_entries.sort_by_key(|e| e.timestamp);
1301
1302 for i in 1..sorted_entries.len() {
1304 let prev = &sorted_entries[i - 1];
1305 let curr = &sorted_entries[i];
1306
1307 if let Some(ref prev_hash) = curr.prev_hash {
1308 if prev_hash != &prev.hash {
1309 return Ok(false);
1310 }
1311 }
1312 }
1313
1314 Ok(true)
1315 }
1316}
1317
1318impl AgenticDB {
1319 pub fn policy_memory(&self) -> PolicyMemoryStore<'_> {
1321 PolicyMemoryStore::new(self)
1322 }
1323
1324 pub fn session_index(&self, session_id: &str, ttl_seconds: i64) -> SessionStateIndex<'_> {
1326 SessionStateIndex::new(self, session_id, ttl_seconds)
1327 }
1328
1329 pub fn witness_log(&self) -> WitnessLog<'_> {
1331 WitnessLog::new(self)
1332 }
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337 use super::*;
1338 use tempfile::tempdir;
1339
1340 fn create_test_db() -> Result<AgenticDB> {
1341 let dir = tempdir().unwrap();
1342 let mut options = DbOptions::default();
1343 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
1344 options.dimensions = 128;
1345 AgenticDB::new(options)
1346 }
1347
1348 #[test]
1349 fn test_reflexion_episode() -> Result<()> {
1350 let db = create_test_db()?;
1351
1352 let id = db.store_episode(
1353 "Solve math problem".to_string(),
1354 vec!["read problem".to_string(), "calculate".to_string()],
1355 vec!["got 42".to_string()],
1356 "Should have shown work".to_string(),
1357 )?;
1358
1359 let episodes = db.retrieve_similar_episodes("math problem solving", 5)?;
1360 assert!(!episodes.is_empty());
1361 assert_eq!(episodes[0].id, id);
1362
1363 Ok(())
1364 }
1365
1366 #[test]
1367 fn test_skill_library() -> Result<()> {
1368 let db = create_test_db()?;
1369
1370 let mut params = HashMap::new();
1371 params.insert("input".to_string(), "string".to_string());
1372
1373 let skill_id = db.create_skill(
1374 "Parse JSON".to_string(),
1375 "Parse JSON from string".to_string(),
1376 params,
1377 vec!["json.parse()".to_string()],
1378 )?;
1379
1380 let skills = db.search_skills("parse json data", 5)?;
1381 assert!(!skills.is_empty());
1382
1383 Ok(())
1384 }
1385
1386 #[test]
1387 fn test_causal_edge() -> Result<()> {
1388 let db = create_test_db()?;
1389
1390 let edge_id = db.add_causal_edge(
1391 vec!["rain".to_string()],
1392 vec!["wet ground".to_string()],
1393 0.95,
1394 "Weather observation".to_string(),
1395 )?;
1396
1397 let results = db.query_with_utility("weather patterns", 5, 0.7, 0.2, 0.1)?;
1398 assert!(!results.is_empty());
1399
1400 Ok(())
1401 }
1402
1403 #[test]
1404 fn test_learning_session() -> Result<()> {
1405 let db = create_test_db()?;
1406
1407 let session_id = db.start_session("Q-Learning".to_string(), 4, 2)?;
1408
1409 db.add_experience(
1410 &session_id,
1411 vec![1.0, 0.0, 0.0, 0.0],
1412 vec![1.0, 0.0],
1413 1.0,
1414 vec![0.0, 1.0, 0.0, 0.0],
1415 false,
1416 )?;
1417
1418 let prediction = db.predict_with_confidence(&session_id, vec![1.0, 0.0, 0.0, 0.0])?;
1419 assert_eq!(prediction.action.len(), 2);
1420
1421 Ok(())
1422 }
1423
1424 #[test]
1425 fn test_auto_consolidate() -> Result<()> {
1426 let db = create_test_db()?;
1427
1428 let sequences = vec![
1429 vec![
1430 "step1".to_string(),
1431 "step2".to_string(),
1432 "step3".to_string(),
1433 ],
1434 vec![
1435 "action1".to_string(),
1436 "action2".to_string(),
1437 "action3".to_string(),
1438 ],
1439 ];
1440
1441 let skill_ids = db.auto_consolidate(sequences, 3)?;
1442 assert_eq!(skill_ids.len(), 2);
1443
1444 Ok(())
1445 }
1446}