1use std::pin::Pin;
13
14use zeph_common::memory::{
15 AsyncMemoryRouter, ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk,
16 MemGraphFact, MemGraphNeighbor, MemPersonaFact, MemReasoningStrategy, MemRecalledMessage,
17 MemSessionSummary, MemSummary, MemTrajectoryEntry, MemTreeNode, RecallView,
18};
19use zeph_memory::semantic::SemanticMemory;
20use zeph_memory::{ConversationId, RecallView as MemRecallView, RecalledFact};
21
22fn map_persona_fact(r: zeph_memory::PersonaFactRow) -> MemPersonaFact {
23 MemPersonaFact {
24 category: r.category,
25 content: r.content,
26 }
27}
28
29fn map_trajectory_entry(r: zeph_memory::TrajectoryEntryRow) -> MemTrajectoryEntry {
30 MemTrajectoryEntry {
31 intent: r.intent,
32 outcome: r.outcome,
33 confidence: r.confidence,
34 }
35}
36
37fn map_tree_node(r: zeph_memory::MemoryTreeRow) -> MemTreeNode {
38 MemTreeNode { content: r.content }
39}
40
41fn map_summary(r: zeph_memory::semantic::Summary) -> MemSummary {
42 MemSummary {
43 first_message_id: r.first_message_id.map(|m| m.0),
44 last_message_id: r.last_message_id.map(|m| m.0),
45 content: r.content,
46 }
47}
48
49fn map_reasoning_strategy(s: zeph_memory::ReasoningStrategy) -> MemReasoningStrategy {
50 MemReasoningStrategy {
51 id: s.id,
52 outcome: s.outcome.as_str().to_owned(),
53 summary: s.summary,
54 }
55}
56
57fn map_correction(c: zeph_memory::UserCorrectionRow) -> MemCorrection {
58 MemCorrection {
59 correction_text: c.correction_text,
60 }
61}
62
63fn map_recalled_message(r: zeph_memory::RecalledMessage) -> MemRecalledMessage {
64 use zeph_llm::provider::Role;
65 let role = match r.message.role {
66 Role::User => "user",
67 Role::Assistant => "assistant",
68 Role::System => "system",
69 }
70 .to_owned();
71 MemRecalledMessage {
72 role,
73 content: r.message.content,
74 score: r.score,
75 }
76}
77
78fn map_graph_fact(rf: RecalledFact) -> MemGraphFact {
79 MemGraphFact {
80 fact: rf.fact.fact,
81 confidence: rf.fact.confidence,
82 activation_score: rf.activation_score,
83 neighbors: rf
84 .neighbors
85 .into_iter()
86 .map(|n| MemGraphNeighbor {
87 fact: n.fact,
88 confidence: n.confidence,
89 })
90 .collect(),
91 provenance_snippet: rf.provenance_snippet,
92 }
93}
94
95fn map_session_summary(r: zeph_memory::semantic::SessionSummaryResult) -> MemSessionSummary {
96 MemSessionSummary {
97 summary_text: r.summary_text,
98 score: r.score,
99 }
100}
101
102pub struct SemanticMemoryBackend {
104 inner: std::sync::Arc<SemanticMemory>,
105}
106
107impl SemanticMemoryBackend {
108 #[must_use]
110 pub fn new(inner: std::sync::Arc<SemanticMemory>) -> Self {
111 Self { inner }
112 }
113}
114
115type BoxFut<'a, T> = Pin<
116 Box<
117 dyn std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>
118 + Send
119 + 'a,
120 >,
121>;
122
123impl ContextMemoryBackend for SemanticMemoryBackend {
124 fn load_persona_facts(&self, min_confidence: f64) -> BoxFut<'_, Vec<MemPersonaFact>> {
125 Box::pin(async move {
126 let rows = self
127 .inner
128 .sqlite()
129 .load_persona_facts(min_confidence)
130 .await
131 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
132 Ok(rows.into_iter().map(map_persona_fact).collect())
133 })
134 }
135
136 fn load_trajectory_entries<'a>(
137 &'a self,
138 tier: Option<&'a str>,
139 top_k: usize,
140 ) -> BoxFut<'a, Vec<MemTrajectoryEntry>> {
141 Box::pin(async move {
142 let rows = self
143 .inner
144 .sqlite()
145 .load_trajectory_entries(tier, top_k)
146 .await
147 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
148 Ok(rows.into_iter().map(map_trajectory_entry).collect())
149 })
150 }
151
152 fn load_tree_nodes(&self, level: u32, top_k: usize) -> BoxFut<'_, Vec<MemTreeNode>> {
153 Box::pin(async move {
154 let rows = self
155 .inner
156 .sqlite()
157 .load_tree_level(level.into(), top_k)
158 .await
159 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
160 Ok(rows.into_iter().map(map_tree_node).collect())
161 })
162 }
163
164 fn load_summaries(&self, conversation_id: i64) -> BoxFut<'_, Vec<MemSummary>> {
165 Box::pin(async move {
166 let cid = ConversationId(conversation_id);
167 let rows = self
168 .inner
169 .load_summaries(cid)
170 .await
171 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
172 Ok(rows.into_iter().map(map_summary).collect())
173 })
174 }
175
176 fn retrieve_reasoning_strategies<'a>(
177 &'a self,
178 query: &'a str,
179 top_k: usize,
180 ) -> BoxFut<'a, Vec<MemReasoningStrategy>> {
181 Box::pin(async move {
182 let strategies = self
183 .inner
184 .retrieve_reasoning_strategies(query, top_k)
185 .await
186 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
187 Ok(strategies.into_iter().map(map_reasoning_strategy).collect())
188 })
189 }
190
191 fn mark_reasoning_used<'a>(&'a self, ids: &'a [String]) -> BoxFut<'a, ()> {
192 Box::pin(async move {
193 if let Some(ref reasoning) = self.inner.reasoning {
194 reasoning
195 .mark_used(ids)
196 .await
197 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
198 }
199 Ok(())
200 })
201 }
202
203 fn retrieve_corrections<'a>(
204 &'a self,
205 query: &'a str,
206 limit: usize,
207 min_score: f32,
208 ) -> BoxFut<'a, Vec<MemCorrection>> {
209 Box::pin(async move {
210 let corrections = self
211 .inner
212 .retrieve_similar_corrections(query, limit, min_score)
213 .await
214 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
215 Ok(corrections.into_iter().map(map_correction).collect())
216 })
217 }
218
219 fn recall<'a>(
220 &'a self,
221 query: &'a str,
222 limit: usize,
223 router: Option<&'a dyn AsyncMemoryRouter>,
224 ) -> BoxFut<'a, Vec<MemRecalledMessage>> {
225 Box::pin(async move {
226 let recalled = if let Some(r) = router {
227 self.inner
228 .recall_routed_async(query, limit, None, r)
229 .await
230 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
231 } else {
232 self.inner
233 .recall(query, limit, None)
234 .await
235 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
236 };
237 Ok(recalled.into_iter().map(map_recalled_message).collect())
238 })
239 }
240
241 fn recall_graph_facts<'a>(
242 &'a self,
243 query: &'a str,
244 params: GraphRecallParams<'a>,
245 ) -> BoxFut<'a, Vec<MemGraphFact>> {
246 Box::pin(async move {
247 let mem_view = match params.view {
248 RecallView::Head => MemRecallView::Head,
249 RecallView::ZoomIn => MemRecallView::ZoomIn,
250 RecallView::ZoomOut => MemRecallView::ZoomOut,
251 };
252 let mem_edge_types: Vec<zeph_memory::EdgeType> = params
253 .edge_types
254 .iter()
255 .map(|e| {
256 use zeph_common::memory::EdgeType as CE;
257 use zeph_memory::EdgeType as ME;
258 match e {
259 CE::Semantic => ME::Semantic,
260 CE::Temporal => ME::Temporal,
261 CE::Causal => ME::Causal,
262 CE::Entity => ME::Entity,
263 }
264 })
265 .collect();
266 let sa_params = params.spreading_activation.map(|p| {
267 zeph_memory::graph::SpreadingActivationParams {
268 decay_lambda: p.decay_lambda,
269 max_hops: p.max_hops,
270 activation_threshold: p.activation_threshold,
271 inhibition_threshold: p.inhibition_threshold,
272 max_activated_nodes: p.max_activated_nodes,
273 temporal_decay_rate: p.temporal_decay_rate,
274 seed_structural_weight: p.seed_structural_weight,
275 seed_community_cap: p.seed_community_cap,
276 }
277 });
278 let recalled = self
279 .inner
280 .recall_graph_view(
281 query,
282 params.limit,
283 mem_view,
284 params.zoom_out_neighbor_cap,
285 params.max_hops,
286 params.temporal_decay_rate,
287 &mem_edge_types,
288 sa_params,
289 )
290 .await
291 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
292 Ok(recalled.into_iter().map(map_graph_fact).collect())
293 })
294 }
295
296 fn search_session_summaries<'a>(
297 &'a self,
298 query: &'a str,
299 limit: usize,
300 current_conversation_id: Option<i64>,
301 ) -> BoxFut<'a, Vec<MemSessionSummary>> {
302 Box::pin(async move {
303 let cid = current_conversation_id.map(ConversationId);
304 let results = self
305 .inner
306 .search_session_summaries(query, limit, cid)
307 .await
308 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
309 Ok(results.into_iter().map(map_session_summary).collect())
310 })
311 }
312
313 fn search_document_collection<'a>(
314 &'a self,
315 collection: &'a str,
316 query: &'a str,
317 top_k: usize,
318 ) -> BoxFut<'a, Vec<MemDocumentChunk>> {
319 Box::pin(async move {
320 let points = self
321 .inner
322 .search_document_collection(collection, query, top_k)
323 .await
324 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
325 Ok(points
326 .into_iter()
327 .map(|p| {
328 let text = p
329 .payload
330 .get("text")
331 .and_then(|v| v.as_str())
332 .unwrap_or_default()
333 .to_owned();
334 MemDocumentChunk { text }
335 })
336 .collect())
337 })
338 }
339}
340
341pub struct TokenCounterAdapter(std::sync::Arc<zeph_memory::TokenCounter>);
344
345impl TokenCounterAdapter {
346 #[must_use]
348 pub fn new(inner: std::sync::Arc<zeph_memory::TokenCounter>) -> Self {
349 Self(inner)
350 }
351}
352
353impl zeph_context::summarization::MessageTokenCounter for TokenCounterAdapter {
354 fn count_message_tokens(&self, msg: &zeph_llm::provider::Message) -> usize {
355 self.0.count_message_tokens(msg)
356 }
357}
358
359#[must_use]
367pub fn build_memory_router(
368 manager: &zeph_context::manager::ContextManager,
369) -> Box<dyn zeph_common::memory::AsyncMemoryRouter + Send + Sync> {
370 use zeph_config::StoreRoutingStrategy;
371
372 if !manager.routing.enabled {
373 return Box::new(zeph_memory::HeuristicRouter);
374 }
375 let fallback = manager.routing.fallback_route;
376 match manager.routing.strategy {
377 StoreRoutingStrategy::Heuristic => Box::new(zeph_memory::HeuristicRouter),
378 StoreRoutingStrategy::Llm => {
379 let Some(provider) = manager.store_routing_provider.clone() else {
380 tracing::warn!(
381 "store_routing: strategy=llm but no provider resolved; \
382 falling back to heuristic"
383 );
384 return Box::new(zeph_memory::HeuristicRouter);
385 };
386 Box::new(zeph_memory::LlmRouter::new(provider, fallback))
387 }
388 StoreRoutingStrategy::Hybrid => {
389 let Some(provider) = manager.store_routing_provider.clone() else {
390 tracing::warn!(
391 "store_routing: strategy=hybrid but no provider resolved; \
392 falling back to heuristic"
393 );
394 return Box::new(zeph_memory::HeuristicRouter);
395 };
396 Box::new(zeph_memory::HybridRouter::new(
397 provider,
398 fallback,
399 manager.routing.confidence_threshold,
400 ))
401 }
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use zeph_llm::provider::{Message, Role};
408 use zeph_memory::graph::types::{EdgeType, GraphFact};
409 use zeph_memory::semantic::{SessionSummaryResult, Summary};
410 use zeph_memory::types::{ConversationId, MessageId};
411 use zeph_memory::{
412 MemoryTreeRow, Outcome, PersonaFactRow, ReasoningStrategy, RecalledFact, RecalledMessage,
413 TrajectoryEntryRow, UserCorrectionRow,
414 };
415
416 use super::*;
417
418 fn make_persona_row() -> PersonaFactRow {
419 PersonaFactRow {
420 id: 1,
421 category: "preference".to_owned(),
422 content: "prefers short answers".to_owned(),
423 confidence: 0.9,
424 evidence_count: 3,
425 source_conversation_id: None,
426 supersedes_id: None,
427 created_at: "2026-01-01".to_owned(),
428 updated_at: "2026-01-02".to_owned(),
429 }
430 }
431
432 fn make_trajectory_row() -> TrajectoryEntryRow {
433 TrajectoryEntryRow {
434 id: 1,
435 conversation_id: Some(42),
436 turn_index: 5,
437 kind: "procedural".to_owned(),
438 intent: "read a file".to_owned(),
439 outcome: "file read successfully".to_owned(),
440 tools_used: "read_file".to_owned(),
441 confidence: 0.85,
442 created_at: "2026-01-01".to_owned(),
443 updated_at: "2026-01-01".to_owned(),
444 }
445 }
446
447 fn make_tree_row() -> MemoryTreeRow {
448 MemoryTreeRow {
449 id: 1,
450 level: 0,
451 parent_id: None,
452 content: "node content here".to_owned(),
453 source_ids: "1,2,3".to_owned(),
454 token_count: 10,
455 consolidated_at: None,
456 created_at: "2026-01-01".to_owned(),
457 }
458 }
459
460 fn make_summary() -> Summary {
461 Summary {
462 id: 1,
463 conversation_id: ConversationId(10),
464 content: "summary of the conversation".to_owned(),
465 first_message_id: Some(MessageId(5)),
466 last_message_id: Some(MessageId(20)),
467 token_estimate: 100,
468 }
469 }
470
471 fn make_reasoning_strategy() -> ReasoningStrategy {
472 ReasoningStrategy {
473 id: "strat-uuid-1".to_owned(),
474 summary: "break the problem into parts".to_owned(),
475 outcome: Outcome::Success,
476 task_hint: "code refactoring task".to_owned(),
477 created_at: 1_700_000_000,
478 last_used_at: 1_700_000_100,
479 use_count: 3,
480 embedded_at: Some(1_700_000_050),
481 }
482 }
483
484 fn make_correction_row() -> UserCorrectionRow {
485 UserCorrectionRow {
486 id: 1,
487 session_id: Some(7),
488 original_output: "wrong output".to_owned(),
489 correction_text: "use bullet points".to_owned(),
490 skill_name: Some("formatting".to_owned()),
491 correction_kind: "explicit_rejection".to_owned(),
492 created_at: "2026-01-01".to_owned(),
493 }
494 }
495
496 fn make_recalled_message(role: Role) -> RecalledMessage {
497 RecalledMessage {
498 message: Message {
499 role,
500 content: "hello world".to_owned(),
501 ..Default::default()
502 },
503 score: 0.75,
504 }
505 }
506
507 fn make_graph_fact() -> GraphFact {
508 GraphFact {
509 entity_name: "Rust".to_owned(),
510 relation: "uses".to_owned(),
511 target_name: "LLVM".to_owned(),
512 fact: "Rust uses LLVM".to_owned(),
513 entity_match_score: 0.9,
514 hop_distance: 0,
515 confidence: 0.95,
516 valid_from: None,
517 edge_type: EdgeType::Semantic,
518 retrieval_count: 1,
519 edge_id: Some(10),
520 }
521 }
522
523 fn make_recalled_fact() -> RecalledFact {
524 RecalledFact::from_graph_fact(make_graph_fact())
525 }
526
527 fn make_session_summary() -> SessionSummaryResult {
528 SessionSummaryResult {
529 summary_text: "yesterday's session about Rust".to_owned(),
530 score: 0.88,
531 conversation_id: ConversationId(99),
532 }
533 }
534
535 #[test]
538 fn persona_fact_maps_fields() {
539 let row = make_persona_row();
540 let dto = map_persona_fact(row);
541 assert_eq!(dto.category, "preference");
542 assert_eq!(dto.content, "prefers short answers");
543 }
544
545 #[test]
548 fn trajectory_entry_maps_fields() {
549 let row = make_trajectory_row();
550 let dto = map_trajectory_entry(row);
551 assert_eq!(dto.intent, "read a file");
552 assert_eq!(dto.outcome, "file read successfully");
553 assert!((dto.confidence - 0.85).abs() < f64::EPSILON);
554 }
555
556 #[test]
559 fn tree_node_maps_content() {
560 let row = make_tree_row();
561 let dto = map_tree_node(row);
562 assert_eq!(dto.content, "node content here");
563 }
564
565 #[test]
568 fn summary_maps_all_fields() {
569 let s = make_summary();
570 let dto = map_summary(s);
571 assert_eq!(dto.first_message_id, Some(5));
572 assert_eq!(dto.last_message_id, Some(20));
573 assert_eq!(dto.content, "summary of the conversation");
574 }
575
576 #[test]
577 fn summary_none_message_ids_stay_none() {
578 let s = Summary {
579 id: 2,
580 conversation_id: ConversationId(1),
581 content: "shutdown summary".to_owned(),
582 first_message_id: None,
583 last_message_id: None,
584 token_estimate: 50,
585 };
586 let dto = map_summary(s);
587 assert!(dto.first_message_id.is_none());
588 assert!(dto.last_message_id.is_none());
589 }
590
591 #[test]
594 fn reasoning_strategy_maps_success_outcome() {
595 let s = make_reasoning_strategy();
596 let dto = map_reasoning_strategy(s);
597 assert_eq!(dto.id, "strat-uuid-1");
598 assert_eq!(dto.outcome, "success");
599 assert_eq!(dto.summary, "break the problem into parts");
600 }
601
602 #[test]
603 fn reasoning_strategy_maps_failure_outcome() {
604 let mut s = make_reasoning_strategy();
605 s.outcome = Outcome::Failure;
606 let dto = map_reasoning_strategy(s);
607 assert_eq!(dto.outcome, "failure");
608 }
609
610 #[test]
613 fn correction_maps_text() {
614 let row = make_correction_row();
615 let dto = map_correction(row);
616 assert_eq!(dto.correction_text, "use bullet points");
617 }
618
619 #[test]
622 fn recalled_message_maps_user_role() {
623 let rm = make_recalled_message(Role::User);
624 let dto = map_recalled_message(rm);
625 assert_eq!(dto.role, "user");
626 assert_eq!(dto.content, "hello world");
627 assert!((dto.score - 0.75).abs() < f32::EPSILON);
628 }
629
630 #[test]
631 fn recalled_message_maps_assistant_role() {
632 let rm = make_recalled_message(Role::Assistant);
633 let dto = map_recalled_message(rm);
634 assert_eq!(dto.role, "assistant");
635 assert!((dto.score - 0.75).abs() < f32::EPSILON);
636 }
637
638 #[test]
639 fn recalled_message_maps_system_role() {
640 let rm = make_recalled_message(Role::System);
641 let dto = map_recalled_message(rm);
642 assert_eq!(dto.role, "system");
643 assert!((dto.score - 0.75).abs() < f32::EPSILON);
644 }
645
646 #[test]
649 fn graph_fact_maps_basic_fields() {
650 let rf = make_recalled_fact();
651 let dto = map_graph_fact(rf);
652 assert_eq!(dto.fact, "Rust uses LLVM");
653 assert!((dto.confidence - 0.95).abs() < f32::EPSILON);
654 assert!(dto.activation_score.is_none());
655 assert!(dto.neighbors.is_empty());
656 assert!(dto.provenance_snippet.is_none());
657 }
658
659 #[test]
660 fn graph_fact_maps_activation_score() {
661 let mut rf = make_recalled_fact();
662 rf.activation_score = Some(0.82);
663 let dto = map_graph_fact(rf);
664 assert!(
665 dto.activation_score
666 .is_some_and(|s| (s - 0.82_f32).abs() < f32::EPSILON)
667 );
668 }
669
670 #[test]
671 fn graph_fact_maps_neighbors() {
672 let mut rf = make_recalled_fact();
673 rf.neighbors.push(GraphFact {
674 entity_name: "LLVM".to_owned(),
675 relation: "supports".to_owned(),
676 target_name: "WebAssembly".to_owned(),
677 fact: "LLVM supports WebAssembly".to_owned(),
678 entity_match_score: 0.5,
679 hop_distance: 1,
680 confidence: 0.8,
681 valid_from: None,
682 edge_type: EdgeType::Semantic,
683 retrieval_count: 0,
684 edge_id: None,
685 });
686 let dto = map_graph_fact(rf);
687 assert_eq!(dto.neighbors.len(), 1);
688 assert_eq!(dto.neighbors[0].fact, "LLVM supports WebAssembly");
689 assert!((dto.neighbors[0].confidence - 0.8).abs() < f32::EPSILON);
690 }
691
692 #[test]
693 fn graph_fact_maps_provenance_snippet() {
694 let mut rf = make_recalled_fact();
695 rf.provenance_snippet = Some("Rust compiler snippet".to_owned());
696 let dto = map_graph_fact(rf);
697 assert_eq!(
698 dto.provenance_snippet.as_deref(),
699 Some("Rust compiler snippet")
700 );
701 }
702
703 #[test]
706 fn session_summary_maps_fields() {
707 let r = make_session_summary();
708 let dto = map_session_summary(r);
709 assert_eq!(dto.summary_text, "yesterday's session about Rust");
710 assert!((dto.score - 0.88).abs() < f32::EPSILON);
711 }
712
713 #[test]
714 fn session_summary_score_zero() {
715 let r = SessionSummaryResult {
716 summary_text: "empty session".to_owned(),
717 score: 0.0,
718 conversation_id: ConversationId(1),
719 };
720 let dto = map_session_summary(r);
721 assert!(dto.score.abs() < f32::EPSILON);
722 }
723
724 #[test]
725 fn session_summary_score_one() {
726 let r = SessionSummaryResult {
727 summary_text: "perfect match".to_owned(),
728 score: 1.0,
729 conversation_id: ConversationId(1),
730 };
731 let dto = map_session_summary(r);
732 assert!((dto.score - 1.0_f32).abs() < f32::EPSILON);
733 }
734}