1use std::pin::Pin;
13
14use zeph_common::memory::{
15 AsyncMemoryRouter, ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk,
16 MemGraphFact, MemGraphNeighbor, MemPersonaFact, MemReasoningStrategy, MemRecalledMessage,
17 MemSessionSummary, MemSummary, MemTrajectoryEntry, MemTreeNode, RecallView,
18};
19use zeph_memory::semantic::SemanticMemory;
20use zeph_memory::{ConversationId, RecallView as MemRecallView, RecalledFact};
21
22fn map_persona_fact(r: zeph_memory::PersonaFactRow) -> MemPersonaFact {
23 MemPersonaFact {
24 category: r.category,
25 content: r.content,
26 }
27}
28
29fn map_trajectory_entry(r: zeph_memory::TrajectoryEntryRow) -> MemTrajectoryEntry {
30 MemTrajectoryEntry {
31 intent: r.intent,
32 outcome: r.outcome,
33 confidence: r.confidence,
34 }
35}
36
37fn map_tree_node(r: zeph_memory::MemoryTreeRow) -> MemTreeNode {
38 MemTreeNode { content: r.content }
39}
40
41fn map_summary(r: zeph_memory::semantic::Summary) -> MemSummary {
42 MemSummary {
43 first_message_id: r.first_message_id.map(|m| m.0),
44 last_message_id: r.last_message_id.map(|m| m.0),
45 content: r.content,
46 }
47}
48
49fn map_reasoning_strategy(s: zeph_memory::ReasoningStrategy) -> MemReasoningStrategy {
50 MemReasoningStrategy {
51 id: s.id,
52 outcome: s.outcome.as_str().to_owned(),
53 summary: s.summary,
54 }
55}
56
57fn map_correction(c: zeph_memory::UserCorrectionRow) -> MemCorrection {
58 MemCorrection {
59 correction_text: c.correction_text,
60 }
61}
62
63fn map_recalled_message(r: zeph_memory::RecalledMessage) -> MemRecalledMessage {
64 use zeph_llm::provider::Role;
65 let role = match r.message.role {
66 Role::User => "user",
67 Role::Assistant => "assistant",
68 Role::System => "system",
69 }
70 .to_owned();
71 MemRecalledMessage {
72 role,
73 content: r.message.content,
74 score: r.score,
75 }
76}
77
78fn map_graph_fact(rf: RecalledFact) -> MemGraphFact {
79 MemGraphFact {
80 fact: rf.fact.fact,
81 confidence: rf.fact.confidence,
82 activation_score: rf.activation_score,
83 neighbors: rf
84 .neighbors
85 .into_iter()
86 .map(|n| MemGraphNeighbor {
87 fact: n.fact,
88 confidence: n.confidence,
89 })
90 .collect(),
91 provenance_snippet: rf.provenance_snippet,
92 }
93}
94
95fn map_session_summary(r: zeph_memory::semantic::SessionSummaryResult) -> MemSessionSummary {
96 MemSessionSummary {
97 summary_text: r.summary_text,
98 score: r.score,
99 }
100}
101
102pub struct SemanticMemoryBackend {
104 inner: std::sync::Arc<SemanticMemory>,
105}
106
107impl SemanticMemoryBackend {
108 #[must_use]
110 pub fn new(inner: std::sync::Arc<SemanticMemory>) -> Self {
111 Self { inner }
112 }
113}
114
115type BoxFut<'a, T> = Pin<
116 Box<
117 dyn std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>
118 + Send
119 + 'a,
120 >,
121>;
122
123impl ContextMemoryBackend for SemanticMemoryBackend {
124 fn load_persona_facts(&self, min_confidence: f64) -> BoxFut<'_, Vec<MemPersonaFact>> {
125 Box::pin(async move {
126 let rows = self
127 .inner
128 .sqlite()
129 .load_persona_facts(min_confidence)
130 .await
131 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
132 Ok(rows.into_iter().map(map_persona_fact).collect())
133 })
134 }
135
136 fn load_trajectory_entries<'a>(
137 &'a self,
138 tier: Option<&'a str>,
139 top_k: usize,
140 ) -> BoxFut<'a, Vec<MemTrajectoryEntry>> {
141 Box::pin(async move {
142 let rows = self
143 .inner
144 .sqlite()
145 .load_trajectory_entries(tier, top_k)
146 .await
147 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
148 Ok(rows.into_iter().map(map_trajectory_entry).collect())
149 })
150 }
151
152 fn load_tree_nodes(&self, level: u32, top_k: usize) -> BoxFut<'_, Vec<MemTreeNode>> {
153 Box::pin(async move {
154 let rows = self
155 .inner
156 .sqlite()
157 .load_tree_level(level.into(), top_k)
158 .await
159 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
160 Ok(rows.into_iter().map(map_tree_node).collect())
161 })
162 }
163
164 fn load_summaries(&self, conversation_id: i64) -> BoxFut<'_, Vec<MemSummary>> {
165 Box::pin(async move {
166 let cid = ConversationId(conversation_id);
167 let rows = self
168 .inner
169 .load_summaries(cid)
170 .await
171 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
172 Ok(rows.into_iter().map(map_summary).collect())
173 })
174 }
175
176 fn retrieve_reasoning_strategies<'a>(
177 &'a self,
178 query: &'a str,
179 top_k: usize,
180 ) -> BoxFut<'a, Vec<MemReasoningStrategy>> {
181 Box::pin(async move {
182 let strategies = self
183 .inner
184 .retrieve_reasoning_strategies(query, top_k)
185 .await
186 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
187 Ok(strategies.into_iter().map(map_reasoning_strategy).collect())
188 })
189 }
190
191 fn mark_reasoning_used<'a>(&'a self, ids: &'a [String]) -> BoxFut<'a, ()> {
192 Box::pin(async move {
193 if let Some(ref reasoning) = self.inner.reasoning {
194 reasoning
195 .mark_used(ids)
196 .await
197 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
198 }
199 Ok(())
200 })
201 }
202
203 fn retrieve_corrections<'a>(
204 &'a self,
205 query: &'a str,
206 limit: usize,
207 min_score: f32,
208 ) -> BoxFut<'a, Vec<MemCorrection>> {
209 Box::pin(async move {
210 let corrections = self
211 .inner
212 .retrieve_similar_corrections(query, limit, min_score)
213 .await
214 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
215 Ok(corrections.into_iter().map(map_correction).collect())
216 })
217 }
218
219 fn recall<'a>(
220 &'a self,
221 query: &'a str,
222 limit: usize,
223 router: Option<&'a dyn AsyncMemoryRouter>,
224 ) -> BoxFut<'a, Vec<MemRecalledMessage>> {
225 Box::pin(async move {
226 let recalled = if let Some(r) = router {
227 self.inner
228 .recall_routed_async(query, limit, None, r)
229 .await
230 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
231 } else {
232 self.inner
233 .recall(query, limit, None)
234 .await
235 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
236 };
237 Ok(recalled.into_iter().map(map_recalled_message).collect())
238 })
239 }
240
241 fn recall_graph_facts<'a>(
242 &'a self,
243 query: &'a str,
244 params: GraphRecallParams<'a>,
245 ) -> BoxFut<'a, Vec<MemGraphFact>> {
246 Box::pin(async move {
247 let mem_view = match params.view {
248 RecallView::Head => MemRecallView::Head,
249 RecallView::ZoomIn => MemRecallView::ZoomIn,
250 RecallView::ZoomOut => MemRecallView::ZoomOut,
251 };
252 let mem_edge_types: Vec<zeph_memory::EdgeType> = params
253 .edge_types
254 .iter()
255 .map(|e| {
256 use zeph_common::memory::EdgeType as CE;
257 use zeph_memory::EdgeType as ME;
258 match e {
259 CE::Semantic => ME::Semantic,
260 CE::Temporal => ME::Temporal,
261 CE::Causal => ME::Causal,
262 CE::Entity => ME::Entity,
263 }
264 })
265 .collect();
266 let sa_params = params.spreading_activation.map(|p| {
267 zeph_memory::graph::SpreadingActivationParams {
268 decay_lambda: p.decay_lambda,
269 max_hops: p.max_hops,
270 activation_threshold: p.activation_threshold,
271 inhibition_threshold: p.inhibition_threshold,
272 max_activated_nodes: p.max_activated_nodes,
273 temporal_decay_rate: p.temporal_decay_rate,
274 seed_structural_weight: p.seed_structural_weight,
275 seed_community_cap: p.seed_community_cap,
276 }
277 });
278 let recalled = self
279 .inner
280 .recall_graph_view(
281 query,
282 params.limit,
283 mem_view,
284 params.zoom_out_neighbor_cap,
285 params.max_hops,
286 params.temporal_decay_rate,
287 &mem_edge_types,
288 sa_params,
289 )
290 .await
291 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
292 Ok(recalled.into_iter().map(map_graph_fact).collect())
293 })
294 }
295
296 fn search_session_summaries<'a>(
297 &'a self,
298 query: &'a str,
299 limit: usize,
300 current_conversation_id: Option<i64>,
301 ) -> BoxFut<'a, Vec<MemSessionSummary>> {
302 Box::pin(async move {
303 let cid = current_conversation_id.map(ConversationId);
304 let results = self
305 .inner
306 .search_session_summaries(query, limit, cid)
307 .await
308 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
309 Ok(results.into_iter().map(map_session_summary).collect())
310 })
311 }
312
313 fn search_document_collection<'a>(
314 &'a self,
315 collection: &'a str,
316 query: &'a str,
317 top_k: usize,
318 ) -> BoxFut<'a, Vec<MemDocumentChunk>> {
319 Box::pin(async move {
320 let points = self
321 .inner
322 .search_document_collection(collection, query, top_k)
323 .await
324 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
325 Ok(points
326 .into_iter()
327 .map(|p| {
328 let text = p
329 .payload
330 .get("text")
331 .and_then(|v| v.as_str())
332 .unwrap_or_default()
333 .to_owned();
334 MemDocumentChunk { text }
335 })
336 .collect())
337 })
338 }
339}
340
341pub struct TokenCounterAdapter(std::sync::Arc<zeph_memory::TokenCounter>);
344
345impl TokenCounterAdapter {
346 #[must_use]
348 pub fn new(inner: std::sync::Arc<zeph_memory::TokenCounter>) -> Self {
349 Self(inner)
350 }
351}
352
353impl zeph_context::summarization::MessageTokenCounter for TokenCounterAdapter {
354 fn count_message_tokens(&self, msg: &zeph_llm::provider::Message) -> usize {
355 self.0.count_message_tokens(msg)
356 }
357}
358
359#[must_use]
367pub fn build_memory_router(
368 manager: &zeph_context::manager::ContextManager,
369) -> Box<dyn zeph_common::memory::AsyncMemoryRouter + Send + Sync> {
370 use zeph_common::memory::parse_route_str;
371 use zeph_config::StoreRoutingStrategy;
372
373 if !manager.routing.enabled {
374 return Box::new(zeph_memory::HeuristicRouter);
375 }
376 let fallback = parse_route_str(
377 &manager.routing.fallback_route,
378 zeph_common::memory::MemoryRoute::Hybrid,
379 );
380 match manager.routing.strategy {
381 StoreRoutingStrategy::Heuristic => Box::new(zeph_memory::HeuristicRouter),
382 StoreRoutingStrategy::Llm => {
383 let Some(provider) = manager.store_routing_provider.clone() else {
384 tracing::warn!(
385 "store_routing: strategy=llm but no provider resolved; \
386 falling back to heuristic"
387 );
388 return Box::new(zeph_memory::HeuristicRouter);
389 };
390 Box::new(zeph_memory::LlmRouter::new(provider, fallback))
391 }
392 StoreRoutingStrategy::Hybrid => {
393 let Some(provider) = manager.store_routing_provider.clone() else {
394 tracing::warn!(
395 "store_routing: strategy=hybrid but no provider resolved; \
396 falling back to heuristic"
397 );
398 return Box::new(zeph_memory::HeuristicRouter);
399 };
400 Box::new(zeph_memory::HybridRouter::new(
401 provider,
402 fallback,
403 manager.routing.confidence_threshold,
404 ))
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use zeph_llm::provider::{Message, Role};
412 use zeph_memory::graph::types::{EdgeType, GraphFact};
413 use zeph_memory::semantic::{SessionSummaryResult, Summary};
414 use zeph_memory::types::{ConversationId, MessageId};
415 use zeph_memory::{
416 MemoryTreeRow, Outcome, PersonaFactRow, ReasoningStrategy, RecalledFact, RecalledMessage,
417 TrajectoryEntryRow, UserCorrectionRow,
418 };
419
420 use super::*;
421
422 fn make_persona_row() -> PersonaFactRow {
423 PersonaFactRow {
424 id: 1,
425 category: "preference".to_owned(),
426 content: "prefers short answers".to_owned(),
427 confidence: 0.9,
428 evidence_count: 3,
429 source_conversation_id: None,
430 supersedes_id: None,
431 created_at: "2026-01-01".to_owned(),
432 updated_at: "2026-01-02".to_owned(),
433 }
434 }
435
436 fn make_trajectory_row() -> TrajectoryEntryRow {
437 TrajectoryEntryRow {
438 id: 1,
439 conversation_id: Some(42),
440 turn_index: 5,
441 kind: "procedural".to_owned(),
442 intent: "read a file".to_owned(),
443 outcome: "file read successfully".to_owned(),
444 tools_used: "read_file".to_owned(),
445 confidence: 0.85,
446 created_at: "2026-01-01".to_owned(),
447 updated_at: "2026-01-01".to_owned(),
448 }
449 }
450
451 fn make_tree_row() -> MemoryTreeRow {
452 MemoryTreeRow {
453 id: 1,
454 level: 0,
455 parent_id: None,
456 content: "node content here".to_owned(),
457 source_ids: "1,2,3".to_owned(),
458 token_count: 10,
459 consolidated_at: None,
460 created_at: "2026-01-01".to_owned(),
461 }
462 }
463
464 fn make_summary() -> Summary {
465 Summary {
466 id: 1,
467 conversation_id: ConversationId(10),
468 content: "summary of the conversation".to_owned(),
469 first_message_id: Some(MessageId(5)),
470 last_message_id: Some(MessageId(20)),
471 token_estimate: 100,
472 }
473 }
474
475 fn make_reasoning_strategy() -> ReasoningStrategy {
476 ReasoningStrategy {
477 id: "strat-uuid-1".to_owned(),
478 summary: "break the problem into parts".to_owned(),
479 outcome: Outcome::Success,
480 task_hint: "code refactoring task".to_owned(),
481 created_at: 1_700_000_000,
482 last_used_at: 1_700_000_100,
483 use_count: 3,
484 embedded_at: Some(1_700_000_050),
485 }
486 }
487
488 fn make_correction_row() -> UserCorrectionRow {
489 UserCorrectionRow {
490 id: 1,
491 session_id: Some(7),
492 original_output: "wrong output".to_owned(),
493 correction_text: "use bullet points".to_owned(),
494 skill_name: Some("formatting".to_owned()),
495 correction_kind: "explicit_rejection".to_owned(),
496 created_at: "2026-01-01".to_owned(),
497 }
498 }
499
500 fn make_recalled_message(role: Role) -> RecalledMessage {
501 RecalledMessage {
502 message: Message {
503 role,
504 content: "hello world".to_owned(),
505 ..Default::default()
506 },
507 score: 0.75,
508 }
509 }
510
511 fn make_graph_fact() -> GraphFact {
512 GraphFact {
513 entity_name: "Rust".to_owned(),
514 relation: "uses".to_owned(),
515 target_name: "LLVM".to_owned(),
516 fact: "Rust uses LLVM".to_owned(),
517 entity_match_score: 0.9,
518 hop_distance: 0,
519 confidence: 0.95,
520 valid_from: None,
521 edge_type: EdgeType::Semantic,
522 retrieval_count: 1,
523 edge_id: Some(10),
524 }
525 }
526
527 fn make_recalled_fact() -> RecalledFact {
528 RecalledFact::from_graph_fact(make_graph_fact())
529 }
530
531 fn make_session_summary() -> SessionSummaryResult {
532 SessionSummaryResult {
533 summary_text: "yesterday's session about Rust".to_owned(),
534 score: 0.88,
535 conversation_id: ConversationId(99),
536 }
537 }
538
539 #[test]
542 fn persona_fact_maps_fields() {
543 let row = make_persona_row();
544 let dto = map_persona_fact(row);
545 assert_eq!(dto.category, "preference");
546 assert_eq!(dto.content, "prefers short answers");
547 }
548
549 #[test]
552 fn trajectory_entry_maps_fields() {
553 let row = make_trajectory_row();
554 let dto = map_trajectory_entry(row);
555 assert_eq!(dto.intent, "read a file");
556 assert_eq!(dto.outcome, "file read successfully");
557 assert!((dto.confidence - 0.85).abs() < f64::EPSILON);
558 }
559
560 #[test]
563 fn tree_node_maps_content() {
564 let row = make_tree_row();
565 let dto = map_tree_node(row);
566 assert_eq!(dto.content, "node content here");
567 }
568
569 #[test]
572 fn summary_maps_all_fields() {
573 let s = make_summary();
574 let dto = map_summary(s);
575 assert_eq!(dto.first_message_id, Some(5));
576 assert_eq!(dto.last_message_id, Some(20));
577 assert_eq!(dto.content, "summary of the conversation");
578 }
579
580 #[test]
581 fn summary_none_message_ids_stay_none() {
582 let s = Summary {
583 id: 2,
584 conversation_id: ConversationId(1),
585 content: "shutdown summary".to_owned(),
586 first_message_id: None,
587 last_message_id: None,
588 token_estimate: 50,
589 };
590 let dto = map_summary(s);
591 assert!(dto.first_message_id.is_none());
592 assert!(dto.last_message_id.is_none());
593 }
594
595 #[test]
598 fn reasoning_strategy_maps_success_outcome() {
599 let s = make_reasoning_strategy();
600 let dto = map_reasoning_strategy(s);
601 assert_eq!(dto.id, "strat-uuid-1");
602 assert_eq!(dto.outcome, "success");
603 assert_eq!(dto.summary, "break the problem into parts");
604 }
605
606 #[test]
607 fn reasoning_strategy_maps_failure_outcome() {
608 let mut s = make_reasoning_strategy();
609 s.outcome = Outcome::Failure;
610 let dto = map_reasoning_strategy(s);
611 assert_eq!(dto.outcome, "failure");
612 }
613
614 #[test]
617 fn correction_maps_text() {
618 let row = make_correction_row();
619 let dto = map_correction(row);
620 assert_eq!(dto.correction_text, "use bullet points");
621 }
622
623 #[test]
626 fn recalled_message_maps_user_role() {
627 let rm = make_recalled_message(Role::User);
628 let dto = map_recalled_message(rm);
629 assert_eq!(dto.role, "user");
630 assert_eq!(dto.content, "hello world");
631 assert!((dto.score - 0.75).abs() < f32::EPSILON);
632 }
633
634 #[test]
635 fn recalled_message_maps_assistant_role() {
636 let rm = make_recalled_message(Role::Assistant);
637 let dto = map_recalled_message(rm);
638 assert_eq!(dto.role, "assistant");
639 assert!((dto.score - 0.75).abs() < f32::EPSILON);
640 }
641
642 #[test]
643 fn recalled_message_maps_system_role() {
644 let rm = make_recalled_message(Role::System);
645 let dto = map_recalled_message(rm);
646 assert_eq!(dto.role, "system");
647 assert!((dto.score - 0.75).abs() < f32::EPSILON);
648 }
649
650 #[test]
653 fn graph_fact_maps_basic_fields() {
654 let rf = make_recalled_fact();
655 let dto = map_graph_fact(rf);
656 assert_eq!(dto.fact, "Rust uses LLVM");
657 assert!((dto.confidence - 0.95).abs() < f32::EPSILON);
658 assert!(dto.activation_score.is_none());
659 assert!(dto.neighbors.is_empty());
660 assert!(dto.provenance_snippet.is_none());
661 }
662
663 #[test]
664 fn graph_fact_maps_activation_score() {
665 let mut rf = make_recalled_fact();
666 rf.activation_score = Some(0.82);
667 let dto = map_graph_fact(rf);
668 assert!(
669 dto.activation_score
670 .is_some_and(|s| (s - 0.82_f32).abs() < f32::EPSILON)
671 );
672 }
673
674 #[test]
675 fn graph_fact_maps_neighbors() {
676 let mut rf = make_recalled_fact();
677 rf.neighbors.push(GraphFact {
678 entity_name: "LLVM".to_owned(),
679 relation: "supports".to_owned(),
680 target_name: "WebAssembly".to_owned(),
681 fact: "LLVM supports WebAssembly".to_owned(),
682 entity_match_score: 0.5,
683 hop_distance: 1,
684 confidence: 0.8,
685 valid_from: None,
686 edge_type: EdgeType::Semantic,
687 retrieval_count: 0,
688 edge_id: None,
689 });
690 let dto = map_graph_fact(rf);
691 assert_eq!(dto.neighbors.len(), 1);
692 assert_eq!(dto.neighbors[0].fact, "LLVM supports WebAssembly");
693 assert!((dto.neighbors[0].confidence - 0.8).abs() < f32::EPSILON);
694 }
695
696 #[test]
697 fn graph_fact_maps_provenance_snippet() {
698 let mut rf = make_recalled_fact();
699 rf.provenance_snippet = Some("Rust compiler snippet".to_owned());
700 let dto = map_graph_fact(rf);
701 assert_eq!(
702 dto.provenance_snippet.as_deref(),
703 Some("Rust compiler snippet")
704 );
705 }
706
707 #[test]
710 fn session_summary_maps_fields() {
711 let r = make_session_summary();
712 let dto = map_session_summary(r);
713 assert_eq!(dto.summary_text, "yesterday's session about Rust");
714 assert!((dto.score - 0.88).abs() < f32::EPSILON);
715 }
716
717 #[test]
718 fn session_summary_score_zero() {
719 let r = SessionSummaryResult {
720 summary_text: "empty session".to_owned(),
721 score: 0.0,
722 conversation_id: ConversationId(1),
723 };
724 let dto = map_session_summary(r);
725 assert!(dto.score.abs() < f32::EPSILON);
726 }
727
728 #[test]
729 fn session_summary_score_one() {
730 let r = SessionSummaryResult {
731 summary_text: "perfect match".to_owned(),
732 score: 1.0,
733 conversation_id: ConversationId(1),
734 };
735 let dto = map_session_summary(r);
736 assert!((dto.score - 1.0_f32).abs() < f32::EPSILON);
737 }
738}