1use std::pin::Pin;
13
14use zeph_common::memory::{
15 AsyncMemoryRouter, ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk,
16 MemGraphFact, MemGraphNeighbor, MemPersonaFact, MemReasoningStrategy, MemRecalledMessage,
17 MemSessionSummary, MemSummary, MemTrajectoryEntry, MemTreeNode, RecallView,
18};
19use zeph_memory::semantic::SemanticMemory;
20use zeph_memory::{ConversationId, RecallView as MemRecallView, RecalledFact};
21
22fn box_err<E: std::error::Error + Send + Sync + 'static>(
23 e: E,
24) -> Box<dyn std::error::Error + Send + Sync> {
25 Box::new(e)
26}
27
28fn map_persona_fact(r: zeph_memory::PersonaFactRow) -> MemPersonaFact {
29 MemPersonaFact {
30 category: r.category,
31 content: r.content,
32 }
33}
34
35fn map_trajectory_entry(r: zeph_memory::TrajectoryEntryRow) -> MemTrajectoryEntry {
36 MemTrajectoryEntry {
37 intent: r.intent,
38 outcome: r.outcome,
39 confidence: r.confidence,
40 }
41}
42
43fn map_tree_node(r: zeph_memory::MemoryTreeRow) -> MemTreeNode {
44 MemTreeNode { content: r.content }
45}
46
47fn map_summary(r: zeph_memory::semantic::Summary) -> MemSummary {
48 MemSummary {
49 first_message_id: r.first_message_id.map(|m| m.0),
50 last_message_id: r.last_message_id.map(|m| m.0),
51 content: r.content,
52 }
53}
54
55fn map_reasoning_strategy(s: zeph_memory::ReasoningStrategy) -> MemReasoningStrategy {
56 MemReasoningStrategy {
57 id: s.id,
58 outcome: s.outcome.as_str().to_owned(),
59 summary: s.summary,
60 }
61}
62
63fn map_correction(c: zeph_memory::UserCorrectionRow) -> MemCorrection {
64 MemCorrection {
65 correction_text: c.correction_text,
66 }
67}
68
69fn map_recalled_message(r: zeph_memory::RecalledMessage) -> MemRecalledMessage {
70 use zeph_llm::provider::Role;
71 let role = match r.message.role {
72 Role::Assistant => "assistant",
73 Role::System => "system",
74 Role::User | _ => "user",
75 }
76 .to_owned();
77 MemRecalledMessage {
78 role,
79 content: r.message.content,
80 score: r.score,
81 }
82}
83
84fn map_graph_fact(rf: RecalledFact) -> MemGraphFact {
85 MemGraphFact {
86 fact: rf.fact.fact,
87 confidence: rf.fact.confidence,
88 activation_score: rf.activation_score,
89 neighbors: rf
90 .neighbors
91 .into_iter()
92 .map(|n| MemGraphNeighbor {
93 fact: n.fact,
94 confidence: n.confidence,
95 })
96 .collect(),
97 provenance_snippet: rf.provenance_snippet,
98 }
99}
100
101fn map_session_summary(r: zeph_memory::semantic::SessionSummaryResult) -> MemSessionSummary {
102 MemSessionSummary {
103 summary_text: r.summary_text,
104 score: r.score,
105 }
106}
107
108pub struct SemanticMemoryBackend {
110 inner: std::sync::Arc<SemanticMemory>,
111}
112
113impl SemanticMemoryBackend {
114 #[must_use]
116 pub fn new(inner: std::sync::Arc<SemanticMemory>) -> Self {
117 Self { inner }
118 }
119}
120
121type BoxFut<'a, T> = Pin<
122 Box<
123 dyn std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>
124 + Send
125 + 'a,
126 >,
127>;
128
129impl ContextMemoryBackend for SemanticMemoryBackend {
130 fn load_persona_facts(&self, min_confidence: f64) -> BoxFut<'_, Vec<MemPersonaFact>> {
131 Box::pin(async move {
132 let rows = self
133 .inner
134 .sqlite()
135 .load_persona_facts(min_confidence)
136 .await
137 .map_err(box_err)?;
138 Ok(rows.into_iter().map(map_persona_fact).collect())
139 })
140 }
141
142 fn load_trajectory_entries<'a>(
143 &'a self,
144 tier: Option<&'a str>,
145 top_k: usize,
146 ) -> BoxFut<'a, Vec<MemTrajectoryEntry>> {
147 Box::pin(async move {
148 let rows = self
149 .inner
150 .sqlite()
151 .load_trajectory_entries(tier, top_k)
152 .await
153 .map_err(box_err)?;
154 Ok(rows.into_iter().map(map_trajectory_entry).collect())
155 })
156 }
157
158 fn load_tree_nodes(&self, level: u32, top_k: usize) -> BoxFut<'_, Vec<MemTreeNode>> {
159 Box::pin(async move {
160 let rows = self
161 .inner
162 .sqlite()
163 .load_tree_level(level.into(), top_k)
164 .await
165 .map_err(box_err)?;
166 Ok(rows.into_iter().map(map_tree_node).collect())
167 })
168 }
169
170 fn load_summaries(&self, conversation_id: i64) -> BoxFut<'_, Vec<MemSummary>> {
171 Box::pin(async move {
172 let cid = ConversationId(conversation_id);
173 let rows = self.inner.load_summaries(cid).await.map_err(box_err)?;
174 Ok(rows.into_iter().map(map_summary).collect())
175 })
176 }
177
178 fn retrieve_reasoning_strategies<'a>(
179 &'a self,
180 query: &'a str,
181 top_k: usize,
182 ) -> BoxFut<'a, Vec<MemReasoningStrategy>> {
183 Box::pin(async move {
184 let strategies = self
185 .inner
186 .retrieve_reasoning_strategies(query, top_k)
187 .await
188 .map_err(box_err)?;
189 Ok(strategies.into_iter().map(map_reasoning_strategy).collect())
190 })
191 }
192
193 fn mark_reasoning_used<'a>(&'a self, ids: &'a [String]) -> BoxFut<'a, ()> {
194 Box::pin(async move {
195 if let Some(ref reasoning) = self.inner.reasoning {
196 reasoning.mark_used(ids).await.map_err(box_err)?;
197 }
198 Ok(())
199 })
200 }
201
202 fn retrieve_corrections<'a>(
203 &'a self,
204 query: &'a str,
205 limit: usize,
206 min_score: f32,
207 ) -> BoxFut<'a, Vec<MemCorrection>> {
208 Box::pin(async move {
209 let corrections = self
210 .inner
211 .retrieve_similar_corrections(query, limit, min_score)
212 .await
213 .map_err(box_err)?;
214 Ok(corrections.into_iter().map(map_correction).collect())
215 })
216 }
217
218 fn recall<'a>(
219 &'a self,
220 query: &'a str,
221 limit: usize,
222 router: Option<&'a dyn AsyncMemoryRouter>,
223 ) -> BoxFut<'a, Vec<MemRecalledMessage>> {
224 Box::pin(async move {
225 let recalled = if let Some(r) = router {
226 self.inner
227 .recall_routed_async(query, limit, None, r, None)
228 .await
229 .map_err(box_err)?
230 } else {
231 self.inner
232 .recall(query, limit, None)
233 .await
234 .map_err(box_err)?
235 };
236 Ok(recalled.into_iter().map(map_recalled_message).collect())
237 })
238 }
239
240 fn recall_graph_facts<'a>(
241 &'a self,
242 query: &'a str,
243 params: GraphRecallParams<'a>,
244 ) -> BoxFut<'a, Vec<MemGraphFact>> {
245 Box::pin(async move {
246 let mem_view = match params.view {
247 RecallView::ZoomIn => MemRecallView::ZoomIn,
248 RecallView::ZoomOut => MemRecallView::ZoomOut,
249 _ => MemRecallView::Head,
250 };
251 let mem_edge_types: Vec<zeph_memory::EdgeType> = params
252 .edge_types
253 .iter()
254 .map(|e| {
255 use zeph_common::memory::EdgeType as CE;
256 use zeph_memory::EdgeType as ME;
257 match e {
258 CE::Temporal => ME::Temporal,
259 CE::Causal => ME::Causal,
260 CE::Entity => ME::Entity,
261 _ => ME::Semantic,
262 }
263 })
264 .collect();
265 let sa_params = params.spreading_activation.map(|p| {
266 zeph_memory::graph::SpreadingActivationParams {
267 decay_lambda: p.decay_lambda,
268 max_hops: p.max_hops,
269 activation_threshold: p.activation_threshold,
270 inhibition_threshold: p.inhibition_threshold,
271 max_activated_nodes: p.max_activated_nodes,
272 temporal_decay_rate: p.temporal_decay_rate,
273 seed_structural_weight: p.seed_structural_weight,
274 seed_community_cap: p.seed_community_cap,
275 alpha: p.alpha,
276 }
277 });
278 let recalled = self
279 .inner
280 .recall_graph_view(
281 query,
282 params.limit,
283 mem_view,
284 params.zoom_out_neighbor_cap,
285 params.max_hops,
286 params.temporal_decay_rate,
287 &mem_edge_types,
288 sa_params,
289 )
290 .await
291 .map_err(box_err)?;
292 Ok(recalled.into_iter().map(map_graph_fact).collect())
293 })
294 }
295
296 fn search_session_summaries<'a>(
297 &'a self,
298 query: &'a str,
299 limit: usize,
300 current_conversation_id: Option<i64>,
301 ) -> BoxFut<'a, Vec<MemSessionSummary>> {
302 Box::pin(async move {
303 let cid = current_conversation_id.map(ConversationId);
304 let results = self
305 .inner
306 .search_session_summaries(query, limit, cid)
307 .await
308 .map_err(box_err)?;
309 Ok(results.into_iter().map(map_session_summary).collect())
310 })
311 }
312
313 fn search_document_collection<'a>(
314 &'a self,
315 collection: &'a str,
316 query: &'a str,
317 top_k: usize,
318 ) -> BoxFut<'a, Vec<MemDocumentChunk>> {
319 Box::pin(async move {
320 let points = self
321 .inner
322 .search_document_collection(collection, query, top_k)
323 .await
324 .map_err(box_err)?;
325 Ok(points
326 .into_iter()
327 .map(|p| {
328 let text = p
329 .payload
330 .get("text")
331 .and_then(|v| v.as_str())
332 .unwrap_or_default()
333 .to_owned();
334 MemDocumentChunk { text }
335 })
336 .collect())
337 })
338 }
339}
340
341pub struct TokenCounterAdapter(std::sync::Arc<zeph_memory::TokenCounter>);
344
345impl TokenCounterAdapter {
346 #[must_use]
348 pub fn new(inner: std::sync::Arc<zeph_memory::TokenCounter>) -> Self {
349 Self(inner)
350 }
351}
352
353impl zeph_context::summarization::MessageTokenCounter for TokenCounterAdapter {
354 fn count_message_tokens(&self, msg: &zeph_llm::provider::Message) -> usize {
355 self.0.count_message_tokens(msg)
356 }
357}
358
359#[must_use]
367pub fn build_memory_router(
368 manager: &zeph_context::manager::ContextManager,
369) -> Box<dyn zeph_common::memory::AsyncMemoryRouter + Send + Sync> {
370 use zeph_config::StoreRoutingStrategy;
371
372 if !manager.routing.enabled {
373 return Box::new(zeph_memory::HeuristicRouter);
374 }
375 let fallback = manager.routing.fallback_route;
376 match manager.routing.strategy {
377 StoreRoutingStrategy::Llm => {
378 let Some(provider) = manager.store_routing_provider.clone() else {
379 tracing::warn!(
380 "store_routing: strategy=llm but no provider resolved; \
381 falling back to heuristic"
382 );
383 return Box::new(zeph_memory::HeuristicRouter);
384 };
385 Box::new(zeph_memory::LlmRouter::new(provider, fallback))
386 }
387 StoreRoutingStrategy::Hybrid => {
388 let Some(provider) = manager.store_routing_provider.clone() else {
389 tracing::warn!(
390 "store_routing: strategy=hybrid but no provider resolved; \
391 falling back to heuristic"
392 );
393 return Box::new(zeph_memory::HeuristicRouter);
394 };
395 Box::new(zeph_memory::HybridRouter::new(
396 provider,
397 fallback,
398 manager.routing.confidence_threshold,
399 ))
400 }
401 _ => Box::new(zeph_memory::HeuristicRouter),
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use zeph_llm::provider::{Message, Role};
408 use zeph_memory::graph::types::{EdgeType, GraphFact};
409 use zeph_memory::semantic::{SessionSummaryResult, Summary};
410 use zeph_memory::types::{ConversationId, MessageId};
411 use zeph_memory::{
412 MemoryTreeRow, Outcome, PersonaFactRow, ReasoningStrategy, RecalledFact, RecalledMessage,
413 TrajectoryEntryRow, UserCorrectionRow,
414 };
415
416 use super::*;
417
418 fn make_persona_row() -> PersonaFactRow {
419 PersonaFactRow {
420 id: 1,
421 category: "preference".to_owned(),
422 content: "prefers short answers".to_owned(),
423 confidence: 0.9,
424 evidence_count: 3,
425 source_conversation_id: None,
426 supersedes_id: None,
427 created_at: "2026-01-01".to_owned(),
428 updated_at: "2026-01-02".to_owned(),
429 }
430 }
431
432 fn make_trajectory_row() -> TrajectoryEntryRow {
433 TrajectoryEntryRow {
434 id: 1,
435 conversation_id: Some(42),
436 turn_index: 5,
437 kind: "procedural".to_owned(),
438 intent: "read a file".to_owned(),
439 outcome: "file read successfully".to_owned(),
440 tools_used: "read_file".to_owned(),
441 confidence: 0.85,
442 created_at: "2026-01-01".to_owned(),
443 updated_at: "2026-01-01".to_owned(),
444 }
445 }
446
447 fn make_tree_row() -> MemoryTreeRow {
448 MemoryTreeRow {
449 id: 1,
450 level: 0,
451 parent_id: None,
452 content: "node content here".to_owned(),
453 source_ids: "1,2,3".to_owned(),
454 token_count: 10,
455 consolidated_at: None,
456 created_at: "2026-01-01".to_owned(),
457 }
458 }
459
460 fn make_summary() -> Summary {
461 Summary {
462 id: 1,
463 conversation_id: ConversationId(10),
464 content: "summary of the conversation".to_owned(),
465 first_message_id: Some(MessageId(5)),
466 last_message_id: Some(MessageId(20)),
467 token_estimate: 100,
468 }
469 }
470
471 fn make_reasoning_strategy() -> ReasoningStrategy {
472 ReasoningStrategy {
473 id: "strat-uuid-1".to_owned(),
474 summary: "break the problem into parts".to_owned(),
475 outcome: Outcome::Success,
476 task_hint: "code refactoring task".to_owned(),
477 created_at: 1_700_000_000,
478 last_used_at: 1_700_000_100,
479 use_count: 3,
480 embedded_at: Some(1_700_000_050),
481 }
482 }
483
484 fn make_correction_row() -> UserCorrectionRow {
485 UserCorrectionRow {
486 id: 1,
487 session_id: Some(7),
488 original_output: "wrong output".to_owned(),
489 correction_text: "use bullet points".to_owned(),
490 skill_name: Some("formatting".to_owned()),
491 correction_kind: "explicit_rejection".to_owned(),
492 created_at: "2026-01-01".to_owned(),
493 }
494 }
495
496 fn make_recalled_message(role: Role) -> RecalledMessage {
497 RecalledMessage {
498 message: Message {
499 role,
500 content: "hello world".to_owned(),
501 ..Default::default()
502 },
503 score: 0.75,
504 }
505 }
506
507 fn make_graph_fact() -> GraphFact {
508 GraphFact {
509 entity_name: "Rust".to_owned(),
510 relation: "uses".to_owned(),
511 target_name: "LLVM".to_owned(),
512 fact: "Rust uses LLVM".to_owned(),
513 entity_match_score: 0.9,
514 hop_distance: 0,
515 confidence: 0.95,
516 valid_from: None,
517 edge_type: EdgeType::Semantic,
518 retrieval_count: 1,
519 edge_id: Some(10),
520 }
521 }
522
523 fn make_recalled_fact() -> RecalledFact {
524 RecalledFact::from_graph_fact(make_graph_fact())
525 }
526
527 fn make_session_summary() -> SessionSummaryResult {
528 SessionSummaryResult {
529 summary_text: "yesterday's session about Rust".to_owned(),
530 score: 0.88,
531 conversation_id: ConversationId(99),
532 }
533 }
534
535 #[test]
538 fn persona_fact_maps_fields() {
539 let row = make_persona_row();
540 let dto = map_persona_fact(row);
541 assert_eq!(dto.category, "preference");
542 assert_eq!(dto.content, "prefers short answers");
543 }
544
545 #[test]
548 fn trajectory_entry_maps_fields() {
549 let row = make_trajectory_row();
550 let dto = map_trajectory_entry(row);
551 assert_eq!(dto.intent, "read a file");
552 assert_eq!(dto.outcome, "file read successfully");
553 assert!((dto.confidence - 0.85).abs() < f64::EPSILON);
554 }
555
556 #[test]
559 fn tree_node_maps_content() {
560 let row = make_tree_row();
561 let dto = map_tree_node(row);
562 assert_eq!(dto.content, "node content here");
563 }
564
565 #[test]
568 fn summary_maps_all_fields() {
569 let s = make_summary();
570 let dto = map_summary(s);
571 assert_eq!(dto.first_message_id, Some(5));
572 assert_eq!(dto.last_message_id, Some(20));
573 assert_eq!(dto.content, "summary of the conversation");
574 }
575
576 #[test]
577 fn summary_none_message_ids_stay_none() {
578 let s = Summary {
579 id: 2,
580 conversation_id: ConversationId(1),
581 content: "shutdown summary".to_owned(),
582 first_message_id: None,
583 last_message_id: None,
584 token_estimate: 50,
585 };
586 let dto = map_summary(s);
587 assert!(dto.first_message_id.is_none());
588 assert!(dto.last_message_id.is_none());
589 }
590
591 #[test]
594 fn reasoning_strategy_maps_success_outcome() {
595 let s = make_reasoning_strategy();
596 let dto = map_reasoning_strategy(s);
597 assert_eq!(dto.id, "strat-uuid-1");
598 assert_eq!(dto.outcome, "success");
599 assert_eq!(dto.summary, "break the problem into parts");
600 }
601
602 #[test]
603 fn reasoning_strategy_maps_failure_outcome() {
604 let mut s = make_reasoning_strategy();
605 s.outcome = Outcome::Failure;
606 let dto = map_reasoning_strategy(s);
607 assert_eq!(dto.outcome, "failure");
608 }
609
610 #[test]
613 fn correction_maps_text() {
614 let row = make_correction_row();
615 let dto = map_correction(row);
616 assert_eq!(dto.correction_text, "use bullet points");
617 }
618
619 #[test]
622 fn recalled_message_maps_user_role() {
623 let rm = make_recalled_message(Role::User);
624 let dto = map_recalled_message(rm);
625 assert_eq!(dto.role, "user");
626 assert_eq!(dto.content, "hello world");
627 assert!((dto.score - 0.75).abs() < f32::EPSILON);
628 }
629
630 #[test]
631 fn recalled_message_maps_assistant_role() {
632 let rm = make_recalled_message(Role::Assistant);
633 let dto = map_recalled_message(rm);
634 assert_eq!(dto.role, "assistant");
635 assert!((dto.score - 0.75).abs() < f32::EPSILON);
636 }
637
638 #[test]
639 fn recalled_message_maps_system_role() {
640 let rm = make_recalled_message(Role::System);
641 let dto = map_recalled_message(rm);
642 assert_eq!(dto.role, "system");
643 assert!((dto.score - 0.75).abs() < f32::EPSILON);
644 }
645
646 #[test]
649 fn graph_fact_maps_basic_fields() {
650 let rf = make_recalled_fact();
651 let dto = map_graph_fact(rf);
652 assert_eq!(dto.fact, "Rust uses LLVM");
653 assert!((dto.confidence - 0.95).abs() < f32::EPSILON);
654 assert!(dto.activation_score.is_none());
655 assert!(dto.neighbors.is_empty());
656 assert!(dto.provenance_snippet.is_none());
657 }
658
659 #[test]
660 fn graph_fact_maps_activation_score() {
661 let mut rf = make_recalled_fact();
662 rf.activation_score = Some(0.82);
663 let dto = map_graph_fact(rf);
664 assert!(
665 dto.activation_score
666 .is_some_and(|s| (s - 0.82_f32).abs() < f32::EPSILON)
667 );
668 }
669
670 #[test]
671 fn graph_fact_maps_neighbors() {
672 let mut rf = make_recalled_fact();
673 rf.neighbors.push(GraphFact {
674 entity_name: "LLVM".to_owned(),
675 relation: "supports".to_owned(),
676 target_name: "WebAssembly".to_owned(),
677 fact: "LLVM supports WebAssembly".to_owned(),
678 entity_match_score: 0.5,
679 hop_distance: 1,
680 confidence: 0.8,
681 valid_from: None,
682 edge_type: EdgeType::Semantic,
683 retrieval_count: 0,
684 edge_id: None,
685 });
686 let dto = map_graph_fact(rf);
687 assert_eq!(dto.neighbors.len(), 1);
688 assert_eq!(dto.neighbors[0].fact, "LLVM supports WebAssembly");
689 assert!((dto.neighbors[0].confidence - 0.8).abs() < f32::EPSILON);
690 }
691
692 #[test]
693 fn graph_fact_maps_provenance_snippet() {
694 let mut rf = make_recalled_fact();
695 rf.provenance_snippet = Some("Rust compiler snippet".to_owned());
696 let dto = map_graph_fact(rf);
697 assert_eq!(
698 dto.provenance_snippet.as_deref(),
699 Some("Rust compiler snippet")
700 );
701 }
702
703 #[test]
706 fn session_summary_maps_fields() {
707 let r = make_session_summary();
708 let dto = map_session_summary(r);
709 assert_eq!(dto.summary_text, "yesterday's session about Rust");
710 assert!((dto.score - 0.88).abs() < f32::EPSILON);
711 }
712
713 #[test]
714 fn session_summary_score_zero() {
715 let r = SessionSummaryResult {
716 summary_text: "empty session".to_owned(),
717 score: 0.0,
718 conversation_id: ConversationId(1),
719 };
720 let dto = map_session_summary(r);
721 assert!(dto.score.abs() < f32::EPSILON);
722 }
723
724 #[test]
725 fn session_summary_score_one() {
726 let r = SessionSummaryResult {
727 summary_text: "perfect match".to_owned(),
728 score: 1.0,
729 conversation_id: ConversationId(1),
730 };
731 let dto = map_session_summary(r);
732 assert!((dto.score - 1.0_f32).abs() < f32::EPSILON);
733 }
734}