1use crate::channel::Channel;
5use zeph_agent_persistence::graph::{build_graph_extraction_config, collect_context_messages};
6use zeph_agent_persistence::{
7 MemoryPersistenceView, MetricsView, PersistMessageRequest, PersistenceService, SecurityView,
8};
9use zeph_llm::provider::{LlmProvider as _, MessagePart, Role};
10
11use super::Agent;
12
13impl<C: Channel> Agent<C> {
14 #[tracing::instrument(name = "core.persist.load_history", skip_all, level = "debug", err)]
30 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
31 let (Some(memory), Some(cid)) = (
32 self.services.memory.persistence.memory.as_ref(),
33 self.services.memory.persistence.conversation_id,
34 ) else {
35 return Ok(());
36 };
37
38 let memory = memory.clone();
40
41 let mut unsummarized = self.services.memory.persistence.unsummarized_count;
42 let memory_view = MemoryPersistenceView {
45 memory: Some(&memory),
46 conversation_id: self.services.memory.persistence.conversation_id,
47 autosave_assistant: self.services.memory.persistence.autosave_assistant,
48 autosave_min_length: self.services.memory.persistence.autosave_min_length,
49 unsummarized_count: &mut unsummarized,
50 goal_text: self.services.memory.extraction.goal_text.clone(),
51 };
52 let mut sqlite_delta = 0u64;
53 let mut embed_delta = 0u64;
54 let mut guard_delta = 0u64;
55 let mut metrics_view = MetricsView {
56 sqlite_message_count: &mut sqlite_delta,
57 embeddings_generated: &mut embed_delta,
58 exfiltration_memory_guards: &mut guard_delta,
59 };
60
61 let svc = PersistenceService::new();
62 let outcome = svc
63 .load_history(
64 &mut self.msg.messages,
65 &mut self.msg.last_persisted_message_id,
66 &mut self.msg.deferred_db_hide_ids,
67 &mut self.msg.deferred_db_summaries,
68 &memory_view,
69 &zeph_config::Config::default(),
70 &mut metrics_view,
71 )
72 .await
73 .map_err(|e| {
74 super::error::AgentError::Memory(zeph_memory::MemoryError::Other(e.to_string()))
75 })?;
76
77 self.services.memory.persistence.unsummarized_count = unsummarized;
79
80 if outcome.messages_loaded > 0 {
81 let _ = memory
83 .sqlite()
84 .increment_session_counts_for_conversation(cid)
85 .await
86 .inspect_err(|e| {
87 tracing::warn!(error = %e, "failed to increment tier session counts");
88 });
89 }
90
91 self.update_metrics(|m| {
93 m.sqlite_message_count = outcome.sqlite_total_messages;
94 });
95 if let Ok(count) = memory.sqlite().count_semantic_facts().await {
96 let count_u64 = u64::try_from(count).unwrap_or(0);
97 self.update_metrics(|m| {
98 m.semantic_fact_count = count_u64;
99 });
100 }
101 if let Ok(count) = memory.unsummarized_message_count(cid).await {
102 self.services.memory.persistence.unsummarized_count =
103 usize::try_from(count).unwrap_or(0);
104 }
105
106 self.recompute_prompt_tokens();
107 Ok(())
108 }
109
110 #[tracing::instrument(name = "core.persist.persist_message", skip_all, level = "debug")]
116 pub(crate) async fn persist_message(
117 &mut self,
118 role: Role,
119 content: &str,
120 parts: &[MessagePart],
121 has_injection_flags: bool,
122 ) {
123 let guard_event = self
127 .services
128 .security
129 .exfiltration_guard
130 .should_guard_memory_write(has_injection_flags);
131 if let Some(ref event) = guard_event {
132 tracing::warn!(
133 ?event,
134 "exfiltration guard: skipping Qdrant embedding for flagged content"
135 );
136 self.push_security_event(
137 zeph_common::SecurityEventCategory::ExfiltrationBlock,
138 "memory_write",
139 "Qdrant embedding skipped: flagged content",
140 );
141 }
142
143 let req = PersistMessageRequest::from_borrowed(role, content, parts, has_injection_flags);
144
145 let mut unsummarized = self.services.memory.persistence.unsummarized_count;
146 let memory_arc = self.services.memory.persistence.memory.clone();
147 let mut memory_view = MemoryPersistenceView {
148 memory: memory_arc.as_ref(),
149 conversation_id: self.services.memory.persistence.conversation_id,
150 autosave_assistant: self.services.memory.persistence.autosave_assistant,
151 autosave_min_length: self.services.memory.persistence.autosave_min_length,
152 unsummarized_count: &mut unsummarized,
153 goal_text: self.services.memory.extraction.goal_text.clone(),
154 };
155 let security = SecurityView {
156 guard_memory_writes: guard_event.is_some(),
157 _phantom: std::marker::PhantomData,
158 };
159 let mut sqlite_delta = 0u64;
160 let mut embed_delta = 0u64;
161 let mut guard_delta = 0u64;
162 let mut metrics_view = MetricsView {
163 sqlite_message_count: &mut sqlite_delta,
164 embeddings_generated: &mut embed_delta,
165 exfiltration_memory_guards: &mut guard_delta,
166 };
167
168 let svc = PersistenceService::new();
169 let outcome = svc
170 .persist_message(
171 req,
172 &mut self.msg.last_persisted_message_id,
173 &mut memory_view,
174 &security,
175 &zeph_config::Config::default(),
176 &mut metrics_view,
177 )
178 .await;
179
180 self.services.memory.persistence.unsummarized_count = unsummarized;
182
183 self.update_metrics(|m| {
185 m.sqlite_message_count += sqlite_delta;
186 m.embeddings_generated += embed_delta;
187 m.exfiltration_memory_guards += guard_delta;
189 });
190
191 if outcome.message_id.is_none() {
192 return;
193 }
194
195 self.enqueue_summarization_task();
199
200 let has_tool_result_parts = parts
203 .iter()
204 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
205
206 self.enqueue_graph_extraction_task(content, has_injection_flags, has_tool_result_parts)
207 .await;
208
209 if role == Role::User && !has_tool_result_parts && !has_injection_flags {
211 self.enqueue_persona_extraction_task();
212 }
213
214 if has_tool_result_parts {
216 self.enqueue_trajectory_extraction_task();
217 }
218
219 let has_tool_use_parts = parts
224 .iter()
225 .any(|p| matches!(p, MessagePart::ToolUse { .. }));
226 if role == Role::Assistant && !has_tool_use_parts && !has_injection_flags {
227 self.enqueue_reasoning_extraction_task();
228 self.enqueue_memcot_distill_task(content);
230 }
231 }
232
233 fn enqueue_memcot_distill_task(&mut self, assistant_content: &str) {
238 let Some(accumulator) = &self.services.memory.extraction.memcot_accumulator else {
239 return;
240 };
241 let distill_provider_name = self
242 .services
243 .memory
244 .extraction
245 .memcot_config
246 .distill_provider
247 .as_str();
248 let provider = self.resolve_background_provider(distill_provider_name);
249
250 let content = assistant_content.to_owned();
251 let supervisor = &mut self.runtime.lifecycle.supervisor;
252
253 accumulator.maybe_enqueue_distill(&content, provider, |name, fut| {
254 supervisor.spawn(super::agent_supervisor::TaskClass::Enrichment, name, fut);
255 });
256 }
257
258 fn enqueue_summarization_task(&mut self) {
260 let (Some(memory), Some(cid)) = (
261 self.services.memory.persistence.memory.clone(),
262 self.services.memory.persistence.conversation_id,
263 ) else {
264 return;
265 };
266
267 if self.services.memory.persistence.unsummarized_count
268 <= self.services.memory.compaction.summarization_threshold
269 {
270 return;
271 }
272
273 let batch_size = self.services.memory.compaction.summarization_threshold / 2;
274
275 self.runtime.lifecycle.supervisor.spawn_summarization("summarization", async move {
276 match tokio::time::timeout(
277 std::time::Duration::from_secs(30),
278 memory.summarize(cid, batch_size),
279 )
280 .await
281 {
282 Ok(Ok(Some(summary_id))) => {
283 tracing::info!(
284 "background summarization: created summary {summary_id} for conversation {cid}"
285 );
286 true
287 }
288 Ok(Ok(None)) => {
289 tracing::debug!("background summarization: no summarization needed");
290 false
291 }
292 Ok(Err(e)) => {
293 tracing::error!("background summarization failed: {e:#}");
294 false
295 }
296 Err(_) => {
297 tracing::warn!("background summarization timed out after 30s");
298 false
299 }
300 }
301 });
302 }
303
304 #[tracing::instrument(
309 name = "core.persist.enqueue_graph_extraction",
310 skip_all,
311 level = "debug"
312 )]
313 async fn enqueue_graph_extraction_task(
314 &mut self,
315 content: &str,
316 has_injection_flags: bool,
317 has_tool_result_parts: bool,
318 ) {
319 if self.services.memory.persistence.memory.is_none()
320 || self.services.memory.persistence.conversation_id.is_none()
321 {
322 return;
323 }
324 if has_tool_result_parts {
325 tracing::debug!("graph extraction skipped: message contains ToolResult parts");
326 return;
327 }
328 if has_injection_flags {
329 tracing::warn!("graph extraction skipped: injection patterns detected in content");
330 return;
331 }
332
333 let cfg = &self.services.memory.extraction.graph_config;
334 if !cfg.enabled {
335 return;
336 }
337 let extraction_cfg = build_graph_extraction_config(
338 cfg,
339 self.services
340 .memory
341 .persistence
342 .conversation_id
343 .map(|c| c.0),
344 );
345 let extract_provider_name = cfg.extract_provider.as_str().to_owned();
348
349 if self.rpe_should_skip(content).await {
352 tracing::debug!("D-MEM RPE: low-surprise turn, skipping graph extraction");
353 return;
354 }
355
356 let context_messages = collect_context_messages(&self.msg.messages);
357
358 let Some(memory) = self.services.memory.persistence.memory.clone() else {
359 return;
360 };
361
362 let validator: zeph_memory::semantic::PostExtractValidator =
363 if self.services.security.memory_validator.is_enabled() {
364 let v = self.services.security.memory_validator.clone();
365 Some(Box::new(move |result| {
366 v.validate_graph_extraction(result)
367 .map_err(|e| e.to_string())
368 }))
369 } else {
370 None
371 };
372
373 let provider_override = if extract_provider_name.is_empty() {
374 None
375 } else {
376 Some(self.resolve_background_provider(&extract_provider_name))
377 };
378
379 self.spawn_graph_extraction_task(
380 memory,
381 content,
382 context_messages,
383 extraction_cfg,
384 validator,
385 provider_override,
386 );
387
388 self.sync_community_detection_failures();
390 self.sync_graph_extraction_metrics();
391 self.enqueue_graph_count_sync_task();
392 }
393
394 fn spawn_graph_extraction_task(
395 &mut self,
396 memory: std::sync::Arc<zeph_memory::semantic::SemanticMemory>,
397 content: &str,
398 context_messages: Vec<String>,
399 extraction_cfg: zeph_memory::semantic::GraphExtractionConfig,
400 validator: zeph_memory::semantic::PostExtractValidator,
401 provider_override: Option<zeph_llm::any::AnyProvider>,
402 ) {
403 let content_owned = content.to_owned();
404 let graph_store = memory.graph_store.clone();
405 let metrics_tx = self.runtime.metrics.metrics_tx.clone();
406 let start_time = self.runtime.lifecycle.start_time;
407
408 self.runtime.lifecycle.supervisor.spawn(
409 super::agent_supervisor::TaskClass::Enrichment,
410 "graph_extraction",
411 async move {
412 let extraction_handle = memory.spawn_graph_extraction(
413 content_owned,
414 context_messages,
415 extraction_cfg,
416 validator,
417 provider_override,
418 );
419
420 if let (Some(store), Some(tx)) = (graph_store, metrics_tx) {
422 let _ = extraction_handle.await;
423 let (entities, edges, communities) = tokio::join!(
424 store.entity_count(),
425 store.active_edge_count(),
426 store.community_count()
427 );
428 let elapsed = start_time.elapsed().as_secs();
429 tx.send_modify(|m| {
430 m.uptime_seconds = elapsed;
431 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
432 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
433 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
434 });
435 } else {
436 let _ = extraction_handle.await;
437 }
438
439 tracing::debug!("background graph extraction complete");
440 },
441 );
442 }
443
444 fn enqueue_graph_count_sync_task(&mut self) {
446 let memory_for_sync = self.services.memory.persistence.memory.clone();
447 let metrics_tx_sync = self.runtime.metrics.metrics_tx.clone();
448 let start_time_sync = self.runtime.lifecycle.start_time;
449 let cid_sync = self.services.memory.persistence.conversation_id;
450 let graph_store_sync = memory_for_sync.as_ref().and_then(|m| m.graph_store.clone());
451 let sqlite_sync = memory_for_sync.as_ref().map(|m| m.sqlite().clone());
452 let guidelines_enabled = self.services.memory.extraction.graph_config.enabled;
453
454 self.runtime.lifecycle.supervisor.spawn(
455 super::agent_supervisor::TaskClass::Telemetry,
456 "graph_count_sync",
457 async move {
458 let Some(store) = graph_store_sync else {
459 return;
460 };
461 let Some(tx) = metrics_tx_sync else { return };
462
463 let (entities, edges, communities) = tokio::join!(
464 store.entity_count(),
465 store.active_edge_count(),
466 store.community_count()
467 );
468 let elapsed = start_time_sync.elapsed().as_secs();
469 tx.send_modify(|m| {
470 m.uptime_seconds = elapsed;
471 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
472 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
473 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
474 });
475
476 if guidelines_enabled && let Some(sqlite) = sqlite_sync {
478 match tokio::time::timeout(
479 std::time::Duration::from_secs(10),
480 sqlite.load_compression_guidelines_meta(cid_sync),
481 )
482 .await
483 {
484 Ok(Ok((version, created_at))) => {
485 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
486 let version_u32 = u32::try_from(version).unwrap_or(0);
487 tx.send_modify(|m| {
488 m.guidelines_version = version_u32;
489 m.guidelines_updated_at = created_at;
490 });
491 }
492 Ok(Err(e)) => {
493 tracing::debug!("guidelines status sync failed: {e:#}");
494 }
495 Err(_) => {
496 tracing::debug!("guidelines status sync timed out");
497 }
498 }
499 }
500 },
501 );
502 }
503
504 fn enqueue_persona_extraction_task(&mut self) {
506 use zeph_memory::semantic::{PersonaExtractionConfig, extract_persona_facts};
507
508 let cfg = &self.services.memory.extraction.persona_config;
509 if !cfg.enabled {
510 return;
511 }
512
513 let Some(memory) = &self.services.memory.persistence.memory else {
514 return;
515 };
516
517 let user_messages: Vec<String> = self
518 .msg
519 .messages
520 .iter()
521 .filter(|m| {
522 m.role == Role::User
523 && !m
524 .parts
525 .iter()
526 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
527 })
528 .take(8)
529 .map(|m| {
530 if m.content.len() > 2048 {
531 m.content[..m.content.floor_char_boundary(2048)].to_owned()
532 } else {
533 m.content.clone()
534 }
535 })
536 .collect();
537
538 if user_messages.len() < cfg.min_messages {
539 return;
540 }
541
542 let timeout_secs = cfg.extraction_timeout_secs;
543 let extraction_cfg = PersonaExtractionConfig {
544 enabled: cfg.enabled,
545 min_messages: cfg.min_messages,
546 max_messages: cfg.max_messages,
547 extraction_timeout_secs: timeout_secs,
548 };
549
550 let provider = self.resolve_background_provider(cfg.persona_provider.as_str());
551 let store = memory.sqlite().clone();
552 let conversation_id = self
553 .services
554 .memory
555 .persistence
556 .conversation_id
557 .map(|c| c.0);
558
559 self.runtime.lifecycle.supervisor.spawn(
560 super::agent_supervisor::TaskClass::Enrichment,
561 "persona_extraction",
562 async move {
563 let user_message_refs: Vec<&str> =
564 user_messages.iter().map(String::as_str).collect();
565 let fut = extract_persona_facts(
566 &store,
567 &provider,
568 &user_message_refs,
569 &extraction_cfg,
570 conversation_id,
571 );
572 match tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), fut).await
573 {
574 Ok(Ok(n)) => tracing::debug!(upserted = n, "persona extraction complete"),
575 Ok(Err(e)) => tracing::warn!(error = %e, "persona extraction failed"),
576 Err(_) => tracing::warn!(
577 timeout_secs,
578 "persona extraction timed out — no facts written this turn"
579 ),
580 }
581 },
582 );
583 }
584
585 fn enqueue_trajectory_extraction_task(&mut self) {
587 use zeph_memory::semantic::{TrajectoryExtractionConfig, extract_trajectory_entries};
588
589 let cfg = self.services.memory.extraction.trajectory_config.clone();
590 if !cfg.enabled {
591 return;
592 }
593
594 let Some(memory) = &self.services.memory.persistence.memory else {
595 return;
596 };
597
598 let conversation_id = match self.services.memory.persistence.conversation_id {
599 Some(cid) => cid.0,
600 None => return,
601 };
602
603 let tail_start = self.msg.messages.len().saturating_sub(cfg.max_messages);
604 let turn_messages: Vec<zeph_llm::provider::Message> =
605 self.msg.messages[tail_start..].to_vec();
606
607 if turn_messages.is_empty() {
608 return;
609 }
610
611 let extraction_cfg = TrajectoryExtractionConfig {
612 enabled: cfg.enabled,
613 max_messages: cfg.max_messages,
614 extraction_timeout_secs: cfg.extraction_timeout_secs,
615 };
616
617 let provider = self.resolve_background_provider(cfg.trajectory_provider.as_str());
618 let store = memory.sqlite().clone();
619 let min_confidence = cfg.min_confidence;
620
621 self.runtime.lifecycle.supervisor.spawn(
622 super::agent_supervisor::TaskClass::Enrichment,
623 "trajectory_extraction",
624 async move {
625 let entries =
626 match extract_trajectory_entries(&provider, &turn_messages, &extraction_cfg)
627 .await
628 {
629 Ok(e) => e,
630 Err(e) => {
631 tracing::warn!(error = %e, "trajectory extraction failed");
632 return;
633 }
634 };
635
636 let last_id = store
637 .trajectory_last_extracted_message_id(conversation_id)
638 .await
639 .unwrap_or(0);
640
641 let mut max_id = last_id;
642 for entry in &entries {
643 if entry.confidence < min_confidence {
644 continue;
645 }
646 let tools_json = serde_json::to_string(&entry.tools_used)
647 .unwrap_or_else(|_| "[]".to_string());
648 match store
649 .insert_trajectory_entry(zeph_memory::NewTrajectoryEntry {
650 conversation_id: Some(conversation_id),
651 turn_index: 0,
652 kind: &entry.kind,
653 intent: &entry.intent,
654 outcome: &entry.outcome,
655 tools_used: &tools_json,
656 confidence: entry.confidence,
657 })
658 .await
659 {
660 Ok(id) => {
661 if id > max_id {
662 max_id = id;
663 }
664 }
665 Err(e) => tracing::warn!(error = %e, "failed to insert trajectory entry"),
666 }
667 }
668
669 if max_id > last_id {
670 let _ = store
671 .set_trajectory_last_extracted_message_id(conversation_id, max_id)
672 .await;
673 }
674
675 tracing::debug!(
676 count = entries.len(),
677 conversation_id,
678 "trajectory extraction complete"
679 );
680 },
681 );
682 }
683
684 fn enqueue_reasoning_extraction_task(&mut self) {
689 let cfg = self.services.memory.extraction.reasoning_config.clone();
690 if !cfg.enabled {
691 return;
692 }
693
694 let Some(memory) = &self.services.memory.persistence.memory else {
695 return;
696 };
697
698 let Some(reasoning) = memory.reasoning.clone() else {
699 return;
700 };
701
702 let tail_start = self.msg.messages.len().saturating_sub(cfg.max_messages);
703 let turn_messages: Vec<zeph_llm::provider::Message> =
704 self.msg.messages[tail_start..].to_vec();
705
706 if turn_messages.len() < cfg.min_messages {
707 return;
708 }
709
710 let extract_provider = self.resolve_background_provider(cfg.extract_provider.as_str());
711 let distill_provider = self.resolve_background_provider(cfg.distill_provider.as_str());
712 let embed_provider = memory.effective_embed_provider().clone();
713 let store_limit = cfg.store_limit;
714 let extraction_timeout = std::time::Duration::from_secs(cfg.extraction_timeout_secs);
715 let distill_timeout = std::time::Duration::from_secs(cfg.distill_timeout_secs);
716 let self_judge_window = cfg.self_judge_window;
717 let min_assistant_chars = cfg.min_assistant_chars;
718
719 self.runtime.lifecycle.supervisor.spawn(
720 super::agent_supervisor::TaskClass::Enrichment,
721 "reasoning_extraction",
722 async move {
723 if let Err(e) = zeph_memory::process_reasoning_turn(
724 &reasoning,
725 &extract_provider,
726 &distill_provider,
727 &embed_provider,
728 &turn_messages,
729 zeph_memory::ProcessTurnConfig {
730 store_limit,
731 extraction_timeout,
732 distill_timeout,
733 self_judge_window,
734 min_assistant_chars,
735 },
736 )
737 .await
738 {
739 tracing::warn!(error = %e, "reasoning: process_turn failed");
740 }
741
742 tracing::debug!("reasoning extraction complete");
743 },
744 );
745 }
746
747 async fn rpe_should_skip(&mut self, content: &str) -> bool {
752 let Some(ref rpe_mutex) = self.services.memory.extraction.rpe_router else {
753 return false;
754 };
755 let Some(memory) = &self.services.memory.persistence.memory else {
756 return false;
757 };
758 let candidates = zeph_memory::extract_candidate_entities(content);
759 let provider = memory.provider();
760 let Ok(Ok(emb_vec)) =
761 tokio::time::timeout(std::time::Duration::from_secs(5), provider.embed(content)).await
762 else {
763 return false; };
765 if let Ok(mut router) = rpe_mutex.lock() {
766 let signal = router.compute(&emb_vec, &candidates);
767 router.push_embedding(emb_vec);
768 router.push_entities(&candidates);
769 !signal.should_extract
770 } else {
771 tracing::warn!("rpe_router mutex poisoned; falling through to extract");
772 false
773 }
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use super::super::agent_tests::{
780 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
781 };
782 use super::*;
783 use zeph_llm::any::AnyProvider;
784 use zeph_llm::provider::Message;
785 use zeph_memory::semantic::SemanticMemory;
786
787 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
788 SemanticMemory::new(
789 ":memory:",
790 "http://127.0.0.1:1",
791 None,
792 provider.clone(),
793 "test-model",
794 )
795 .await
796 .unwrap()
797 }
798
799 #[tokio::test]
800 async fn load_history_without_memory_returns_ok() {
801 let provider = mock_provider(vec![]);
802 let channel = MockChannel::new(vec![]);
803 let registry = create_test_registry();
804 let executor = MockToolExecutor::no_tools();
805 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
806
807 let result = agent.load_history().await;
808 assert!(result.is_ok());
809 assert_eq!(agent.msg.messages.len(), 1); }
812
813 #[tokio::test]
814 async fn load_history_with_messages_injects_into_agent() {
815 let provider = mock_provider(vec![]);
816 let channel = MockChannel::new(vec![]);
817 let registry = create_test_registry();
818 let executor = MockToolExecutor::no_tools();
819
820 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
821 let cid = memory.sqlite().create_conversation().await.unwrap();
822
823 memory
824 .sqlite()
825 .save_message(cid, "user", "hello from history")
826 .await
827 .unwrap();
828 memory
829 .sqlite()
830 .save_message(cid, "assistant", "hi back")
831 .await
832 .unwrap();
833
834 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
835 std::sync::Arc::new(memory),
836 cid,
837 50,
838 5,
839 100,
840 );
841
842 let messages_before = agent.msg.messages.len();
843 agent.load_history().await.unwrap();
844 assert_eq!(agent.msg.messages.len(), messages_before + 2);
846 }
847
848 #[tokio::test]
849 async fn load_history_skips_empty_messages() {
850 let provider = mock_provider(vec![]);
851 let channel = MockChannel::new(vec![]);
852 let registry = create_test_registry();
853 let executor = MockToolExecutor::no_tools();
854
855 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
856 let cid = memory.sqlite().create_conversation().await.unwrap();
857
858 memory
860 .sqlite()
861 .save_message(cid, "user", " ")
862 .await
863 .unwrap();
864 memory
865 .sqlite()
866 .save_message(cid, "user", "real message")
867 .await
868 .unwrap();
869
870 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
871 std::sync::Arc::new(memory),
872 cid,
873 50,
874 5,
875 100,
876 );
877
878 let messages_before = agent.msg.messages.len();
879 agent.load_history().await.unwrap();
880 assert_eq!(agent.msg.messages.len(), messages_before + 1);
882 }
883
884 #[tokio::test]
885 async fn load_history_with_empty_store_returns_ok() {
886 let provider = mock_provider(vec![]);
887 let channel = MockChannel::new(vec![]);
888 let registry = create_test_registry();
889 let executor = MockToolExecutor::no_tools();
890
891 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
892 let cid = memory.sqlite().create_conversation().await.unwrap();
893
894 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
895 std::sync::Arc::new(memory),
896 cid,
897 50,
898 5,
899 100,
900 );
901
902 let messages_before = agent.msg.messages.len();
903 agent.load_history().await.unwrap();
904 assert_eq!(agent.msg.messages.len(), messages_before);
906 }
907
908 #[tokio::test]
909 async fn load_history_increments_session_count_for_existing_messages() {
910 let provider = mock_provider(vec![]);
911 let channel = MockChannel::new(vec![]);
912 let registry = create_test_registry();
913 let executor = MockToolExecutor::no_tools();
914
915 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
916 let cid = memory.sqlite().create_conversation().await.unwrap();
917
918 let id1 = memory
920 .sqlite()
921 .save_message(cid, "user", "hello")
922 .await
923 .unwrap();
924 let id2 = memory
925 .sqlite()
926 .save_message(cid, "assistant", "hi")
927 .await
928 .unwrap();
929
930 let memory_arc = std::sync::Arc::new(memory);
931 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
932 memory_arc.clone(),
933 cid,
934 50,
935 5,
936 100,
937 );
938
939 agent.load_history().await.unwrap();
940
941 let counts: Vec<i64> = zeph_db::query_scalar(
943 "SELECT session_count FROM messages WHERE id IN (?, ?) ORDER BY id",
944 )
945 .bind(id1)
946 .bind(id2)
947 .fetch_all(memory_arc.sqlite().pool())
948 .await
949 .unwrap();
950 assert_eq!(
951 counts,
952 vec![1, 1],
953 "session_count must be 1 after first restore"
954 );
955 }
956
957 #[tokio::test]
958 async fn load_history_does_not_increment_session_count_for_new_conversation() {
959 let provider = mock_provider(vec![]);
960 let channel = MockChannel::new(vec![]);
961 let registry = create_test_registry();
962 let executor = MockToolExecutor::no_tools();
963
964 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
965 let cid = memory.sqlite().create_conversation().await.unwrap();
966
967 let memory_arc = std::sync::Arc::new(memory);
969 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
970 memory_arc.clone(),
971 cid,
972 50,
973 5,
974 100,
975 );
976
977 agent.load_history().await.unwrap();
978
979 let counts: Vec<i64> =
981 zeph_db::query_scalar("SELECT session_count FROM messages WHERE conversation_id = ?")
982 .bind(cid)
983 .fetch_all(memory_arc.sqlite().pool())
984 .await
985 .unwrap();
986 assert!(counts.is_empty(), "new conversation must have no messages");
987 }
988
989 #[tokio::test]
990 async fn persist_message_without_memory_silently_returns() {
991 let provider = mock_provider(vec![]);
993 let channel = MockChannel::new(vec![]);
994 let registry = create_test_registry();
995 let executor = MockToolExecutor::no_tools();
996 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
997
998 agent.persist_message(Role::User, "hello", &[], false).await;
1000 }
1001
1002 #[tokio::test]
1003 async fn persist_message_assistant_autosave_false_uses_save_only() {
1004 let provider = mock_provider(vec![]);
1005 let channel = MockChannel::new(vec![]);
1006 let registry = create_test_registry();
1007 let executor = MockToolExecutor::no_tools();
1008
1009 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1010 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1011 let cid = memory.sqlite().create_conversation().await.unwrap();
1012
1013 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1014 .with_metrics(tx)
1015 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1016 agent.services.memory.persistence.autosave_assistant = false;
1017 agent.services.memory.persistence.autosave_min_length = 20;
1018
1019 agent
1020 .persist_message(Role::Assistant, "short assistant reply", &[], false)
1021 .await;
1022
1023 let history = agent
1024 .services
1025 .memory
1026 .persistence
1027 .memory
1028 .as_ref()
1029 .unwrap()
1030 .sqlite()
1031 .load_history(cid, 50)
1032 .await
1033 .unwrap();
1034 assert_eq!(history.len(), 1, "message must be saved");
1035 assert_eq!(history[0].content, "short assistant reply");
1036 assert_eq!(rx.borrow().embeddings_generated, 0);
1038 }
1039
1040 #[tokio::test]
1041 async fn persist_message_assistant_below_min_length_uses_save_only() {
1042 let provider = mock_provider(vec![]);
1043 let channel = MockChannel::new(vec![]);
1044 let registry = create_test_registry();
1045 let executor = MockToolExecutor::no_tools();
1046
1047 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1048 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1049 let cid = memory.sqlite().create_conversation().await.unwrap();
1050
1051 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1053 .with_metrics(tx)
1054 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1055 agent.services.memory.persistence.autosave_assistant = true;
1056 agent.services.memory.persistence.autosave_min_length = 1000;
1057
1058 agent
1059 .persist_message(Role::Assistant, "too short", &[], false)
1060 .await;
1061
1062 let history = agent
1063 .services
1064 .memory
1065 .persistence
1066 .memory
1067 .as_ref()
1068 .unwrap()
1069 .sqlite()
1070 .load_history(cid, 50)
1071 .await
1072 .unwrap();
1073 assert_eq!(history.len(), 1, "message must be saved");
1074 assert_eq!(history[0].content, "too short");
1075 assert_eq!(rx.borrow().embeddings_generated, 0);
1076 }
1077
1078 #[tokio::test]
1079 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
1080 let provider = mock_provider(vec![]);
1082 let channel = MockChannel::new(vec![]);
1083 let registry = create_test_registry();
1084 let executor = MockToolExecutor::no_tools();
1085
1086 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1087 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1088 let cid = memory.sqlite().create_conversation().await.unwrap();
1089
1090 let min_length = 10usize;
1091 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1092 .with_metrics(tx)
1093 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1094 agent.services.memory.persistence.autosave_assistant = true;
1095 agent.services.memory.persistence.autosave_min_length = min_length;
1096
1097 let content_at_boundary = "A".repeat(min_length);
1099 assert_eq!(content_at_boundary.len(), min_length);
1100 agent
1101 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
1102 .await;
1103
1104 assert_eq!(rx.borrow().sqlite_message_count, 1);
1106 }
1107
1108 #[tokio::test]
1109 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
1110 let provider = mock_provider(vec![]);
1112 let channel = MockChannel::new(vec![]);
1113 let registry = create_test_registry();
1114 let executor = MockToolExecutor::no_tools();
1115
1116 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1117 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1118 let cid = memory.sqlite().create_conversation().await.unwrap();
1119
1120 let min_length = 10usize;
1121 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1122 .with_metrics(tx)
1123 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1124 agent.services.memory.persistence.autosave_assistant = true;
1125 agent.services.memory.persistence.autosave_min_length = min_length;
1126
1127 let content_below_boundary = "A".repeat(min_length - 1);
1129 assert_eq!(content_below_boundary.len(), min_length - 1);
1130 agent
1131 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
1132 .await;
1133
1134 let history = agent
1135 .services
1136 .memory
1137 .persistence
1138 .memory
1139 .as_ref()
1140 .unwrap()
1141 .sqlite()
1142 .load_history(cid, 50)
1143 .await
1144 .unwrap();
1145 assert_eq!(history.len(), 1, "message must still be saved");
1146 assert_eq!(rx.borrow().embeddings_generated, 0);
1148 }
1149
1150 #[tokio::test]
1151 async fn persist_message_increments_unsummarized_count() {
1152 let provider = mock_provider(vec![]);
1153 let channel = MockChannel::new(vec![]);
1154 let registry = create_test_registry();
1155 let executor = MockToolExecutor::no_tools();
1156
1157 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1158 let cid = memory.sqlite().create_conversation().await.unwrap();
1159
1160 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1162 std::sync::Arc::new(memory),
1163 cid,
1164 50,
1165 5,
1166 100,
1167 );
1168
1169 assert_eq!(agent.services.memory.persistence.unsummarized_count, 0);
1170
1171 agent.persist_message(Role::User, "first", &[], false).await;
1172 assert_eq!(agent.services.memory.persistence.unsummarized_count, 1);
1173
1174 agent
1175 .persist_message(Role::User, "second", &[], false)
1176 .await;
1177 assert_eq!(agent.services.memory.persistence.unsummarized_count, 2);
1178 }
1179
1180 #[tokio::test]
1181 async fn check_summarization_resets_counter_on_success() {
1182 let provider = mock_provider(vec![]);
1183 let channel = MockChannel::new(vec![]);
1184 let registry = create_test_registry();
1185 let executor = MockToolExecutor::no_tools();
1186
1187 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1188 let cid = memory.sqlite().create_conversation().await.unwrap();
1189
1190 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1192 std::sync::Arc::new(memory),
1193 cid,
1194 50,
1195 5,
1196 1,
1197 );
1198
1199 agent.persist_message(Role::User, "msg1", &[], false).await;
1200 agent.persist_message(Role::User, "msg2", &[], false).await;
1201
1202 assert!(agent.services.memory.persistence.unsummarized_count <= 2);
1207 }
1208
1209 #[tokio::test]
1210 async fn unsummarized_count_not_incremented_without_memory() {
1211 let provider = mock_provider(vec![]);
1212 let channel = MockChannel::new(vec![]);
1213 let registry = create_test_registry();
1214 let executor = MockToolExecutor::no_tools();
1215 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1216
1217 agent.persist_message(Role::User, "hello", &[], false).await;
1218 assert_eq!(agent.services.memory.persistence.unsummarized_count, 0);
1220 }
1221
1222 mod graph_extraction_guards {
1224 use super::*;
1225 use crate::config::GraphConfig;
1226 use zeph_llm::provider::MessageMetadata;
1227 use zeph_memory::graph::GraphStore;
1228
1229 fn enabled_graph_config() -> GraphConfig {
1230 GraphConfig {
1231 enabled: true,
1232 ..GraphConfig::default()
1233 }
1234 }
1235
1236 async fn agent_with_graph(
1237 provider: &AnyProvider,
1238 config: GraphConfig,
1239 ) -> Agent<MockChannel> {
1240 let memory =
1241 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1242 let cid = memory.sqlite().create_conversation().await.unwrap();
1243 Agent::new(
1244 provider.clone(),
1245 MockChannel::new(vec![]),
1246 create_test_registry(),
1247 None,
1248 5,
1249 MockToolExecutor::no_tools(),
1250 )
1251 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1252 .with_graph_config(config)
1253 }
1254
1255 #[tokio::test]
1256 async fn injection_flag_guard_skips_extraction() {
1257 let provider = mock_provider(vec![]);
1259 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1260 let pool = agent
1261 .services
1262 .memory
1263 .persistence
1264 .memory
1265 .as_ref()
1266 .unwrap()
1267 .sqlite()
1268 .pool()
1269 .clone();
1270
1271 agent
1272 .enqueue_graph_extraction_task("I use Rust", true, false)
1273 .await;
1274
1275 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1277
1278 let store = GraphStore::new(pool);
1279 let count = store.get_metadata("extraction_count").await.unwrap();
1280 assert!(
1281 count.is_none(),
1282 "injection flag must prevent extraction_count from being written"
1283 );
1284 }
1285
1286 #[tokio::test]
1287 async fn disabled_config_guard_skips_extraction() {
1288 let provider = mock_provider(vec![]);
1290 let disabled_cfg = GraphConfig {
1291 enabled: false,
1292 ..GraphConfig::default()
1293 };
1294 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
1295 let pool = agent
1296 .services
1297 .memory
1298 .persistence
1299 .memory
1300 .as_ref()
1301 .unwrap()
1302 .sqlite()
1303 .pool()
1304 .clone();
1305
1306 agent
1307 .enqueue_graph_extraction_task("I use Rust", false, false)
1308 .await;
1309
1310 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1311
1312 let store = GraphStore::new(pool);
1313 let count = store.get_metadata("extraction_count").await.unwrap();
1314 assert!(
1315 count.is_none(),
1316 "disabled graph config must prevent extraction"
1317 );
1318 }
1319
1320 #[tokio::test]
1321 async fn happy_path_fires_extraction() {
1322 let provider = mock_provider(vec![]);
1325 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1326 let pool = agent
1327 .services
1328 .memory
1329 .persistence
1330 .memory
1331 .as_ref()
1332 .unwrap()
1333 .sqlite()
1334 .pool()
1335 .clone();
1336
1337 agent
1338 .enqueue_graph_extraction_task("I use Rust for systems programming", false, false)
1339 .await;
1340
1341 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1343
1344 let store = GraphStore::new(pool);
1345 let count = store.get_metadata("extraction_count").await.unwrap();
1346 assert!(
1347 count.is_some(),
1348 "happy-path extraction must increment extraction_count"
1349 );
1350 }
1351
1352 #[tokio::test]
1353 async fn tool_result_parts_guard_skips_extraction() {
1354 let provider = mock_provider(vec![]);
1358 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1359 let pool = agent
1360 .services
1361 .memory
1362 .persistence
1363 .memory
1364 .as_ref()
1365 .unwrap()
1366 .sqlite()
1367 .pool()
1368 .clone();
1369
1370 agent
1371 .enqueue_graph_extraction_task(
1372 "[tool_result: abc123]\nprovider_type = \"claude\"\nallowed_commands = []",
1373 false,
1374 true, )
1376 .await;
1377
1378 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1379
1380 let store = GraphStore::new(pool);
1381 let count = store.get_metadata("extraction_count").await.unwrap();
1382 assert!(
1383 count.is_none(),
1384 "tool result message must not trigger graph extraction"
1385 );
1386 }
1387
1388 #[tokio::test]
1389 async fn context_filter_excludes_tool_result_messages() {
1390 let provider = mock_provider(vec![]);
1401 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1402
1403 agent.msg.messages.push(Message {
1406 role: Role::User,
1407 content: "[tool_result: abc]\nprovider_type = \"openai\"".to_owned(),
1408 parts: vec![MessagePart::ToolResult {
1409 tool_use_id: "abc".to_owned(),
1410 content: "provider_type = \"openai\"".to_owned(),
1411 is_error: false,
1412 }],
1413 metadata: MessageMetadata::default(),
1414 });
1415
1416 let pool = agent
1417 .services
1418 .memory
1419 .persistence
1420 .memory
1421 .as_ref()
1422 .unwrap()
1423 .sqlite()
1424 .pool()
1425 .clone();
1426
1427 agent
1429 .enqueue_graph_extraction_task(
1430 "I prefer Rust for systems programming",
1431 false,
1432 false,
1433 )
1434 .await;
1435
1436 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1437
1438 let store = GraphStore::new(pool);
1440 let count = store.get_metadata("extraction_count").await.unwrap();
1441 assert!(
1442 count.is_some(),
1443 "conversational message must trigger extraction even with prior tool result in history"
1444 );
1445 }
1446 }
1447
1448 mod persona_extraction_guards {
1450 use super::*;
1451 use zeph_config::PersonaConfig;
1452 use zeph_llm::provider::MessageMetadata;
1453
1454 fn enabled_persona_config() -> PersonaConfig {
1455 PersonaConfig {
1456 enabled: true,
1457 min_messages: 1,
1458 ..PersonaConfig::default()
1459 }
1460 }
1461
1462 async fn agent_with_persona(
1463 provider: &AnyProvider,
1464 config: PersonaConfig,
1465 ) -> Agent<MockChannel> {
1466 let memory =
1467 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1468 let cid = memory.sqlite().create_conversation().await.unwrap();
1469 let mut agent = Agent::new(
1470 provider.clone(),
1471 MockChannel::new(vec![]),
1472 create_test_registry(),
1473 None,
1474 5,
1475 MockToolExecutor::no_tools(),
1476 )
1477 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1478 agent.services.memory.extraction.persona_config = config;
1479 agent
1480 }
1481
1482 #[tokio::test]
1483 async fn disabled_config_skips_spawn() {
1484 let provider = mock_provider(vec![]);
1486 let mut agent = agent_with_persona(
1487 &provider,
1488 PersonaConfig {
1489 enabled: false,
1490 ..PersonaConfig::default()
1491 },
1492 )
1493 .await;
1494
1495 agent.msg.messages.push(zeph_llm::provider::Message {
1497 role: Role::User,
1498 content: "I prefer Rust for systems programming".to_owned(),
1499 parts: vec![],
1500 metadata: MessageMetadata::default(),
1501 });
1502
1503 agent.enqueue_persona_extraction_task();
1504
1505 let store = agent
1506 .services
1507 .memory
1508 .persistence
1509 .memory
1510 .as_ref()
1511 .unwrap()
1512 .sqlite()
1513 .clone();
1514 let count = store.count_persona_facts().await.unwrap();
1515 assert_eq!(count, 0, "disabled persona config must not write any facts");
1516 }
1517
1518 #[tokio::test]
1519 async fn below_min_messages_skips_spawn() {
1520 let provider = mock_provider(vec![]);
1522 let mut agent = agent_with_persona(
1523 &provider,
1524 PersonaConfig {
1525 enabled: true,
1526 min_messages: 3,
1527 ..PersonaConfig::default()
1528 },
1529 )
1530 .await;
1531
1532 for text in ["I use Rust", "I prefer async code"] {
1533 agent.msg.messages.push(zeph_llm::provider::Message {
1534 role: Role::User,
1535 content: text.to_owned(),
1536 parts: vec![],
1537 metadata: MessageMetadata::default(),
1538 });
1539 }
1540
1541 agent.enqueue_persona_extraction_task();
1542
1543 let store = agent
1544 .services
1545 .memory
1546 .persistence
1547 .memory
1548 .as_ref()
1549 .unwrap()
1550 .sqlite()
1551 .clone();
1552 let count = store.count_persona_facts().await.unwrap();
1553 assert_eq!(
1554 count, 0,
1555 "below min_messages threshold must not trigger extraction"
1556 );
1557 }
1558
1559 #[tokio::test]
1560 async fn no_memory_skips_spawn() {
1561 let provider = mock_provider(vec![]);
1563 let channel = MockChannel::new(vec![]);
1564 let registry = create_test_registry();
1565 let executor = MockToolExecutor::no_tools();
1566 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1567 agent.services.memory.extraction.persona_config = enabled_persona_config();
1568 agent.msg.messages.push(zeph_llm::provider::Message {
1569 role: Role::User,
1570 content: "I like Rust".to_owned(),
1571 parts: vec![],
1572 metadata: MessageMetadata::default(),
1573 });
1574
1575 agent.enqueue_persona_extraction_task();
1577 }
1578
1579 #[tokio::test]
1580 async fn enabled_enough_messages_spawns_extraction() {
1581 use zeph_llm::mock::MockProvider;
1584 let (mock, recorded) = MockProvider::default().with_recording();
1585 let provider = AnyProvider::Mock(mock);
1586 let mut agent = agent_with_persona(&provider, enabled_persona_config()).await;
1587
1588 agent.msg.messages.push(zeph_llm::provider::Message {
1589 role: Role::User,
1590 content: "I prefer Rust for systems programming".to_owned(),
1591 parts: vec![],
1592 metadata: MessageMetadata::default(),
1593 });
1594
1595 agent.enqueue_persona_extraction_task();
1596
1597 agent.runtime.lifecycle.supervisor.join_all_for_test().await;
1599
1600 let calls = recorded.lock().unwrap();
1601 assert!(
1602 !calls.is_empty(),
1603 "happy-path: provider.chat() must be called when extraction completes"
1604 );
1605 }
1606
1607 #[tokio::test]
1608 async fn messages_capped_at_eight() {
1609 use zeph_llm::mock::MockProvider;
1612 let (mock, recorded) = MockProvider::default().with_recording();
1613 let provider = AnyProvider::Mock(mock);
1614 let mut agent = agent_with_persona(&provider, enabled_persona_config()).await;
1615
1616 for i in 0..12u32 {
1617 agent.msg.messages.push(zeph_llm::provider::Message {
1618 role: Role::User,
1619 content: format!("I like message {i}"),
1620 parts: vec![],
1621 metadata: MessageMetadata::default(),
1622 });
1623 }
1624
1625 agent.enqueue_persona_extraction_task();
1626
1627 agent.runtime.lifecycle.supervisor.join_all_for_test().await;
1629
1630 let calls = recorded.lock().unwrap();
1632 assert!(
1633 !calls.is_empty(),
1634 "extraction must run when enough messages present"
1635 );
1636 let prompt = &calls[0];
1638 let user_text = prompt
1639 .iter()
1640 .filter(|m| m.role == Role::User)
1641 .map(|m| m.content.as_str())
1642 .collect::<Vec<_>>()
1643 .join(" ");
1644 assert!(
1646 !user_text.contains("I like message 8"),
1647 "message index 8 must be excluded from extraction input"
1648 );
1649 }
1650
1651 #[test]
1652 fn long_message_truncated_at_char_boundary() {
1653 let long_content = "x".repeat(3000);
1657 let truncated = if long_content.len() > 2048 {
1658 long_content[..long_content.floor_char_boundary(2048)].to_owned()
1659 } else {
1660 long_content.clone()
1661 };
1662 assert_eq!(
1663 truncated.len(),
1664 2048,
1665 "ASCII content must be truncated to exactly 2048 bytes"
1666 );
1667
1668 let multi = "é".repeat(1500); let truncated_multi = if multi.len() > 2048 {
1671 multi[..multi.floor_char_boundary(2048)].to_owned()
1672 } else {
1673 multi.clone()
1674 };
1675 assert!(
1676 truncated_multi.len() <= 2048,
1677 "multi-byte content must not exceed 2048 bytes"
1678 );
1679 assert!(truncated_multi.is_char_boundary(truncated_multi.len()));
1680 }
1681 }
1682
1683 #[tokio::test]
1684 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
1685 let provider = mock_provider(vec![]);
1686 let channel = MockChannel::new(vec![]);
1687 let registry = create_test_registry();
1688 let executor = MockToolExecutor::no_tools();
1689
1690 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1691 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1692 let cid = memory.sqlite().create_conversation().await.unwrap();
1693
1694 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1696 .with_metrics(tx)
1697 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1698 agent.services.memory.persistence.autosave_assistant = false;
1699 agent.services.memory.persistence.autosave_min_length = 20;
1700
1701 let long_user_msg = "A".repeat(100);
1702 agent
1703 .persist_message(Role::User, &long_user_msg, &[], false)
1704 .await;
1705
1706 let history = agent
1707 .services
1708 .memory
1709 .persistence
1710 .memory
1711 .as_ref()
1712 .unwrap()
1713 .sqlite()
1714 .load_history(cid, 50)
1715 .await
1716 .unwrap();
1717 assert_eq!(history.len(), 1, "user message must be saved");
1718 assert_eq!(rx.borrow().sqlite_message_count, 1);
1721 }
1722
1723 #[tokio::test]
1727 async fn persist_message_saves_correct_tool_use_parts() {
1728 use zeph_llm::provider::MessagePart;
1729
1730 let provider = mock_provider(vec![]);
1731 let channel = MockChannel::new(vec![]);
1732 let registry = create_test_registry();
1733 let executor = MockToolExecutor::no_tools();
1734
1735 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1736 let cid = memory.sqlite().create_conversation().await.unwrap();
1737
1738 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1739 std::sync::Arc::new(memory),
1740 cid,
1741 50,
1742 5,
1743 100,
1744 );
1745
1746 let parts = vec![MessagePart::ToolUse {
1747 id: "call_abc123".to_string(),
1748 name: "read_file".to_string(),
1749 input: serde_json::json!({"path": "/tmp/test.txt"}),
1750 }];
1751 let content = "[tool_use: read_file(call_abc123)]";
1752
1753 agent
1754 .persist_message(Role::Assistant, content, &parts, false)
1755 .await;
1756
1757 let history = agent
1758 .services
1759 .memory
1760 .persistence
1761 .memory
1762 .as_ref()
1763 .unwrap()
1764 .sqlite()
1765 .load_history(cid, 50)
1766 .await
1767 .unwrap();
1768
1769 assert_eq!(history.len(), 1);
1770 assert_eq!(history[0].role, Role::Assistant);
1771 assert_eq!(history[0].content, content);
1772 assert_eq!(history[0].parts.len(), 1);
1773 match &history[0].parts[0] {
1774 MessagePart::ToolUse { id, name, .. } => {
1775 assert_eq!(id, "call_abc123");
1776 assert_eq!(name, "read_file");
1777 }
1778 other => panic!("expected ToolUse part, got {other:?}"),
1779 }
1780 assert!(
1782 !history[0]
1783 .parts
1784 .iter()
1785 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1786 "assistant message must not contain ToolResult parts"
1787 );
1788 }
1789
1790 #[tokio::test]
1791 async fn persist_message_saves_correct_tool_result_parts() {
1792 use zeph_llm::provider::MessagePart;
1793
1794 let provider = mock_provider(vec![]);
1795 let channel = MockChannel::new(vec![]);
1796 let registry = create_test_registry();
1797 let executor = MockToolExecutor::no_tools();
1798
1799 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1800 let cid = memory.sqlite().create_conversation().await.unwrap();
1801
1802 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1803 std::sync::Arc::new(memory),
1804 cid,
1805 50,
1806 5,
1807 100,
1808 );
1809
1810 let parts = vec![MessagePart::ToolResult {
1811 tool_use_id: "call_abc123".to_string(),
1812 content: "file contents here".to_string(),
1813 is_error: false,
1814 }];
1815 let content = "[tool_result: call_abc123]\nfile contents here";
1816
1817 agent
1818 .persist_message(Role::User, content, &parts, false)
1819 .await;
1820
1821 let history = agent
1822 .services
1823 .memory
1824 .persistence
1825 .memory
1826 .as_ref()
1827 .unwrap()
1828 .sqlite()
1829 .load_history(cid, 50)
1830 .await
1831 .unwrap();
1832
1833 assert_eq!(history.len(), 1);
1834 assert_eq!(history[0].role, Role::User);
1835 assert_eq!(history[0].content, content);
1836 assert_eq!(history[0].parts.len(), 1);
1837 match &history[0].parts[0] {
1838 MessagePart::ToolResult {
1839 tool_use_id,
1840 content: result_content,
1841 is_error,
1842 } => {
1843 assert_eq!(tool_use_id, "call_abc123");
1844 assert_eq!(result_content, "file contents here");
1845 assert!(!is_error);
1846 }
1847 other => panic!("expected ToolResult part, got {other:?}"),
1848 }
1849 assert!(
1851 !history[0]
1852 .parts
1853 .iter()
1854 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1855 "user ToolResult message must not contain ToolUse parts"
1856 );
1857 }
1858
1859 #[tokio::test]
1860 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1861 use zeph_llm::provider::MessagePart;
1862
1863 let provider = mock_provider(vec![]);
1864 let channel = MockChannel::new(vec![]);
1865 let registry = create_test_registry();
1866 let executor = MockToolExecutor::no_tools();
1867
1868 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1869 let cid = memory.sqlite().create_conversation().await.unwrap();
1870
1871 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1872 std::sync::Arc::new(memory),
1873 cid,
1874 50,
1875 5,
1876 100,
1877 );
1878
1879 let assistant_parts = vec![MessagePart::ToolUse {
1881 id: "id_1".to_string(),
1882 name: "list_dir".to_string(),
1883 input: serde_json::json!({"path": "/tmp"}),
1884 }];
1885 agent
1886 .persist_message(
1887 Role::Assistant,
1888 "[tool_use: list_dir(id_1)]",
1889 &assistant_parts,
1890 false,
1891 )
1892 .await;
1893
1894 let user_parts = vec![MessagePart::ToolResult {
1896 tool_use_id: "id_1".to_string(),
1897 content: "file1.txt\nfile2.txt".to_string(),
1898 is_error: false,
1899 }];
1900 agent
1901 .persist_message(
1902 Role::User,
1903 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1904 &user_parts,
1905 false,
1906 )
1907 .await;
1908
1909 let history = agent
1910 .services
1911 .memory
1912 .persistence
1913 .memory
1914 .as_ref()
1915 .unwrap()
1916 .sqlite()
1917 .load_history(cid, 50)
1918 .await
1919 .unwrap();
1920
1921 assert_eq!(history.len(), 2);
1922
1923 assert_eq!(history[0].role, Role::Assistant);
1925 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1926 assert!(
1927 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1928 "first message must be assistant ToolUse"
1929 );
1930
1931 assert_eq!(history[1].role, Role::User);
1933 assert_eq!(
1934 history[1].content,
1935 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1936 );
1937 assert!(
1938 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1939 "second message must be user ToolResult"
1940 );
1941
1942 assert!(
1944 !history[0]
1945 .parts
1946 .iter()
1947 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1948 "assistant message must not have ToolResult parts"
1949 );
1950 assert!(
1951 !history[1]
1952 .parts
1953 .iter()
1954 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1955 "user message must not have ToolUse parts"
1956 );
1957 }
1958
1959 #[tokio::test]
1960 async fn persist_message_saves_correct_tool_output_parts() {
1961 use zeph_llm::provider::MessagePart;
1962
1963 let provider = mock_provider(vec![]);
1964 let channel = MockChannel::new(vec![]);
1965 let registry = create_test_registry();
1966 let executor = MockToolExecutor::no_tools();
1967
1968 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1969 let cid = memory.sqlite().create_conversation().await.unwrap();
1970
1971 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1972 std::sync::Arc::new(memory),
1973 cid,
1974 50,
1975 5,
1976 100,
1977 );
1978
1979 let parts = vec![MessagePart::ToolOutput {
1980 tool_name: "shell".into(),
1981 body: "hello from shell".to_string(),
1982 compacted_at: None,
1983 }];
1984 let content = "[tool: shell]\nhello from shell";
1985
1986 agent
1987 .persist_message(Role::User, content, &parts, false)
1988 .await;
1989
1990 let history = agent
1991 .services
1992 .memory
1993 .persistence
1994 .memory
1995 .as_ref()
1996 .unwrap()
1997 .sqlite()
1998 .load_history(cid, 50)
1999 .await
2000 .unwrap();
2001
2002 assert_eq!(history.len(), 1);
2003 assert_eq!(history[0].role, Role::User);
2004 assert_eq!(history[0].content, content);
2005 assert_eq!(history[0].parts.len(), 1);
2006 match &history[0].parts[0] {
2007 MessagePart::ToolOutput {
2008 tool_name,
2009 body,
2010 compacted_at,
2011 } => {
2012 assert_eq!(tool_name, "shell");
2013 assert_eq!(body, "hello from shell");
2014 assert!(compacted_at.is_none());
2015 }
2016 other => panic!("expected ToolOutput part, got {other:?}"),
2017 }
2018 }
2019
2020 #[tokio::test]
2023 async fn load_history_removes_trailing_orphan_tool_use() {
2024 use zeph_llm::provider::MessagePart;
2025
2026 let provider = mock_provider(vec![]);
2027 let channel = MockChannel::new(vec![]);
2028 let registry = create_test_registry();
2029 let executor = MockToolExecutor::no_tools();
2030
2031 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2032 let cid = memory.sqlite().create_conversation().await.unwrap();
2033 let sqlite = memory.sqlite();
2034
2035 sqlite
2037 .save_message(cid, "user", "do something with a tool")
2038 .await
2039 .unwrap();
2040
2041 let parts = serde_json::to_string(&[MessagePart::ToolUse {
2043 id: "call_orphan".to_string(),
2044 name: "shell".to_string(),
2045 input: serde_json::json!({"command": "ls"}),
2046 }])
2047 .unwrap();
2048 sqlite
2049 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
2050 .await
2051 .unwrap();
2052
2053 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2054 std::sync::Arc::new(memory),
2055 cid,
2056 50,
2057 5,
2058 100,
2059 );
2060
2061 let messages_before = agent.msg.messages.len();
2062 agent.load_history().await.unwrap();
2063
2064 assert_eq!(
2066 agent.msg.messages.len(),
2067 messages_before + 1,
2068 "orphaned trailing tool_use must be removed"
2069 );
2070 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
2071 }
2072
2073 #[tokio::test]
2074 async fn load_history_removes_leading_orphan_tool_result() {
2075 use zeph_llm::provider::MessagePart;
2076
2077 let provider = mock_provider(vec![]);
2078 let channel = MockChannel::new(vec![]);
2079 let registry = create_test_registry();
2080 let executor = MockToolExecutor::no_tools();
2081
2082 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2083 let cid = memory.sqlite().create_conversation().await.unwrap();
2084 let sqlite = memory.sqlite();
2085
2086 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2088 tool_use_id: "call_missing".to_string(),
2089 content: "result data".to_string(),
2090 is_error: false,
2091 }])
2092 .unwrap();
2093 sqlite
2094 .save_message_with_parts(
2095 cid,
2096 "user",
2097 "[tool_result: call_missing]\nresult data",
2098 &result_parts,
2099 )
2100 .await
2101 .unwrap();
2102
2103 sqlite
2105 .save_message(cid, "assistant", "here is my response")
2106 .await
2107 .unwrap();
2108
2109 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2110 std::sync::Arc::new(memory),
2111 cid,
2112 50,
2113 5,
2114 100,
2115 );
2116
2117 let messages_before = agent.msg.messages.len();
2118 agent.load_history().await.unwrap();
2119
2120 assert_eq!(
2122 agent.msg.messages.len(),
2123 messages_before + 1,
2124 "orphaned leading tool_result must be removed"
2125 );
2126 assert_eq!(agent.msg.messages.last().unwrap().role, Role::Assistant);
2127 }
2128
2129 #[tokio::test]
2130 async fn load_history_preserves_complete_tool_pairs() {
2131 use zeph_llm::provider::MessagePart;
2132
2133 let provider = mock_provider(vec![]);
2134 let channel = MockChannel::new(vec![]);
2135 let registry = create_test_registry();
2136 let executor = MockToolExecutor::no_tools();
2137
2138 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2139 let cid = memory.sqlite().create_conversation().await.unwrap();
2140 let sqlite = memory.sqlite();
2141
2142 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2144 id: "call_ok".to_string(),
2145 name: "shell".to_string(),
2146 input: serde_json::json!({"command": "pwd"}),
2147 }])
2148 .unwrap();
2149 sqlite
2150 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
2151 .await
2152 .unwrap();
2153
2154 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2155 tool_use_id: "call_ok".to_string(),
2156 content: "/home/user".to_string(),
2157 is_error: false,
2158 }])
2159 .unwrap();
2160 sqlite
2161 .save_message_with_parts(
2162 cid,
2163 "user",
2164 "[tool_result: call_ok]\n/home/user",
2165 &result_parts,
2166 )
2167 .await
2168 .unwrap();
2169
2170 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2171 std::sync::Arc::new(memory),
2172 cid,
2173 50,
2174 5,
2175 100,
2176 );
2177
2178 let messages_before = agent.msg.messages.len();
2179 agent.load_history().await.unwrap();
2180
2181 assert_eq!(
2183 agent.msg.messages.len(),
2184 messages_before + 2,
2185 "complete tool_use/tool_result pair must be preserved"
2186 );
2187 assert_eq!(agent.msg.messages[messages_before].role, Role::Assistant);
2188 assert_eq!(agent.msg.messages[messages_before + 1].role, Role::User);
2189 }
2190
2191 #[tokio::test]
2192 async fn load_history_handles_multiple_trailing_orphans() {
2193 use zeph_llm::provider::MessagePart;
2194
2195 let provider = mock_provider(vec![]);
2196 let channel = MockChannel::new(vec![]);
2197 let registry = create_test_registry();
2198 let executor = MockToolExecutor::no_tools();
2199
2200 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2201 let cid = memory.sqlite().create_conversation().await.unwrap();
2202 let sqlite = memory.sqlite();
2203
2204 sqlite.save_message(cid, "user", "start").await.unwrap();
2206
2207 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
2209 id: "call_1".to_string(),
2210 name: "shell".to_string(),
2211 input: serde_json::json!({}),
2212 }])
2213 .unwrap();
2214 sqlite
2215 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
2216 .await
2217 .unwrap();
2218
2219 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
2221 id: "call_2".to_string(),
2222 name: "read_file".to_string(),
2223 input: serde_json::json!({}),
2224 }])
2225 .unwrap();
2226 sqlite
2227 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
2228 .await
2229 .unwrap();
2230
2231 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2232 std::sync::Arc::new(memory),
2233 cid,
2234 50,
2235 5,
2236 100,
2237 );
2238
2239 let messages_before = agent.msg.messages.len();
2240 agent.load_history().await.unwrap();
2241
2242 assert_eq!(
2244 agent.msg.messages.len(),
2245 messages_before + 1,
2246 "all trailing orphaned tool_use messages must be removed"
2247 );
2248 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
2249 }
2250
2251 #[tokio::test]
2252 async fn load_history_no_tool_messages_unchanged() {
2253 let provider = mock_provider(vec![]);
2254 let channel = MockChannel::new(vec![]);
2255 let registry = create_test_registry();
2256 let executor = MockToolExecutor::no_tools();
2257
2258 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2259 let cid = memory.sqlite().create_conversation().await.unwrap();
2260 let sqlite = memory.sqlite();
2261
2262 sqlite.save_message(cid, "user", "hello").await.unwrap();
2263 sqlite
2264 .save_message(cid, "assistant", "hi there")
2265 .await
2266 .unwrap();
2267 sqlite
2268 .save_message(cid, "user", "how are you?")
2269 .await
2270 .unwrap();
2271
2272 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2273 std::sync::Arc::new(memory),
2274 cid,
2275 50,
2276 5,
2277 100,
2278 );
2279
2280 let messages_before = agent.msg.messages.len();
2281 agent.load_history().await.unwrap();
2282
2283 assert_eq!(
2285 agent.msg.messages.len(),
2286 messages_before + 3,
2287 "plain messages without tool parts must pass through unchanged"
2288 );
2289 }
2290
2291 #[tokio::test]
2292 async fn load_history_removes_both_leading_and_trailing_orphans() {
2293 use zeph_llm::provider::MessagePart;
2294
2295 let provider = mock_provider(vec![]);
2296 let channel = MockChannel::new(vec![]);
2297 let registry = create_test_registry();
2298 let executor = MockToolExecutor::no_tools();
2299
2300 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2301 let cid = memory.sqlite().create_conversation().await.unwrap();
2302 let sqlite = memory.sqlite();
2303
2304 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2306 tool_use_id: "call_leading".to_string(),
2307 content: "orphaned result".to_string(),
2308 is_error: false,
2309 }])
2310 .unwrap();
2311 sqlite
2312 .save_message_with_parts(
2313 cid,
2314 "user",
2315 "[tool_result: call_leading]\norphaned result",
2316 &result_parts,
2317 )
2318 .await
2319 .unwrap();
2320
2321 sqlite
2323 .save_message(cid, "user", "what is 2+2?")
2324 .await
2325 .unwrap();
2326 sqlite.save_message(cid, "assistant", "4").await.unwrap();
2327
2328 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2330 id: "call_trailing".to_string(),
2331 name: "shell".to_string(),
2332 input: serde_json::json!({"command": "date"}),
2333 }])
2334 .unwrap();
2335 sqlite
2336 .save_message_with_parts(
2337 cid,
2338 "assistant",
2339 "[tool_use: shell(call_trailing)]",
2340 &use_parts,
2341 )
2342 .await
2343 .unwrap();
2344
2345 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2346 std::sync::Arc::new(memory),
2347 cid,
2348 50,
2349 5,
2350 100,
2351 );
2352
2353 let messages_before = agent.msg.messages.len();
2354 agent.load_history().await.unwrap();
2355
2356 assert_eq!(
2358 agent.msg.messages.len(),
2359 messages_before + 2,
2360 "both leading and trailing orphans must be removed"
2361 );
2362 assert_eq!(agent.msg.messages[messages_before].role, Role::User);
2363 assert_eq!(agent.msg.messages[messages_before].content, "what is 2+2?");
2364 assert_eq!(
2365 agent.msg.messages[messages_before + 1].role,
2366 Role::Assistant
2367 );
2368 assert_eq!(agent.msg.messages[messages_before + 1].content, "4");
2369 }
2370
2371 #[tokio::test]
2376 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
2377 use zeph_llm::provider::MessagePart;
2378
2379 let provider = mock_provider(vec![]);
2380 let channel = MockChannel::new(vec![]);
2381 let registry = create_test_registry();
2382 let executor = MockToolExecutor::no_tools();
2383
2384 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2385 let cid = memory.sqlite().create_conversation().await.unwrap();
2386 let sqlite = memory.sqlite();
2387
2388 sqlite
2390 .save_message(cid, "user", "first question")
2391 .await
2392 .unwrap();
2393 sqlite
2394 .save_message(cid, "assistant", "first answer")
2395 .await
2396 .unwrap();
2397
2398 let use_parts = serde_json::to_string(&[
2402 MessagePart::ToolUse {
2403 id: "call_mid_1".to_string(),
2404 name: "shell".to_string(),
2405 input: serde_json::json!({"command": "ls"}),
2406 },
2407 MessagePart::Text {
2408 text: "Let me check the files.".to_string(),
2409 },
2410 ])
2411 .unwrap();
2412 sqlite
2413 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
2414 .await
2415 .unwrap();
2416
2417 sqlite
2419 .save_message(cid, "user", "second question")
2420 .await
2421 .unwrap();
2422 sqlite
2423 .save_message(cid, "assistant", "second answer")
2424 .await
2425 .unwrap();
2426
2427 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2428 std::sync::Arc::new(memory),
2429 cid,
2430 50,
2431 5,
2432 100,
2433 );
2434
2435 let messages_before = agent.msg.messages.len();
2436 agent.load_history().await.unwrap();
2437
2438 assert_eq!(
2441 agent.msg.messages.len(),
2442 messages_before + 5,
2443 "message count must be 5 (orphan message kept — has text content)"
2444 );
2445
2446 let orphan = &agent.msg.messages[messages_before + 2];
2448 assert_eq!(orphan.role, Role::Assistant);
2449 assert!(
2450 !orphan
2451 .parts
2452 .iter()
2453 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
2454 "orphaned ToolUse parts must be stripped from mid-history message"
2455 );
2456 assert!(
2458 orphan.parts.iter().any(
2459 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
2460 ),
2461 "text content of orphaned assistant message must be preserved"
2462 );
2463 }
2464
2465 #[tokio::test]
2470 async fn load_history_keeps_tool_only_user_message() {
2471 use zeph_llm::provider::MessagePart;
2472
2473 let provider = mock_provider(vec![]);
2474 let channel = MockChannel::new(vec![]);
2475 let registry = create_test_registry();
2476 let executor = MockToolExecutor::no_tools();
2477
2478 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2479 let cid = memory.sqlite().create_conversation().await.unwrap();
2480 let sqlite = memory.sqlite();
2481
2482 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2484 id: "call_rc3".to_string(),
2485 name: "memory_save".to_string(),
2486 input: serde_json::json!({"content": "something"}),
2487 }])
2488 .unwrap();
2489 sqlite
2490 .save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
2491 .await
2492 .unwrap();
2493
2494 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2496 tool_use_id: "call_rc3".to_string(),
2497 content: "saved".to_string(),
2498 is_error: false,
2499 }])
2500 .unwrap();
2501 sqlite
2502 .save_message_with_parts(cid, "user", "", &result_parts)
2503 .await
2504 .unwrap();
2505
2506 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2507
2508 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2509 std::sync::Arc::new(memory),
2510 cid,
2511 50,
2512 5,
2513 100,
2514 );
2515
2516 let messages_before = agent.msg.messages.len();
2517 agent.load_history().await.unwrap();
2518
2519 assert_eq!(
2522 agent.msg.messages.len(),
2523 messages_before + 3,
2524 "user message with empty content but ToolResult parts must not be dropped"
2525 );
2526
2527 let user_msg = &agent.msg.messages[messages_before + 1];
2529 assert_eq!(user_msg.role, Role::User);
2530 assert!(
2531 user_msg.parts.iter().any(
2532 |p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
2533 ),
2534 "ToolResult part must be preserved on user message with empty content"
2535 );
2536 }
2537
2538 #[tokio::test]
2542 async fn strip_orphans_removes_orphaned_tool_result() {
2543 use zeph_llm::provider::MessagePart;
2544
2545 let provider = mock_provider(vec![]);
2546 let channel = MockChannel::new(vec![]);
2547 let registry = create_test_registry();
2548 let executor = MockToolExecutor::no_tools();
2549
2550 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2551 let cid = memory.sqlite().create_conversation().await.unwrap();
2552 let sqlite = memory.sqlite();
2553
2554 sqlite.save_message(cid, "user", "hello").await.unwrap();
2556 sqlite.save_message(cid, "assistant", "hi").await.unwrap();
2557
2558 sqlite
2560 .save_message(cid, "assistant", "plain answer")
2561 .await
2562 .unwrap();
2563
2564 let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2566 tool_use_id: "call_nonexistent".to_string(),
2567 content: "stale result".to_string(),
2568 is_error: false,
2569 }])
2570 .unwrap();
2571 sqlite
2572 .save_message_with_parts(
2573 cid,
2574 "user",
2575 "[tool_result: call_nonexistent]\nstale result",
2576 &orphan_result_parts,
2577 )
2578 .await
2579 .unwrap();
2580
2581 sqlite
2582 .save_message(cid, "assistant", "final")
2583 .await
2584 .unwrap();
2585
2586 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2587 std::sync::Arc::new(memory),
2588 cid,
2589 50,
2590 5,
2591 100,
2592 );
2593
2594 let messages_before = agent.msg.messages.len();
2595 agent.load_history().await.unwrap();
2596
2597 let loaded = &agent.msg.messages[messages_before..];
2601 for msg in loaded {
2602 assert!(
2603 !msg.parts.iter().any(|p| matches!(
2604 p,
2605 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
2606 )),
2607 "orphaned ToolResult part must be stripped from history"
2608 );
2609 }
2610 }
2611
2612 #[tokio::test]
2615 async fn strip_orphans_keeps_complete_pair() {
2616 use zeph_llm::provider::MessagePart;
2617
2618 let provider = mock_provider(vec![]);
2619 let channel = MockChannel::new(vec![]);
2620 let registry = create_test_registry();
2621 let executor = MockToolExecutor::no_tools();
2622
2623 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2624 let cid = memory.sqlite().create_conversation().await.unwrap();
2625 let sqlite = memory.sqlite();
2626
2627 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2628 id: "call_valid".to_string(),
2629 name: "shell".to_string(),
2630 input: serde_json::json!({"command": "ls"}),
2631 }])
2632 .unwrap();
2633 sqlite
2634 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2635 .await
2636 .unwrap();
2637
2638 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2639 tool_use_id: "call_valid".to_string(),
2640 content: "file.rs".to_string(),
2641 is_error: false,
2642 }])
2643 .unwrap();
2644 sqlite
2645 .save_message_with_parts(cid, "user", "", &result_parts)
2646 .await
2647 .unwrap();
2648
2649 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2650 std::sync::Arc::new(memory),
2651 cid,
2652 50,
2653 5,
2654 100,
2655 );
2656
2657 let messages_before = agent.msg.messages.len();
2658 agent.load_history().await.unwrap();
2659
2660 assert_eq!(
2661 agent.msg.messages.len(),
2662 messages_before + 2,
2663 "complete tool_use/tool_result pair must be preserved"
2664 );
2665
2666 let user_msg = &agent.msg.messages[messages_before + 1];
2667 assert!(
2668 user_msg.parts.iter().any(|p| matches!(
2669 p,
2670 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
2671 )),
2672 "ToolResult part for a matched tool_use must not be stripped"
2673 );
2674 }
2675
2676 #[tokio::test]
2679 async fn strip_orphans_mixed_history() {
2680 use zeph_llm::provider::MessagePart;
2681
2682 let provider = mock_provider(vec![]);
2683 let channel = MockChannel::new(vec![]);
2684 let registry = create_test_registry();
2685 let executor = MockToolExecutor::no_tools();
2686
2687 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2688 let cid = memory.sqlite().create_conversation().await.unwrap();
2689 let sqlite = memory.sqlite();
2690
2691 let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
2693 id: "call_good".to_string(),
2694 name: "shell".to_string(),
2695 input: serde_json::json!({"command": "pwd"}),
2696 }])
2697 .unwrap();
2698 sqlite
2699 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
2700 .await
2701 .unwrap();
2702
2703 let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
2704 tool_use_id: "call_good".to_string(),
2705 content: "/home".to_string(),
2706 is_error: false,
2707 }])
2708 .unwrap();
2709 sqlite
2710 .save_message_with_parts(cid, "user", "", &result_parts_ok)
2711 .await
2712 .unwrap();
2713
2714 sqlite
2716 .save_message(cid, "assistant", "text only")
2717 .await
2718 .unwrap();
2719
2720 let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
2721 tool_use_id: "call_ghost".to_string(),
2722 content: "ghost result".to_string(),
2723 is_error: false,
2724 }])
2725 .unwrap();
2726 sqlite
2727 .save_message_with_parts(
2728 cid,
2729 "user",
2730 "[tool_result: call_ghost]\nghost result",
2731 &orphan_parts,
2732 )
2733 .await
2734 .unwrap();
2735
2736 sqlite
2737 .save_message(cid, "assistant", "final reply")
2738 .await
2739 .unwrap();
2740
2741 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2742 std::sync::Arc::new(memory),
2743 cid,
2744 50,
2745 5,
2746 100,
2747 );
2748
2749 let messages_before = agent.msg.messages.len();
2750 agent.load_history().await.unwrap();
2751
2752 let loaded = &agent.msg.messages[messages_before..];
2753
2754 for msg in loaded {
2756 assert!(
2757 !msg.parts.iter().any(|p| matches!(
2758 p,
2759 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
2760 )),
2761 "orphaned ToolResult (call_ghost) must be stripped from history"
2762 );
2763 }
2764
2765 let has_good_result = loaded.iter().any(|msg| {
2768 msg.role == Role::User
2769 && msg.parts.iter().any(|p| {
2770 matches!(
2771 p,
2772 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
2773 )
2774 })
2775 });
2776 assert!(
2777 has_good_result,
2778 "matched ToolResult (call_good) must be preserved in history"
2779 );
2780 }
2781
2782 #[tokio::test]
2785 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
2786 use zeph_llm::provider::MessagePart;
2787
2788 let provider = mock_provider(vec![]);
2789 let channel = MockChannel::new(vec![]);
2790 let registry = create_test_registry();
2791 let executor = MockToolExecutor::no_tools();
2792
2793 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2794 let cid = memory.sqlite().create_conversation().await.unwrap();
2795 let sqlite = memory.sqlite();
2796
2797 sqlite
2798 .save_message(cid, "user", "run a command")
2799 .await
2800 .unwrap();
2801
2802 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2804 id: "call_ok".to_string(),
2805 name: "shell".to_string(),
2806 input: serde_json::json!({"command": "echo hi"}),
2807 }])
2808 .unwrap();
2809 sqlite
2810 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2811 .await
2812 .unwrap();
2813
2814 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2816 tool_use_id: "call_ok".to_string(),
2817 content: "hi".to_string(),
2818 is_error: false,
2819 }])
2820 .unwrap();
2821 sqlite
2822 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
2823 .await
2824 .unwrap();
2825
2826 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2827
2828 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2829 std::sync::Arc::new(memory),
2830 cid,
2831 50,
2832 5,
2833 100,
2834 );
2835
2836 let messages_before = agent.msg.messages.len();
2837 agent.load_history().await.unwrap();
2838
2839 assert_eq!(
2841 agent.msg.messages.len(),
2842 messages_before + 4,
2843 "matched tool pair must not be removed"
2844 );
2845 let tool_msg = &agent.msg.messages[messages_before + 1];
2846 assert!(
2847 tool_msg
2848 .parts
2849 .iter()
2850 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
2851 "matched ToolUse parts must be preserved"
2852 );
2853 }
2854
2855 #[tokio::test]
2859 async fn persist_cancelled_tool_results_pairs_tool_use() {
2860 use zeph_llm::provider::MessagePart;
2861
2862 let provider = mock_provider(vec![]);
2863 let channel = MockChannel::new(vec![]);
2864 let registry = create_test_registry();
2865 let executor = MockToolExecutor::no_tools();
2866
2867 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2868 let cid = memory.sqlite().create_conversation().await.unwrap();
2869
2870 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2871 std::sync::Arc::new(memory),
2872 cid,
2873 50,
2874 5,
2875 100,
2876 );
2877
2878 let tool_calls = vec![
2880 zeph_llm::provider::ToolUseRequest {
2881 id: "cancel_id_1".to_string(),
2882 name: "shell".to_string().into(),
2883 input: serde_json::json!({}),
2884 },
2885 zeph_llm::provider::ToolUseRequest {
2886 id: "cancel_id_2".to_string(),
2887 name: "read_file".to_string().into(),
2888 input: serde_json::json!({}),
2889 },
2890 ];
2891
2892 agent.persist_cancelled_tool_results(&tool_calls).await;
2893
2894 let history = agent
2895 .services
2896 .memory
2897 .persistence
2898 .memory
2899 .as_ref()
2900 .unwrap()
2901 .sqlite()
2902 .load_history(cid, 50)
2903 .await
2904 .unwrap();
2905
2906 assert_eq!(history.len(), 1);
2908 assert_eq!(history[0].role, Role::User);
2909
2910 for tc in &tool_calls {
2912 assert!(
2913 history[0].parts.iter().any(|p| matches!(
2914 p,
2915 MessagePart::ToolResult { tool_use_id, is_error, .. }
2916 if tool_use_id == &tc.id && *is_error
2917 )),
2918 "tombstone ToolResult for {} must be present and is_error=true",
2919 tc.id
2920 );
2921 }
2922 }
2923
2924 #[tokio::test]
2935 async fn issue_2529_orphaned_legacy_content_pair_is_soft_deleted() {
2936 use zeph_llm::provider::MessagePart;
2937
2938 let provider = mock_provider(vec![]);
2939 let channel = MockChannel::new(vec![]);
2940 let registry = create_test_registry();
2941 let executor = MockToolExecutor::no_tools();
2942
2943 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2944 let cid = memory.sqlite().create_conversation().await.unwrap();
2945 let sqlite = memory.sqlite();
2946
2947 sqlite
2949 .save_message(cid, "user", "save this for me")
2950 .await
2951 .unwrap();
2952
2953 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2956 id: "call_2529".to_string(),
2957 name: "memory_save".to_string(),
2958 input: serde_json::json!({"content": "save this"}),
2959 }])
2960 .unwrap();
2961 let orphan_assistant_id = sqlite
2962 .save_message_with_parts(
2963 cid,
2964 "assistant",
2965 "[tool_use: memory_save(call_2529)]",
2966 &use_parts,
2967 )
2968 .await
2969 .unwrap();
2970
2971 sqlite
2976 .save_message(cid, "assistant", "here is a plain reply")
2977 .await
2978 .unwrap();
2979
2980 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2981 tool_use_id: "call_2529".to_string(),
2982 content: "saved".to_string(),
2983 is_error: false,
2984 }])
2985 .unwrap();
2986 let orphan_user_id = sqlite
2987 .save_message_with_parts(
2988 cid,
2989 "user",
2990 "[tool_result: call_2529]\nsaved",
2991 &result_parts,
2992 )
2993 .await
2994 .unwrap();
2995
2996 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2998
2999 let memory_arc = std::sync::Arc::new(memory);
3000 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3001 memory_arc.clone(),
3002 cid,
3003 50,
3004 5,
3005 100,
3006 );
3007
3008 agent.load_history().await.unwrap();
3009
3010 let assistant_deleted_count: Vec<i64> = zeph_db::query_scalar(
3013 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
3014 )
3015 .bind(orphan_assistant_id)
3016 .fetch_all(memory_arc.sqlite().pool())
3017 .await
3018 .unwrap();
3019
3020 let user_deleted_count: Vec<i64> = zeph_db::query_scalar(
3021 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
3022 )
3023 .bind(orphan_user_id)
3024 .fetch_all(memory_arc.sqlite().pool())
3025 .await
3026 .unwrap();
3027
3028 assert_eq!(
3029 assistant_deleted_count.first().copied().unwrap_or(0),
3030 1,
3031 "orphaned assistant[ToolUse] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
3032 );
3033 assert_eq!(
3034 user_deleted_count.first().copied().unwrap_or(0),
3035 1,
3036 "orphaned user[ToolResult] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
3037 );
3038 }
3039
3040 #[tokio::test]
3044 async fn issue_2529_soft_delete_is_idempotent_across_sessions() {
3045 use zeph_llm::provider::MessagePart;
3046
3047 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3048 let cid = memory.sqlite().create_conversation().await.unwrap();
3049 let sqlite = memory.sqlite();
3050
3051 sqlite
3053 .save_message(cid, "user", "do something")
3054 .await
3055 .unwrap();
3056
3057 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
3059 id: "call_idem".to_string(),
3060 name: "shell".to_string(),
3061 input: serde_json::json!({"command": "ls"}),
3062 }])
3063 .unwrap();
3064 sqlite
3065 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_idem)]", &use_parts)
3066 .await
3067 .unwrap();
3068
3069 sqlite
3071 .save_message(cid, "assistant", "continuing")
3072 .await
3073 .unwrap();
3074
3075 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
3077 tool_use_id: "call_idem".to_string(),
3078 content: "output".to_string(),
3079 is_error: false,
3080 }])
3081 .unwrap();
3082 sqlite
3083 .save_message_with_parts(
3084 cid,
3085 "user",
3086 "[tool_result: call_idem]\noutput",
3087 &result_parts,
3088 )
3089 .await
3090 .unwrap();
3091
3092 sqlite
3093 .save_message(cid, "assistant", "final")
3094 .await
3095 .unwrap();
3096
3097 let memory_arc = std::sync::Arc::new(memory);
3098
3099 let mut agent1 = Agent::new(
3101 mock_provider(vec![]),
3102 MockChannel::new(vec![]),
3103 create_test_registry(),
3104 None,
3105 5,
3106 MockToolExecutor::no_tools(),
3107 )
3108 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
3109 agent1.load_history().await.unwrap();
3110 let count_after_first = agent1.msg.messages.len();
3111
3112 let mut agent2 = Agent::new(
3115 mock_provider(vec![]),
3116 MockChannel::new(vec![]),
3117 create_test_registry(),
3118 None,
3119 5,
3120 MockToolExecutor::no_tools(),
3121 )
3122 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
3123 agent2.load_history().await.unwrap();
3124 let count_after_second = agent2.msg.messages.len();
3125
3126 assert_eq!(
3128 count_after_first, count_after_second,
3129 "second load_history must load the same message count as the first (soft-deleted orphans excluded)"
3130 );
3131 }
3132
3133 #[tokio::test]
3137 async fn issue_2529_message_with_text_and_tool_tag_is_kept_after_part_strip() {
3138 use zeph_llm::provider::MessagePart;
3139
3140 let provider = mock_provider(vec![]);
3141 let channel = MockChannel::new(vec![]);
3142 let registry = create_test_registry();
3143 let executor = MockToolExecutor::no_tools();
3144
3145 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3146 let cid = memory.sqlite().create_conversation().await.unwrap();
3147 let sqlite = memory.sqlite();
3148
3149 sqlite
3151 .save_message(cid, "user", "check the files")
3152 .await
3153 .unwrap();
3154
3155 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
3158 id: "call_mixed".to_string(),
3159 name: "shell".to_string(),
3160 input: serde_json::json!({"command": "ls"}),
3161 }])
3162 .unwrap();
3163 sqlite
3164 .save_message_with_parts(
3165 cid,
3166 "assistant",
3167 "Let me list the directory. [tool_use: shell(call_mixed)]",
3168 &use_parts,
3169 )
3170 .await
3171 .unwrap();
3172
3173 sqlite.save_message(cid, "user", "thanks").await.unwrap();
3175 sqlite
3176 .save_message(cid, "assistant", "you are welcome")
3177 .await
3178 .unwrap();
3179
3180 let memory_arc = std::sync::Arc::new(memory);
3181 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3182 memory_arc.clone(),
3183 cid,
3184 50,
3185 5,
3186 100,
3187 );
3188
3189 let messages_before = agent.msg.messages.len();
3190 agent.load_history().await.unwrap();
3191
3192 assert_eq!(
3194 agent.msg.messages.len(),
3195 messages_before + 4,
3196 "assistant message with text + tool tag must not be removed after ToolUse strip"
3197 );
3198
3199 let mixed_msg = agent
3201 .msg
3202 .messages
3203 .iter()
3204 .find(|m| m.content.contains("Let me list the directory"))
3205 .expect("mixed-content assistant message must still be in history");
3206 assert!(
3207 !mixed_msg
3208 .parts
3209 .iter()
3210 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
3211 "orphaned ToolUse parts must be stripped even when message has meaningful text"
3212 );
3213 assert_eq!(
3214 mixed_msg.content, "Let me list the directory. [tool_use: shell(call_mixed)]",
3215 "content field must be unchanged — only parts are stripped"
3216 );
3217 }
3218
3219 #[tokio::test]
3222 async fn persist_message_skipped_tool_result_does_not_embed() {
3223 use zeph_llm::provider::MessagePart;
3224
3225 let provider = mock_provider(vec![]);
3226 let channel = MockChannel::new(vec![]);
3227 let registry = create_test_registry();
3228 let executor = MockToolExecutor::no_tools();
3229
3230 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
3231 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3232 let cid = memory.sqlite().create_conversation().await.unwrap();
3233
3234 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3235 .with_metrics(tx)
3236 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
3237 agent.services.memory.persistence.autosave_assistant = true;
3238 agent.services.memory.persistence.autosave_min_length = 0;
3239
3240 let parts = vec![MessagePart::ToolResult {
3241 tool_use_id: "tu1".into(),
3242 content: "[skipped] bash tool was blocked by utility gate".into(),
3243 is_error: false,
3244 }];
3245
3246 agent
3247 .persist_message(
3248 Role::User,
3249 "[skipped] bash tool was blocked by utility gate",
3250 &parts,
3251 false,
3252 )
3253 .await;
3254
3255 assert_eq!(
3256 rx.borrow().embeddings_generated,
3257 0,
3258 "[skipped] ToolResult must not be embedded into Qdrant"
3259 );
3260 }
3261
3262 #[tokio::test]
3263 async fn persist_message_stopped_tool_result_does_not_embed() {
3264 use zeph_llm::provider::MessagePart;
3265
3266 let provider = mock_provider(vec![]);
3267 let channel = MockChannel::new(vec![]);
3268 let registry = create_test_registry();
3269 let executor = MockToolExecutor::no_tools();
3270
3271 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
3272 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3273 let cid = memory.sqlite().create_conversation().await.unwrap();
3274
3275 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3276 .with_metrics(tx)
3277 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
3278 agent.services.memory.persistence.autosave_assistant = true;
3279 agent.services.memory.persistence.autosave_min_length = 0;
3280
3281 let parts = vec![MessagePart::ToolResult {
3282 tool_use_id: "tu2".into(),
3283 content: "[stopped] execution limit reached".into(),
3284 is_error: false,
3285 }];
3286
3287 agent
3288 .persist_message(
3289 Role::User,
3290 "[stopped] execution limit reached",
3291 &parts,
3292 false,
3293 )
3294 .await;
3295
3296 assert_eq!(
3297 rx.borrow().embeddings_generated,
3298 0,
3299 "[stopped] ToolResult must not be embedded into Qdrant"
3300 );
3301 }
3302
3303 #[tokio::test]
3304 async fn persist_message_normal_tool_result_is_saved_not_blocked_by_guard() {
3305 use zeph_llm::provider::MessagePart;
3308
3309 let provider = mock_provider(vec![]);
3310 let channel = MockChannel::new(vec![]);
3311 let registry = create_test_registry();
3312 let executor = MockToolExecutor::no_tools();
3313
3314 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3315 let cid = memory.sqlite().create_conversation().await.unwrap();
3316 let memory_arc = std::sync::Arc::new(memory);
3317
3318 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3319 memory_arc.clone(),
3320 cid,
3321 50,
3322 5,
3323 100,
3324 );
3325 agent.services.memory.persistence.autosave_assistant = true;
3326 agent.services.memory.persistence.autosave_min_length = 0;
3327
3328 let content = "total 42\ndrwxr-xr-x 5 user group";
3329 let parts = vec![MessagePart::ToolResult {
3330 tool_use_id: "tu3".into(),
3331 content: content.into(),
3332 is_error: false,
3333 }];
3334
3335 agent
3336 .persist_message(Role::User, content, &parts, false)
3337 .await;
3338
3339 let history = memory_arc.sqlite().load_history(cid, 50).await.unwrap();
3341 assert_eq!(
3342 history.len(),
3343 1,
3344 "normal ToolResult must be saved to SQLite"
3345 );
3346 assert_eq!(history[0].content, content);
3347 }
3348
3349 #[test]
3354 fn trajectory_extraction_slice_bounds_messages() {
3355 let max_messages: usize = 20;
3357 let total_messages = 100usize;
3358
3359 let tail_start = total_messages.saturating_sub(max_messages);
3360 let window = total_messages - tail_start;
3361
3362 assert_eq!(
3363 window, 20,
3364 "slice should contain exactly max_messages items"
3365 );
3366 assert_eq!(tail_start, 80, "slice should start at len - max_messages");
3367 }
3368
3369 #[test]
3370 fn trajectory_extraction_slice_handles_few_messages() {
3371 let max_messages: usize = 20;
3372 let total_messages = 5usize;
3373
3374 let tail_start = total_messages.saturating_sub(max_messages);
3375 let window = total_messages - tail_start;
3376
3377 assert_eq!(window, 5, "should return all messages when fewer than max");
3378 assert_eq!(tail_start, 0, "slice should start from the beginning");
3379 }
3380
3381 #[tokio::test]
3387 async fn regression_3168_complete_tool_pair_survives_round_trip() {
3388 use zeph_llm::provider::MessagePart;
3389
3390 let provider = mock_provider(vec![]);
3391 let channel = MockChannel::new(vec![]);
3392 let registry = create_test_registry();
3393 let executor = MockToolExecutor::no_tools();
3394
3395 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3396 let cid = memory.sqlite().create_conversation().await.unwrap();
3397 let sqlite = memory.sqlite();
3398
3399 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
3400 id: "r3168_call".to_string(),
3401 name: "shell".to_string(),
3402 input: serde_json::json!({"command": "echo hi"}),
3403 }])
3404 .unwrap();
3405 sqlite
3406 .save_message_with_parts(
3407 cid,
3408 "assistant",
3409 "[tool_use: shell(r3168_call)]",
3410 &use_parts,
3411 )
3412 .await
3413 .unwrap();
3414
3415 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
3416 tool_use_id: "r3168_call".to_string(),
3417 content: "[skipped]".to_string(),
3418 is_error: false,
3419 }])
3420 .unwrap();
3421 sqlite
3422 .save_message_with_parts(cid, "user", "[tool_result: r3168_call]", &result_parts)
3423 .await
3424 .unwrap();
3425
3426 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3427 std::sync::Arc::new(memory),
3428 cid,
3429 50,
3430 5,
3431 100,
3432 );
3433
3434 let base = agent.msg.messages.len();
3435 agent.load_history().await.unwrap();
3436
3437 assert_eq!(
3438 agent.msg.messages.len(),
3439 base + 2,
3440 "both messages of the complete pair must survive load_history"
3441 );
3442
3443 let assistant_msg = agent
3444 .msg
3445 .messages
3446 .iter()
3447 .find(|m| m.role == Role::Assistant)
3448 .expect("assistant message missing after load_history");
3449 assert!(
3450 assistant_msg
3451 .parts
3452 .iter()
3453 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "r3168_call")),
3454 "ToolUse part must be preserved in assistant message"
3455 );
3456
3457 let user_msg = agent
3458 .msg
3459 .messages
3460 .iter()
3461 .rev()
3462 .find(|m| m.role == Role::User)
3463 .expect("user message missing after load_history");
3464 assert!(
3465 user_msg.parts.iter().any(|p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "r3168_call")),
3466 "ToolResult part must be preserved in user message"
3467 );
3468 }
3469
3470 #[tokio::test]
3475 async fn regression_3168_corrupt_parts_row_skipped_on_load() {
3476 use zeph_llm::provider::MessagePart;
3477
3478 let provider = mock_provider(vec![]);
3479 let channel = MockChannel::new(vec![]);
3480 let registry = create_test_registry();
3481 let executor = MockToolExecutor::no_tools();
3482
3483 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3484 let cid = memory.sqlite().create_conversation().await.unwrap();
3485 let sqlite = memory.sqlite();
3486
3487 sqlite
3491 .save_message_with_parts(cid, "assistant", "[tool_use: shell(corrupt)]", "[]")
3492 .await
3493 .unwrap();
3494
3495 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
3497 tool_use_id: "corrupt".to_string(),
3498 content: "result".to_string(),
3499 is_error: false,
3500 }])
3501 .unwrap();
3502 sqlite
3503 .save_message_with_parts(cid, "user", "[tool_result: corrupt]", &result_parts)
3504 .await
3505 .unwrap();
3506
3507 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3508 std::sync::Arc::new(memory),
3509 cid,
3510 50,
3511 5,
3512 100,
3513 );
3514
3515 let base = agent.msg.messages.len();
3516 agent.load_history().await.unwrap();
3517
3518 let loaded = agent.msg.messages.len() - base;
3523 let orphan_present = agent.msg.messages.iter().any(|m| {
3526 m.role == Role::User
3527 && m.parts.iter().any(|p| {
3528 matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "corrupt")
3529 })
3530 });
3531 assert!(
3532 !orphan_present,
3533 "orphaned ToolResult must not survive load_history; loaded={loaded}"
3534 );
3535 }
3536}