1use std::collections::HashSet;
5
6use crate::channel::Channel;
7use zeph_llm::provider::{LlmProvider as _, Message, MessagePart, Role};
8use zeph_memory::store::role_str;
9
10use super::Agent;
11
12fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> (usize, Vec<i64>) {
28 let mut removed = 0;
29 let mut db_ids: Vec<i64> = Vec::new();
30
31 loop {
32 if let Some(last) = messages.last()
34 && last.role == Role::Assistant
35 && last
36 .parts
37 .iter()
38 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
39 {
40 let ids: Vec<String> = last
41 .parts
42 .iter()
43 .filter_map(|p| {
44 if let MessagePart::ToolUse { id, .. } = p {
45 Some(id.clone())
46 } else {
47 None
48 }
49 })
50 .collect();
51 tracing::warn!(
52 tool_ids = ?ids,
53 "removing orphaned trailing tool_use message from restored history"
54 );
55 if let Some(db_id) = messages.last().and_then(|m| m.metadata.db_id) {
56 db_ids.push(db_id);
57 }
58 messages.pop();
59 removed += 1;
60 continue;
61 }
62
63 if let Some(first) = messages.first()
65 && first.role == Role::User
66 && first
67 .parts
68 .iter()
69 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
70 {
71 let ids: Vec<String> = first
72 .parts
73 .iter()
74 .filter_map(|p| {
75 if let MessagePart::ToolResult { tool_use_id, .. } = p {
76 Some(tool_use_id.clone())
77 } else {
78 None
79 }
80 })
81 .collect();
82 tracing::warn!(
83 tool_use_ids = ?ids,
84 "removing orphaned leading tool_result message from restored history"
85 );
86 if let Some(db_id) = messages.first().and_then(|m| m.metadata.db_id) {
87 db_ids.push(db_id);
88 }
89 messages.remove(0);
90 removed += 1;
91 continue;
92 }
93
94 break;
95 }
96
97 let (mid_removed, mid_db_ids) = strip_mid_history_orphans(messages);
100 removed += mid_removed;
101 db_ids.extend(mid_db_ids);
102
103 (removed, db_ids)
104}
105
106fn orphaned_tool_use_ids(msg: &Message, next_msg: Option<&Message>) -> HashSet<String> {
108 let matched: HashSet<String> = next_msg
109 .filter(|n| n.role == Role::User)
110 .map(|n| {
111 msg.parts
112 .iter()
113 .filter_map(|p| if let MessagePart::ToolUse { id, .. } = p { Some(id.clone()) } else { None })
114 .filter(|uid| n.parts.iter().any(|np| matches!(np, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)))
115 .collect()
116 })
117 .unwrap_or_default();
118 msg.parts
119 .iter()
120 .filter_map(|p| {
121 if let MessagePart::ToolUse { id, .. } = p
122 && !matched.contains(id)
123 {
124 Some(id.clone())
125 } else {
126 None
127 }
128 })
129 .collect()
130}
131
132fn orphaned_tool_result_ids(msg: &Message, prev_msg: Option<&Message>) -> HashSet<String> {
134 let avail: HashSet<&str> = prev_msg
135 .filter(|p| p.role == Role::Assistant)
136 .map(|p| {
137 p.parts
138 .iter()
139 .filter_map(|part| {
140 if let MessagePart::ToolUse { id, .. } = part {
141 Some(id.as_str())
142 } else {
143 None
144 }
145 })
146 .collect()
147 })
148 .unwrap_or_default();
149 msg.parts
150 .iter()
151 .filter_map(|p| {
152 if let MessagePart::ToolResult { tool_use_id, .. } = p
153 && !avail.contains(tool_use_id.as_str())
154 {
155 Some(tool_use_id.clone())
156 } else {
157 None
158 }
159 })
160 .collect()
161}
162
163fn has_meaningful_content(content: &str) -> bool {
184 const PREFIXES: [&str; 3] = ["[tool_use: ", "[tool_result: ", "[tool output: "];
185
186 let mut remaining = content.trim();
187
188 loop {
189 let next = PREFIXES
191 .iter()
192 .filter_map(|prefix| remaining.find(prefix).map(|pos| (pos, *prefix)))
193 .min_by_key(|(pos, _)| *pos);
194
195 let Some((start, prefix)) = next else {
196 break;
198 };
199
200 if !remaining[..start].trim().is_empty() {
202 return true;
203 }
204
205 let after_prefix = &remaining[start + prefix.len()..];
207 let Some(close) = after_prefix.find(']') else {
208 return true;
210 };
211
212 let tag_end = start + prefix.len() + close + 1;
214
215 if prefix == "[tool_result: " || prefix == "[tool output: " {
216 let body = remaining[tag_end..].trim_start_matches('\n');
219 let next_tag = PREFIXES
220 .iter()
221 .filter_map(|p| body.find(p))
222 .min()
223 .unwrap_or(body.len());
224 remaining = &body[next_tag..];
225 } else {
226 remaining = &remaining[tag_end..];
227 }
228 }
229
230 !remaining.trim().is_empty()
231}
232
233fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> (usize, Vec<i64>) {
246 let mut removed = 0;
247 let mut db_ids: Vec<i64> = Vec::new();
248 let mut i = 0;
249 while i < messages.len() {
250 if messages[i].role == Role::Assistant
254 && messages[i]
255 .parts
256 .iter()
257 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
258 {
259 let orphaned_ids = orphaned_tool_use_ids(&messages[i], messages.get(i + 1));
260 if !orphaned_ids.is_empty() {
261 tracing::warn!(
262 tool_ids = ?orphaned_ids,
263 index = i,
264 "stripping orphaned mid-history tool_use parts from assistant message"
265 );
266 messages[i].parts.retain(
267 |p| !matches!(p, MessagePart::ToolUse { id, .. } if orphaned_ids.contains(id)),
268 );
269 let is_empty =
270 !has_meaningful_content(&messages[i].content) && messages[i].parts.is_empty();
271 if is_empty {
272 if let Some(db_id) = messages[i].metadata.db_id {
273 db_ids.push(db_id);
274 }
275 messages.remove(i);
276 removed += 1;
277 continue; }
279 }
280 }
281
282 if messages[i].role == Role::User
284 && messages[i]
285 .parts
286 .iter()
287 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
288 {
289 let orphaned_ids = orphaned_tool_result_ids(
290 &messages[i],
291 if i > 0 { messages.get(i - 1) } else { None },
292 );
293 if !orphaned_ids.is_empty() {
294 tracing::warn!(
295 tool_use_ids = ?orphaned_ids,
296 index = i,
297 "stripping orphaned mid-history tool_result parts from user message"
298 );
299 messages[i].parts.retain(|p| {
300 !matches!(p, MessagePart::ToolResult { tool_use_id, .. } if orphaned_ids.contains(tool_use_id.as_str()))
301 });
302
303 let is_empty =
304 !has_meaningful_content(&messages[i].content) && messages[i].parts.is_empty();
305 if is_empty {
306 if let Some(db_id) = messages[i].metadata.db_id {
307 db_ids.push(db_id);
308 }
309 messages.remove(i);
310 removed += 1;
311 continue;
313 }
314 }
315 }
316
317 i += 1;
318 }
319 (removed, db_ids)
320}
321
322impl<C: Channel> Agent<C> {
323 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
329 let (Some(memory), Some(cid)) =
330 (&self.memory_state.memory, self.memory_state.conversation_id)
331 else {
332 return Ok(());
333 };
334
335 let history = memory
336 .sqlite()
337 .load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
338 .await?;
339 if !history.is_empty() {
340 let mut loaded = 0;
341 let mut skipped = 0;
342
343 for msg in history {
344 if !has_meaningful_content(&msg.content) && msg.parts.is_empty() {
349 tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
350 skipped += 1;
351 continue;
352 }
353 self.msg.messages.push(msg);
354 loaded += 1;
355 }
356
357 let history_start = self.msg.messages.len() - loaded;
359 let mut restored_slice = self.msg.messages.split_off(history_start);
360 let (orphans, orphan_db_ids) = sanitize_tool_pairs(&mut restored_slice);
361 skipped += orphans;
362 loaded = loaded.saturating_sub(orphans);
363 self.msg.messages.append(&mut restored_slice);
364
365 if !orphan_db_ids.is_empty() {
366 let ids: Vec<zeph_memory::types::MessageId> = orphan_db_ids
367 .iter()
368 .map(|&id| zeph_memory::types::MessageId(id))
369 .collect();
370 if let Err(e) = memory.sqlite().soft_delete_messages(&ids).await {
371 tracing::warn!(
372 count = ids.len(),
373 error = %e,
374 "failed to soft-delete orphaned tool-pair messages from DB"
375 );
376 } else {
377 tracing::debug!(
378 count = ids.len(),
379 "soft-deleted orphaned tool-pair messages from DB"
380 );
381 }
382 }
383
384 tracing::info!("restored {loaded} message(s) from conversation {cid}");
385 if skipped > 0 {
386 tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
387 }
388
389 if loaded > 0 {
390 let _ = memory
393 .sqlite()
394 .increment_session_counts_for_conversation(cid)
395 .await
396 .inspect_err(|e| {
397 tracing::warn!(error = %e, "failed to increment tier session counts");
398 });
399 }
400 }
401
402 if let Ok(count) = memory.message_count(cid).await {
403 let count_u64 = u64::try_from(count).unwrap_or(0);
404 self.update_metrics(|m| {
405 m.sqlite_message_count = count_u64;
406 });
407 }
408
409 if let Ok(count) = memory.sqlite().count_semantic_facts().await {
410 let count_u64 = u64::try_from(count).unwrap_or(0);
411 self.update_metrics(|m| {
412 m.semantic_fact_count = count_u64;
413 });
414 }
415
416 if let Ok(count) = memory.unsummarized_message_count(cid).await {
417 self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
418 }
419
420 self.recompute_prompt_tokens();
421 Ok(())
422 }
423
424 pub(crate) async fn persist_message(
430 &mut self,
431 role: Role,
432 content: &str,
433 parts: &[MessagePart],
434 has_injection_flags: bool,
435 ) {
436 let (Some(memory), Some(cid)) =
437 (&self.memory_state.memory, self.memory_state.conversation_id)
438 else {
439 return;
440 };
441
442 let parts_json = if parts.is_empty() {
443 "[]".to_string()
444 } else {
445 serde_json::to_string(parts).unwrap_or_else(|e| {
446 tracing::warn!("failed to serialize message parts, storing empty: {e}");
447 "[]".to_string()
448 })
449 };
450
451 let guard_event = self
455 .security
456 .exfiltration_guard
457 .should_guard_memory_write(has_injection_flags);
458 if let Some(ref event) = guard_event {
459 tracing::warn!(
460 ?event,
461 "exfiltration guard: skipping Qdrant embedding for flagged content"
462 );
463 self.update_metrics(|m| m.exfiltration_memory_guards += 1);
464 self.push_security_event(
465 crate::metrics::SecurityEventCategory::ExfiltrationBlock,
466 "memory_write",
467 "Qdrant embedding skipped: flagged content",
468 );
469 }
470
471 let skip_embedding = guard_event.is_some();
472
473 let should_embed = if skip_embedding {
474 false
475 } else {
476 match role {
477 Role::Assistant => {
478 self.memory_state.autosave_assistant
479 && content.len() >= self.memory_state.autosave_min_length
480 }
481 _ => true,
482 }
483 };
484
485 let goal_text = self.memory_state.goal_text.clone();
486
487 let (embedding_stored, was_persisted) = if should_embed {
488 match memory
489 .remember_with_parts(
490 cid,
491 role_str(role),
492 content,
493 &parts_json,
494 goal_text.as_deref(),
495 )
496 .await
497 {
498 Ok((Some(message_id), stored)) => {
499 self.last_persisted_message_id = Some(message_id.0);
500 (stored, true)
501 }
502 Ok((None, _)) => {
503 return;
505 }
506 Err(e) => {
507 tracing::error!("failed to persist message: {e:#}");
508 return;
509 }
510 }
511 } else {
512 match memory
513 .save_only(cid, role_str(role), content, &parts_json)
514 .await
515 {
516 Ok(message_id) => {
517 self.last_persisted_message_id = Some(message_id.0);
518 (false, true)
519 }
520 Err(e) => {
521 tracing::error!("failed to persist message: {e:#}");
522 return;
523 }
524 }
525 };
526
527 if !was_persisted {
528 return;
529 }
530
531 self.memory_state.unsummarized_count += 1;
532
533 self.update_metrics(|m| {
534 m.sqlite_message_count += 1;
535 if embedding_stored {
536 m.embeddings_generated += 1;
537 }
538 });
539
540 self.check_summarization().await;
541
542 let has_tool_result_parts = parts
545 .iter()
546 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
547
548 self.maybe_spawn_graph_extraction(content, has_injection_flags, has_tool_result_parts)
549 .await;
550 }
551
552 #[allow(clippy::too_many_lines)]
553 async fn maybe_spawn_graph_extraction(
554 &mut self,
555 content: &str,
556 has_injection_flags: bool,
557 has_tool_result_parts: bool,
558 ) {
559 use zeph_memory::semantic::GraphExtractionConfig;
560
561 if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
562 return;
563 }
564
565 if has_tool_result_parts {
568 tracing::debug!("graph extraction skipped: message contains ToolResult parts");
569 return;
570 }
571
572 if has_injection_flags {
574 tracing::warn!("graph extraction skipped: injection patterns detected in content");
575 return;
576 }
577
578 let extraction_cfg = {
580 let cfg = &self.memory_state.graph_config;
581 if !cfg.enabled {
582 return;
583 }
584 GraphExtractionConfig {
585 max_entities: cfg.max_entities_per_message,
586 max_edges: cfg.max_edges_per_message,
587 extraction_timeout_secs: cfg.extraction_timeout_secs,
588 community_refresh_interval: cfg.community_refresh_interval,
589 expired_edge_retention_days: cfg.expired_edge_retention_days,
590 max_entities_cap: cfg.max_entities,
591 community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
592 community_summary_concurrency: cfg.community_summary_concurrency,
593 lpa_edge_chunk_size: cfg.lpa_edge_chunk_size,
594 note_linking: zeph_memory::NoteLinkingConfig {
595 enabled: cfg.note_linking.enabled,
596 similarity_threshold: cfg.note_linking.similarity_threshold,
597 top_k: cfg.note_linking.top_k,
598 timeout_secs: cfg.note_linking.timeout_secs,
599 },
600 link_weight_decay_lambda: cfg.link_weight_decay_lambda,
601 link_weight_decay_interval_secs: cfg.link_weight_decay_interval_secs,
602 belief_revision_enabled: cfg.belief_revision.enabled,
603 belief_revision_similarity_threshold: cfg.belief_revision.similarity_threshold,
604 }
605 };
606
607 if self.rpe_should_skip(content).await {
609 tracing::debug!("D-MEM RPE: low-surprise turn, skipping graph extraction");
610 return;
611 }
612
613 let context_messages: Vec<String> = self
617 .msg
618 .messages
619 .iter()
620 .rev()
621 .filter(|m| {
622 m.role == Role::User
623 && !m
624 .parts
625 .iter()
626 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
627 })
628 .take(4)
629 .map(|m| m.content.clone())
630 .collect();
631
632 let _ = self.channel.send_status("saving to graph...").await;
633
634 if let Some(memory) = &self.memory_state.memory {
635 let validator: zeph_memory::semantic::PostExtractValidator =
638 if self.security.memory_validator.is_enabled() {
639 let v = self.security.memory_validator.clone();
640 Some(Box::new(move |result| {
641 v.validate_graph_extraction(result)
642 .map_err(|e| e.to_string())
643 }))
644 } else {
645 None
646 };
647 let extraction_handle = memory.spawn_graph_extraction(
648 content.to_owned(),
649 context_messages,
650 extraction_cfg,
651 validator,
652 );
653 if let (Some(store), Some(tx)) =
656 (memory.graph_store.clone(), self.metrics.metrics_tx.clone())
657 {
658 let start = self.lifecycle.start_time;
659 tokio::spawn(async move {
660 let _ = extraction_handle.await;
661 let (entities, edges, communities) = tokio::join!(
662 store.entity_count(),
663 store.active_edge_count(),
664 store.community_count()
665 );
666 let elapsed = start.elapsed().as_secs();
667 tx.send_modify(|m| {
668 m.uptime_seconds = elapsed;
669 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
670 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
671 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
672 });
673 });
674 }
675 }
676 let _ = self.channel.send_status("").await;
677 self.sync_community_detection_failures();
678 self.sync_graph_extraction_metrics();
679 self.sync_graph_counts().await;
680 #[cfg(feature = "compression-guidelines")]
681 self.sync_guidelines_status().await;
682 }
683
684 pub(crate) async fn check_summarization(&mut self) {
685 let (Some(memory), Some(cid)) =
686 (&self.memory_state.memory, self.memory_state.conversation_id)
687 else {
688 return;
689 };
690
691 if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
692 let _ = self.channel.send_status("summarizing...").await;
693 let batch_size = self.memory_state.summarization_threshold / 2;
694 match memory.summarize(cid, batch_size).await {
695 Ok(Some(summary_id)) => {
696 tracing::info!("created summary {summary_id} for conversation {cid}");
697 self.memory_state.unsummarized_count = 0;
698 self.update_metrics(|m| {
699 m.summaries_count += 1;
700 });
701 }
702 Ok(None) => {
703 tracing::debug!("no summarization needed");
704 }
705 Err(e) => {
706 tracing::error!("summarization failed: {e:#}");
707 }
708 }
709 let _ = self.channel.send_status("").await;
710 }
711 }
712
713 async fn rpe_should_skip(&mut self, content: &str) -> bool {
718 let Some(ref rpe_mutex) = self.memory_state.rpe_router else {
719 return false;
720 };
721 let Some(memory) = &self.memory_state.memory else {
722 return false;
723 };
724 let candidates = zeph_memory::extract_candidate_entities(content);
725 let provider = memory.provider();
726 let Ok(Ok(emb_vec)) =
727 tokio::time::timeout(std::time::Duration::from_secs(5), provider.embed(content)).await
728 else {
729 return false; };
731 if let Ok(mut router) = rpe_mutex.lock() {
732 let signal = router.compute(&emb_vec, &candidates);
733 router.push_embedding(emb_vec);
734 router.push_entities(&candidates);
735 !signal.should_extract
736 } else {
737 tracing::warn!("rpe_router mutex poisoned; falling through to extract");
738 false
739 }
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::super::agent_tests::{
746 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
747 };
748 use super::*;
749 use zeph_llm::any::AnyProvider;
750 use zeph_memory::semantic::SemanticMemory;
751
752 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
753 SemanticMemory::new(
754 ":memory:",
755 "http://127.0.0.1:1",
756 provider.clone(),
757 "test-model",
758 )
759 .await
760 .unwrap()
761 }
762
763 #[tokio::test]
764 async fn load_history_without_memory_returns_ok() {
765 let provider = mock_provider(vec![]);
766 let channel = MockChannel::new(vec![]);
767 let registry = create_test_registry();
768 let executor = MockToolExecutor::no_tools();
769 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
770
771 let result = agent.load_history().await;
772 assert!(result.is_ok());
773 assert_eq!(agent.msg.messages.len(), 1); }
776
777 #[tokio::test]
778 async fn load_history_with_messages_injects_into_agent() {
779 let provider = mock_provider(vec![]);
780 let channel = MockChannel::new(vec![]);
781 let registry = create_test_registry();
782 let executor = MockToolExecutor::no_tools();
783
784 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
785 let cid = memory.sqlite().create_conversation().await.unwrap();
786
787 memory
788 .sqlite()
789 .save_message(cid, "user", "hello from history")
790 .await
791 .unwrap();
792 memory
793 .sqlite()
794 .save_message(cid, "assistant", "hi back")
795 .await
796 .unwrap();
797
798 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
799 std::sync::Arc::new(memory),
800 cid,
801 50,
802 5,
803 100,
804 );
805
806 let messages_before = agent.msg.messages.len();
807 agent.load_history().await.unwrap();
808 assert_eq!(agent.msg.messages.len(), messages_before + 2);
810 }
811
812 #[tokio::test]
813 async fn load_history_skips_empty_messages() {
814 let provider = mock_provider(vec![]);
815 let channel = MockChannel::new(vec![]);
816 let registry = create_test_registry();
817 let executor = MockToolExecutor::no_tools();
818
819 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
820 let cid = memory.sqlite().create_conversation().await.unwrap();
821
822 memory
824 .sqlite()
825 .save_message(cid, "user", " ")
826 .await
827 .unwrap();
828 memory
829 .sqlite()
830 .save_message(cid, "user", "real message")
831 .await
832 .unwrap();
833
834 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
835 std::sync::Arc::new(memory),
836 cid,
837 50,
838 5,
839 100,
840 );
841
842 let messages_before = agent.msg.messages.len();
843 agent.load_history().await.unwrap();
844 assert_eq!(agent.msg.messages.len(), messages_before + 1);
846 }
847
848 #[tokio::test]
849 async fn load_history_with_empty_store_returns_ok() {
850 let provider = mock_provider(vec![]);
851 let channel = MockChannel::new(vec![]);
852 let registry = create_test_registry();
853 let executor = MockToolExecutor::no_tools();
854
855 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
856 let cid = memory.sqlite().create_conversation().await.unwrap();
857
858 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
859 std::sync::Arc::new(memory),
860 cid,
861 50,
862 5,
863 100,
864 );
865
866 let messages_before = agent.msg.messages.len();
867 agent.load_history().await.unwrap();
868 assert_eq!(agent.msg.messages.len(), messages_before);
870 }
871
872 #[tokio::test]
873 async fn load_history_increments_session_count_for_existing_messages() {
874 let provider = mock_provider(vec![]);
875 let channel = MockChannel::new(vec![]);
876 let registry = create_test_registry();
877 let executor = MockToolExecutor::no_tools();
878
879 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
880 let cid = memory.sqlite().create_conversation().await.unwrap();
881
882 let id1 = memory
884 .sqlite()
885 .save_message(cid, "user", "hello")
886 .await
887 .unwrap();
888 let id2 = memory
889 .sqlite()
890 .save_message(cid, "assistant", "hi")
891 .await
892 .unwrap();
893
894 let memory_arc = std::sync::Arc::new(memory);
895 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
896 memory_arc.clone(),
897 cid,
898 50,
899 5,
900 100,
901 );
902
903 agent.load_history().await.unwrap();
904
905 let counts: Vec<i64> = zeph_db::query_scalar(
907 "SELECT session_count FROM messages WHERE id IN (?, ?) ORDER BY id",
908 )
909 .bind(id1)
910 .bind(id2)
911 .fetch_all(memory_arc.sqlite().pool())
912 .await
913 .unwrap();
914 assert_eq!(
915 counts,
916 vec![1, 1],
917 "session_count must be 1 after first restore"
918 );
919 }
920
921 #[tokio::test]
922 async fn load_history_does_not_increment_session_count_for_new_conversation() {
923 let provider = mock_provider(vec![]);
924 let channel = MockChannel::new(vec![]);
925 let registry = create_test_registry();
926 let executor = MockToolExecutor::no_tools();
927
928 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
929 let cid = memory.sqlite().create_conversation().await.unwrap();
930
931 let memory_arc = std::sync::Arc::new(memory);
933 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
934 memory_arc.clone(),
935 cid,
936 50,
937 5,
938 100,
939 );
940
941 agent.load_history().await.unwrap();
942
943 let counts: Vec<i64> =
945 zeph_db::query_scalar("SELECT session_count FROM messages WHERE conversation_id = ?")
946 .bind(cid)
947 .fetch_all(memory_arc.sqlite().pool())
948 .await
949 .unwrap();
950 assert!(counts.is_empty(), "new conversation must have no messages");
951 }
952
953 #[tokio::test]
954 async fn persist_message_without_memory_silently_returns() {
955 let provider = mock_provider(vec![]);
957 let channel = MockChannel::new(vec![]);
958 let registry = create_test_registry();
959 let executor = MockToolExecutor::no_tools();
960 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
961
962 agent.persist_message(Role::User, "hello", &[], false).await;
964 }
965
966 #[tokio::test]
967 async fn persist_message_assistant_autosave_false_uses_save_only() {
968 let provider = mock_provider(vec![]);
969 let channel = MockChannel::new(vec![]);
970 let registry = create_test_registry();
971 let executor = MockToolExecutor::no_tools();
972
973 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
974 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
975 let cid = memory.sqlite().create_conversation().await.unwrap();
976
977 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
978 .with_metrics(tx)
979 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
980 .with_autosave_config(false, 20);
981
982 agent
983 .persist_message(Role::Assistant, "short assistant reply", &[], false)
984 .await;
985
986 let history = agent
987 .memory_state
988 .memory
989 .as_ref()
990 .unwrap()
991 .sqlite()
992 .load_history(cid, 50)
993 .await
994 .unwrap();
995 assert_eq!(history.len(), 1, "message must be saved");
996 assert_eq!(history[0].content, "short assistant reply");
997 assert_eq!(rx.borrow().embeddings_generated, 0);
999 }
1000
1001 #[tokio::test]
1002 async fn persist_message_assistant_below_min_length_uses_save_only() {
1003 let provider = mock_provider(vec![]);
1004 let channel = MockChannel::new(vec![]);
1005 let registry = create_test_registry();
1006 let executor = MockToolExecutor::no_tools();
1007
1008 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1009 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1010 let cid = memory.sqlite().create_conversation().await.unwrap();
1011
1012 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1014 .with_metrics(tx)
1015 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1016 .with_autosave_config(true, 1000);
1017
1018 agent
1019 .persist_message(Role::Assistant, "too short", &[], false)
1020 .await;
1021
1022 let history = agent
1023 .memory_state
1024 .memory
1025 .as_ref()
1026 .unwrap()
1027 .sqlite()
1028 .load_history(cid, 50)
1029 .await
1030 .unwrap();
1031 assert_eq!(history.len(), 1, "message must be saved");
1032 assert_eq!(history[0].content, "too short");
1033 assert_eq!(rx.borrow().embeddings_generated, 0);
1034 }
1035
1036 #[tokio::test]
1037 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
1038 let provider = mock_provider(vec![]);
1040 let channel = MockChannel::new(vec![]);
1041 let registry = create_test_registry();
1042 let executor = MockToolExecutor::no_tools();
1043
1044 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1045 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1046 let cid = memory.sqlite().create_conversation().await.unwrap();
1047
1048 let min_length = 10usize;
1049 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1050 .with_metrics(tx)
1051 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1052 .with_autosave_config(true, min_length);
1053
1054 let content_at_boundary = "A".repeat(min_length);
1056 assert_eq!(content_at_boundary.len(), min_length);
1057 agent
1058 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
1059 .await;
1060
1061 assert_eq!(rx.borrow().sqlite_message_count, 1);
1063 }
1064
1065 #[tokio::test]
1066 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
1067 let provider = mock_provider(vec![]);
1069 let channel = MockChannel::new(vec![]);
1070 let registry = create_test_registry();
1071 let executor = MockToolExecutor::no_tools();
1072
1073 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1074 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1075 let cid = memory.sqlite().create_conversation().await.unwrap();
1076
1077 let min_length = 10usize;
1078 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1079 .with_metrics(tx)
1080 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1081 .with_autosave_config(true, min_length);
1082
1083 let content_below_boundary = "A".repeat(min_length - 1);
1085 assert_eq!(content_below_boundary.len(), min_length - 1);
1086 agent
1087 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
1088 .await;
1089
1090 let history = agent
1091 .memory_state
1092 .memory
1093 .as_ref()
1094 .unwrap()
1095 .sqlite()
1096 .load_history(cid, 50)
1097 .await
1098 .unwrap();
1099 assert_eq!(history.len(), 1, "message must still be saved");
1100 assert_eq!(rx.borrow().embeddings_generated, 0);
1102 }
1103
1104 #[tokio::test]
1105 async fn persist_message_increments_unsummarized_count() {
1106 let provider = mock_provider(vec![]);
1107 let channel = MockChannel::new(vec![]);
1108 let registry = create_test_registry();
1109 let executor = MockToolExecutor::no_tools();
1110
1111 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1112 let cid = memory.sqlite().create_conversation().await.unwrap();
1113
1114 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1116 std::sync::Arc::new(memory),
1117 cid,
1118 50,
1119 5,
1120 100,
1121 );
1122
1123 assert_eq!(agent.memory_state.unsummarized_count, 0);
1124
1125 agent.persist_message(Role::User, "first", &[], false).await;
1126 assert_eq!(agent.memory_state.unsummarized_count, 1);
1127
1128 agent
1129 .persist_message(Role::User, "second", &[], false)
1130 .await;
1131 assert_eq!(agent.memory_state.unsummarized_count, 2);
1132 }
1133
1134 #[tokio::test]
1135 async fn check_summarization_resets_counter_on_success() {
1136 let provider = mock_provider(vec![]);
1137 let channel = MockChannel::new(vec![]);
1138 let registry = create_test_registry();
1139 let executor = MockToolExecutor::no_tools();
1140
1141 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1142 let cid = memory.sqlite().create_conversation().await.unwrap();
1143
1144 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1146 std::sync::Arc::new(memory),
1147 cid,
1148 50,
1149 5,
1150 1,
1151 );
1152
1153 agent.persist_message(Role::User, "msg1", &[], false).await;
1154 agent.persist_message(Role::User, "msg2", &[], false).await;
1155
1156 assert!(agent.memory_state.unsummarized_count <= 2);
1161 }
1162
1163 #[tokio::test]
1164 async fn unsummarized_count_not_incremented_without_memory() {
1165 let provider = mock_provider(vec![]);
1166 let channel = MockChannel::new(vec![]);
1167 let registry = create_test_registry();
1168 let executor = MockToolExecutor::no_tools();
1169 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1170
1171 agent.persist_message(Role::User, "hello", &[], false).await;
1172 assert_eq!(agent.memory_state.unsummarized_count, 0);
1174 }
1175
1176 mod graph_extraction_guards {
1178 use super::*;
1179 use crate::config::GraphConfig;
1180 use zeph_llm::provider::MessageMetadata;
1181 use zeph_memory::graph::GraphStore;
1182
1183 fn enabled_graph_config() -> GraphConfig {
1184 GraphConfig {
1185 enabled: true,
1186 ..GraphConfig::default()
1187 }
1188 }
1189
1190 async fn agent_with_graph(
1191 provider: &AnyProvider,
1192 config: GraphConfig,
1193 ) -> Agent<MockChannel> {
1194 let memory =
1195 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1196 let cid = memory.sqlite().create_conversation().await.unwrap();
1197 Agent::new(
1198 provider.clone(),
1199 MockChannel::new(vec![]),
1200 create_test_registry(),
1201 None,
1202 5,
1203 MockToolExecutor::no_tools(),
1204 )
1205 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1206 .with_graph_config(config)
1207 }
1208
1209 #[tokio::test]
1210 async fn injection_flag_guard_skips_extraction() {
1211 let provider = mock_provider(vec![]);
1213 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1214 let pool = agent
1215 .memory_state
1216 .memory
1217 .as_ref()
1218 .unwrap()
1219 .sqlite()
1220 .pool()
1221 .clone();
1222
1223 agent
1224 .maybe_spawn_graph_extraction("I use Rust", true, false)
1225 .await;
1226
1227 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1229
1230 let store = GraphStore::new(pool);
1231 let count = store.get_metadata("extraction_count").await.unwrap();
1232 assert!(
1233 count.is_none(),
1234 "injection flag must prevent extraction_count from being written"
1235 );
1236 }
1237
1238 #[tokio::test]
1239 async fn disabled_config_guard_skips_extraction() {
1240 let provider = mock_provider(vec![]);
1242 let disabled_cfg = GraphConfig {
1243 enabled: false,
1244 ..GraphConfig::default()
1245 };
1246 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
1247 let pool = agent
1248 .memory_state
1249 .memory
1250 .as_ref()
1251 .unwrap()
1252 .sqlite()
1253 .pool()
1254 .clone();
1255
1256 agent
1257 .maybe_spawn_graph_extraction("I use Rust", false, false)
1258 .await;
1259
1260 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1261
1262 let store = GraphStore::new(pool);
1263 let count = store.get_metadata("extraction_count").await.unwrap();
1264 assert!(
1265 count.is_none(),
1266 "disabled graph config must prevent extraction"
1267 );
1268 }
1269
1270 #[tokio::test]
1271 async fn happy_path_fires_extraction() {
1272 let provider = mock_provider(vec![]);
1275 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1276 let pool = agent
1277 .memory_state
1278 .memory
1279 .as_ref()
1280 .unwrap()
1281 .sqlite()
1282 .pool()
1283 .clone();
1284
1285 agent
1286 .maybe_spawn_graph_extraction("I use Rust for systems programming", false, false)
1287 .await;
1288
1289 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1291
1292 let store = GraphStore::new(pool);
1293 let count = store.get_metadata("extraction_count").await.unwrap();
1294 assert!(
1295 count.is_some(),
1296 "happy-path extraction must increment extraction_count"
1297 );
1298 }
1299
1300 #[tokio::test]
1301 async fn tool_result_parts_guard_skips_extraction() {
1302 let provider = mock_provider(vec![]);
1306 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1307 let pool = agent
1308 .memory_state
1309 .memory
1310 .as_ref()
1311 .unwrap()
1312 .sqlite()
1313 .pool()
1314 .clone();
1315
1316 agent
1317 .maybe_spawn_graph_extraction(
1318 "[tool_result: abc123]\nprovider_type = \"claude\"\nallowed_commands = []",
1319 false,
1320 true, )
1322 .await;
1323
1324 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1325
1326 let store = GraphStore::new(pool);
1327 let count = store.get_metadata("extraction_count").await.unwrap();
1328 assert!(
1329 count.is_none(),
1330 "tool result message must not trigger graph extraction"
1331 );
1332 }
1333
1334 #[tokio::test]
1335 async fn context_filter_excludes_tool_result_messages() {
1336 let provider = mock_provider(vec![]);
1347 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1348
1349 agent.msg.messages.push(Message {
1352 role: Role::User,
1353 content: "[tool_result: abc]\nprovider_type = \"openai\"".to_owned(),
1354 parts: vec![MessagePart::ToolResult {
1355 tool_use_id: "abc".to_owned(),
1356 content: "provider_type = \"openai\"".to_owned(),
1357 is_error: false,
1358 }],
1359 metadata: MessageMetadata::default(),
1360 });
1361
1362 let pool = agent
1363 .memory_state
1364 .memory
1365 .as_ref()
1366 .unwrap()
1367 .sqlite()
1368 .pool()
1369 .clone();
1370
1371 agent
1373 .maybe_spawn_graph_extraction("I prefer Rust for systems programming", false, false)
1374 .await;
1375
1376 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1377
1378 let store = GraphStore::new(pool);
1380 let count = store.get_metadata("extraction_count").await.unwrap();
1381 assert!(
1382 count.is_some(),
1383 "conversational message must trigger extraction even with prior tool result in history"
1384 );
1385 }
1386 }
1387
1388 #[tokio::test]
1389 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
1390 let provider = mock_provider(vec![]);
1391 let channel = MockChannel::new(vec![]);
1392 let registry = create_test_registry();
1393 let executor = MockToolExecutor::no_tools();
1394
1395 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1396 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1397 let cid = memory.sqlite().create_conversation().await.unwrap();
1398
1399 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1401 .with_metrics(tx)
1402 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1403 .with_autosave_config(false, 20);
1404
1405 let long_user_msg = "A".repeat(100);
1406 agent
1407 .persist_message(Role::User, &long_user_msg, &[], false)
1408 .await;
1409
1410 let history = agent
1411 .memory_state
1412 .memory
1413 .as_ref()
1414 .unwrap()
1415 .sqlite()
1416 .load_history(cid, 50)
1417 .await
1418 .unwrap();
1419 assert_eq!(history.len(), 1, "user message must be saved");
1420 assert_eq!(rx.borrow().sqlite_message_count, 1);
1423 }
1424
1425 #[tokio::test]
1429 async fn persist_message_saves_correct_tool_use_parts() {
1430 use zeph_llm::provider::MessagePart;
1431
1432 let provider = mock_provider(vec![]);
1433 let channel = MockChannel::new(vec![]);
1434 let registry = create_test_registry();
1435 let executor = MockToolExecutor::no_tools();
1436
1437 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1438 let cid = memory.sqlite().create_conversation().await.unwrap();
1439
1440 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1441 std::sync::Arc::new(memory),
1442 cid,
1443 50,
1444 5,
1445 100,
1446 );
1447
1448 let parts = vec![MessagePart::ToolUse {
1449 id: "call_abc123".to_string(),
1450 name: "read_file".to_string(),
1451 input: serde_json::json!({"path": "/tmp/test.txt"}),
1452 }];
1453 let content = "[tool_use: read_file(call_abc123)]";
1454
1455 agent
1456 .persist_message(Role::Assistant, content, &parts, false)
1457 .await;
1458
1459 let history = agent
1460 .memory_state
1461 .memory
1462 .as_ref()
1463 .unwrap()
1464 .sqlite()
1465 .load_history(cid, 50)
1466 .await
1467 .unwrap();
1468
1469 assert_eq!(history.len(), 1);
1470 assert_eq!(history[0].role, Role::Assistant);
1471 assert_eq!(history[0].content, content);
1472 assert_eq!(history[0].parts.len(), 1);
1473 match &history[0].parts[0] {
1474 MessagePart::ToolUse { id, name, .. } => {
1475 assert_eq!(id, "call_abc123");
1476 assert_eq!(name, "read_file");
1477 }
1478 other => panic!("expected ToolUse part, got {other:?}"),
1479 }
1480 assert!(
1482 !history[0]
1483 .parts
1484 .iter()
1485 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1486 "assistant message must not contain ToolResult parts"
1487 );
1488 }
1489
1490 #[tokio::test]
1491 async fn persist_message_saves_correct_tool_result_parts() {
1492 use zeph_llm::provider::MessagePart;
1493
1494 let provider = mock_provider(vec![]);
1495 let channel = MockChannel::new(vec![]);
1496 let registry = create_test_registry();
1497 let executor = MockToolExecutor::no_tools();
1498
1499 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1500 let cid = memory.sqlite().create_conversation().await.unwrap();
1501
1502 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1503 std::sync::Arc::new(memory),
1504 cid,
1505 50,
1506 5,
1507 100,
1508 );
1509
1510 let parts = vec![MessagePart::ToolResult {
1511 tool_use_id: "call_abc123".to_string(),
1512 content: "file contents here".to_string(),
1513 is_error: false,
1514 }];
1515 let content = "[tool_result: call_abc123]\nfile contents here";
1516
1517 agent
1518 .persist_message(Role::User, content, &parts, false)
1519 .await;
1520
1521 let history = agent
1522 .memory_state
1523 .memory
1524 .as_ref()
1525 .unwrap()
1526 .sqlite()
1527 .load_history(cid, 50)
1528 .await
1529 .unwrap();
1530
1531 assert_eq!(history.len(), 1);
1532 assert_eq!(history[0].role, Role::User);
1533 assert_eq!(history[0].content, content);
1534 assert_eq!(history[0].parts.len(), 1);
1535 match &history[0].parts[0] {
1536 MessagePart::ToolResult {
1537 tool_use_id,
1538 content: result_content,
1539 is_error,
1540 } => {
1541 assert_eq!(tool_use_id, "call_abc123");
1542 assert_eq!(result_content, "file contents here");
1543 assert!(!is_error);
1544 }
1545 other => panic!("expected ToolResult part, got {other:?}"),
1546 }
1547 assert!(
1549 !history[0]
1550 .parts
1551 .iter()
1552 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1553 "user ToolResult message must not contain ToolUse parts"
1554 );
1555 }
1556
1557 #[tokio::test]
1558 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1559 use zeph_llm::provider::MessagePart;
1560
1561 let provider = mock_provider(vec![]);
1562 let channel = MockChannel::new(vec![]);
1563 let registry = create_test_registry();
1564 let executor = MockToolExecutor::no_tools();
1565
1566 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1567 let cid = memory.sqlite().create_conversation().await.unwrap();
1568
1569 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1570 std::sync::Arc::new(memory),
1571 cid,
1572 50,
1573 5,
1574 100,
1575 );
1576
1577 let assistant_parts = vec![MessagePart::ToolUse {
1579 id: "id_1".to_string(),
1580 name: "list_dir".to_string(),
1581 input: serde_json::json!({"path": "/tmp"}),
1582 }];
1583 agent
1584 .persist_message(
1585 Role::Assistant,
1586 "[tool_use: list_dir(id_1)]",
1587 &assistant_parts,
1588 false,
1589 )
1590 .await;
1591
1592 let user_parts = vec![MessagePart::ToolResult {
1594 tool_use_id: "id_1".to_string(),
1595 content: "file1.txt\nfile2.txt".to_string(),
1596 is_error: false,
1597 }];
1598 agent
1599 .persist_message(
1600 Role::User,
1601 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1602 &user_parts,
1603 false,
1604 )
1605 .await;
1606
1607 let history = agent
1608 .memory_state
1609 .memory
1610 .as_ref()
1611 .unwrap()
1612 .sqlite()
1613 .load_history(cid, 50)
1614 .await
1615 .unwrap();
1616
1617 assert_eq!(history.len(), 2);
1618
1619 assert_eq!(history[0].role, Role::Assistant);
1621 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1622 assert!(
1623 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1624 "first message must be assistant ToolUse"
1625 );
1626
1627 assert_eq!(history[1].role, Role::User);
1629 assert_eq!(
1630 history[1].content,
1631 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1632 );
1633 assert!(
1634 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1635 "second message must be user ToolResult"
1636 );
1637
1638 assert!(
1640 !history[0]
1641 .parts
1642 .iter()
1643 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1644 "assistant message must not have ToolResult parts"
1645 );
1646 assert!(
1647 !history[1]
1648 .parts
1649 .iter()
1650 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1651 "user message must not have ToolUse parts"
1652 );
1653 }
1654
1655 #[tokio::test]
1656 async fn persist_message_saves_correct_tool_output_parts() {
1657 use zeph_llm::provider::MessagePart;
1658
1659 let provider = mock_provider(vec![]);
1660 let channel = MockChannel::new(vec![]);
1661 let registry = create_test_registry();
1662 let executor = MockToolExecutor::no_tools();
1663
1664 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1665 let cid = memory.sqlite().create_conversation().await.unwrap();
1666
1667 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1668 std::sync::Arc::new(memory),
1669 cid,
1670 50,
1671 5,
1672 100,
1673 );
1674
1675 let parts = vec![MessagePart::ToolOutput {
1676 tool_name: "shell".to_string(),
1677 body: "hello from shell".to_string(),
1678 compacted_at: None,
1679 }];
1680 let content = "[tool: shell]\nhello from shell";
1681
1682 agent
1683 .persist_message(Role::User, content, &parts, false)
1684 .await;
1685
1686 let history = agent
1687 .memory_state
1688 .memory
1689 .as_ref()
1690 .unwrap()
1691 .sqlite()
1692 .load_history(cid, 50)
1693 .await
1694 .unwrap();
1695
1696 assert_eq!(history.len(), 1);
1697 assert_eq!(history[0].role, Role::User);
1698 assert_eq!(history[0].content, content);
1699 assert_eq!(history[0].parts.len(), 1);
1700 match &history[0].parts[0] {
1701 MessagePart::ToolOutput {
1702 tool_name,
1703 body,
1704 compacted_at,
1705 } => {
1706 assert_eq!(tool_name, "shell");
1707 assert_eq!(body, "hello from shell");
1708 assert!(compacted_at.is_none());
1709 }
1710 other => panic!("expected ToolOutput part, got {other:?}"),
1711 }
1712 }
1713
1714 #[tokio::test]
1717 async fn load_history_removes_trailing_orphan_tool_use() {
1718 use zeph_llm::provider::MessagePart;
1719
1720 let provider = mock_provider(vec![]);
1721 let channel = MockChannel::new(vec![]);
1722 let registry = create_test_registry();
1723 let executor = MockToolExecutor::no_tools();
1724
1725 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1726 let cid = memory.sqlite().create_conversation().await.unwrap();
1727 let sqlite = memory.sqlite();
1728
1729 sqlite
1731 .save_message(cid, "user", "do something with a tool")
1732 .await
1733 .unwrap();
1734
1735 let parts = serde_json::to_string(&[MessagePart::ToolUse {
1737 id: "call_orphan".to_string(),
1738 name: "shell".to_string(),
1739 input: serde_json::json!({"command": "ls"}),
1740 }])
1741 .unwrap();
1742 sqlite
1743 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
1744 .await
1745 .unwrap();
1746
1747 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1748 std::sync::Arc::new(memory),
1749 cid,
1750 50,
1751 5,
1752 100,
1753 );
1754
1755 let messages_before = agent.msg.messages.len();
1756 agent.load_history().await.unwrap();
1757
1758 assert_eq!(
1760 agent.msg.messages.len(),
1761 messages_before + 1,
1762 "orphaned trailing tool_use must be removed"
1763 );
1764 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
1765 }
1766
1767 #[tokio::test]
1768 async fn load_history_removes_leading_orphan_tool_result() {
1769 use zeph_llm::provider::MessagePart;
1770
1771 let provider = mock_provider(vec![]);
1772 let channel = MockChannel::new(vec![]);
1773 let registry = create_test_registry();
1774 let executor = MockToolExecutor::no_tools();
1775
1776 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1777 let cid = memory.sqlite().create_conversation().await.unwrap();
1778 let sqlite = memory.sqlite();
1779
1780 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1782 tool_use_id: "call_missing".to_string(),
1783 content: "result data".to_string(),
1784 is_error: false,
1785 }])
1786 .unwrap();
1787 sqlite
1788 .save_message_with_parts(
1789 cid,
1790 "user",
1791 "[tool_result: call_missing]\nresult data",
1792 &result_parts,
1793 )
1794 .await
1795 .unwrap();
1796
1797 sqlite
1799 .save_message(cid, "assistant", "here is my response")
1800 .await
1801 .unwrap();
1802
1803 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1804 std::sync::Arc::new(memory),
1805 cid,
1806 50,
1807 5,
1808 100,
1809 );
1810
1811 let messages_before = agent.msg.messages.len();
1812 agent.load_history().await.unwrap();
1813
1814 assert_eq!(
1816 agent.msg.messages.len(),
1817 messages_before + 1,
1818 "orphaned leading tool_result must be removed"
1819 );
1820 assert_eq!(agent.msg.messages.last().unwrap().role, Role::Assistant);
1821 }
1822
1823 #[tokio::test]
1824 async fn load_history_preserves_complete_tool_pairs() {
1825 use zeph_llm::provider::MessagePart;
1826
1827 let provider = mock_provider(vec![]);
1828 let channel = MockChannel::new(vec![]);
1829 let registry = create_test_registry();
1830 let executor = MockToolExecutor::no_tools();
1831
1832 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1833 let cid = memory.sqlite().create_conversation().await.unwrap();
1834 let sqlite = memory.sqlite();
1835
1836 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1838 id: "call_ok".to_string(),
1839 name: "shell".to_string(),
1840 input: serde_json::json!({"command": "pwd"}),
1841 }])
1842 .unwrap();
1843 sqlite
1844 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
1845 .await
1846 .unwrap();
1847
1848 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1849 tool_use_id: "call_ok".to_string(),
1850 content: "/home/user".to_string(),
1851 is_error: false,
1852 }])
1853 .unwrap();
1854 sqlite
1855 .save_message_with_parts(
1856 cid,
1857 "user",
1858 "[tool_result: call_ok]\n/home/user",
1859 &result_parts,
1860 )
1861 .await
1862 .unwrap();
1863
1864 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1865 std::sync::Arc::new(memory),
1866 cid,
1867 50,
1868 5,
1869 100,
1870 );
1871
1872 let messages_before = agent.msg.messages.len();
1873 agent.load_history().await.unwrap();
1874
1875 assert_eq!(
1877 agent.msg.messages.len(),
1878 messages_before + 2,
1879 "complete tool_use/tool_result pair must be preserved"
1880 );
1881 assert_eq!(agent.msg.messages[messages_before].role, Role::Assistant);
1882 assert_eq!(agent.msg.messages[messages_before + 1].role, Role::User);
1883 }
1884
1885 #[tokio::test]
1886 async fn load_history_handles_multiple_trailing_orphans() {
1887 use zeph_llm::provider::MessagePart;
1888
1889 let provider = mock_provider(vec![]);
1890 let channel = MockChannel::new(vec![]);
1891 let registry = create_test_registry();
1892 let executor = MockToolExecutor::no_tools();
1893
1894 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1895 let cid = memory.sqlite().create_conversation().await.unwrap();
1896 let sqlite = memory.sqlite();
1897
1898 sqlite.save_message(cid, "user", "start").await.unwrap();
1900
1901 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
1903 id: "call_1".to_string(),
1904 name: "shell".to_string(),
1905 input: serde_json::json!({}),
1906 }])
1907 .unwrap();
1908 sqlite
1909 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
1910 .await
1911 .unwrap();
1912
1913 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
1915 id: "call_2".to_string(),
1916 name: "read_file".to_string(),
1917 input: serde_json::json!({}),
1918 }])
1919 .unwrap();
1920 sqlite
1921 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
1922 .await
1923 .unwrap();
1924
1925 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1926 std::sync::Arc::new(memory),
1927 cid,
1928 50,
1929 5,
1930 100,
1931 );
1932
1933 let messages_before = agent.msg.messages.len();
1934 agent.load_history().await.unwrap();
1935
1936 assert_eq!(
1938 agent.msg.messages.len(),
1939 messages_before + 1,
1940 "all trailing orphaned tool_use messages must be removed"
1941 );
1942 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
1943 }
1944
1945 #[tokio::test]
1946 async fn load_history_no_tool_messages_unchanged() {
1947 let provider = mock_provider(vec![]);
1948 let channel = MockChannel::new(vec![]);
1949 let registry = create_test_registry();
1950 let executor = MockToolExecutor::no_tools();
1951
1952 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1953 let cid = memory.sqlite().create_conversation().await.unwrap();
1954 let sqlite = memory.sqlite();
1955
1956 sqlite.save_message(cid, "user", "hello").await.unwrap();
1957 sqlite
1958 .save_message(cid, "assistant", "hi there")
1959 .await
1960 .unwrap();
1961 sqlite
1962 .save_message(cid, "user", "how are you?")
1963 .await
1964 .unwrap();
1965
1966 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1967 std::sync::Arc::new(memory),
1968 cid,
1969 50,
1970 5,
1971 100,
1972 );
1973
1974 let messages_before = agent.msg.messages.len();
1975 agent.load_history().await.unwrap();
1976
1977 assert_eq!(
1979 agent.msg.messages.len(),
1980 messages_before + 3,
1981 "plain messages without tool parts must pass through unchanged"
1982 );
1983 }
1984
1985 #[tokio::test]
1986 async fn load_history_removes_both_leading_and_trailing_orphans() {
1987 use zeph_llm::provider::MessagePart;
1988
1989 let provider = mock_provider(vec![]);
1990 let channel = MockChannel::new(vec![]);
1991 let registry = create_test_registry();
1992 let executor = MockToolExecutor::no_tools();
1993
1994 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1995 let cid = memory.sqlite().create_conversation().await.unwrap();
1996 let sqlite = memory.sqlite();
1997
1998 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2000 tool_use_id: "call_leading".to_string(),
2001 content: "orphaned result".to_string(),
2002 is_error: false,
2003 }])
2004 .unwrap();
2005 sqlite
2006 .save_message_with_parts(
2007 cid,
2008 "user",
2009 "[tool_result: call_leading]\norphaned result",
2010 &result_parts,
2011 )
2012 .await
2013 .unwrap();
2014
2015 sqlite
2017 .save_message(cid, "user", "what is 2+2?")
2018 .await
2019 .unwrap();
2020 sqlite.save_message(cid, "assistant", "4").await.unwrap();
2021
2022 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2024 id: "call_trailing".to_string(),
2025 name: "shell".to_string(),
2026 input: serde_json::json!({"command": "date"}),
2027 }])
2028 .unwrap();
2029 sqlite
2030 .save_message_with_parts(
2031 cid,
2032 "assistant",
2033 "[tool_use: shell(call_trailing)]",
2034 &use_parts,
2035 )
2036 .await
2037 .unwrap();
2038
2039 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2040 std::sync::Arc::new(memory),
2041 cid,
2042 50,
2043 5,
2044 100,
2045 );
2046
2047 let messages_before = agent.msg.messages.len();
2048 agent.load_history().await.unwrap();
2049
2050 assert_eq!(
2052 agent.msg.messages.len(),
2053 messages_before + 2,
2054 "both leading and trailing orphans must be removed"
2055 );
2056 assert_eq!(agent.msg.messages[messages_before].role, Role::User);
2057 assert_eq!(agent.msg.messages[messages_before].content, "what is 2+2?");
2058 assert_eq!(
2059 agent.msg.messages[messages_before + 1].role,
2060 Role::Assistant
2061 );
2062 assert_eq!(agent.msg.messages[messages_before + 1].content, "4");
2063 }
2064
2065 #[tokio::test]
2070 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
2071 use zeph_llm::provider::MessagePart;
2072
2073 let provider = mock_provider(vec![]);
2074 let channel = MockChannel::new(vec![]);
2075 let registry = create_test_registry();
2076 let executor = MockToolExecutor::no_tools();
2077
2078 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2079 let cid = memory.sqlite().create_conversation().await.unwrap();
2080 let sqlite = memory.sqlite();
2081
2082 sqlite
2084 .save_message(cid, "user", "first question")
2085 .await
2086 .unwrap();
2087 sqlite
2088 .save_message(cid, "assistant", "first answer")
2089 .await
2090 .unwrap();
2091
2092 let use_parts = serde_json::to_string(&[
2096 MessagePart::ToolUse {
2097 id: "call_mid_1".to_string(),
2098 name: "shell".to_string(),
2099 input: serde_json::json!({"command": "ls"}),
2100 },
2101 MessagePart::Text {
2102 text: "Let me check the files.".to_string(),
2103 },
2104 ])
2105 .unwrap();
2106 sqlite
2107 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
2108 .await
2109 .unwrap();
2110
2111 sqlite
2113 .save_message(cid, "user", "second question")
2114 .await
2115 .unwrap();
2116 sqlite
2117 .save_message(cid, "assistant", "second answer")
2118 .await
2119 .unwrap();
2120
2121 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2122 std::sync::Arc::new(memory),
2123 cid,
2124 50,
2125 5,
2126 100,
2127 );
2128
2129 let messages_before = agent.msg.messages.len();
2130 agent.load_history().await.unwrap();
2131
2132 assert_eq!(
2135 agent.msg.messages.len(),
2136 messages_before + 5,
2137 "message count must be 5 (orphan message kept — has text content)"
2138 );
2139
2140 let orphan = &agent.msg.messages[messages_before + 2];
2142 assert_eq!(orphan.role, Role::Assistant);
2143 assert!(
2144 !orphan
2145 .parts
2146 .iter()
2147 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
2148 "orphaned ToolUse parts must be stripped from mid-history message"
2149 );
2150 assert!(
2152 orphan.parts.iter().any(
2153 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
2154 ),
2155 "text content of orphaned assistant message must be preserved"
2156 );
2157 }
2158
2159 #[tokio::test]
2164 async fn load_history_keeps_tool_only_user_message() {
2165 use zeph_llm::provider::MessagePart;
2166
2167 let provider = mock_provider(vec![]);
2168 let channel = MockChannel::new(vec![]);
2169 let registry = create_test_registry();
2170 let executor = MockToolExecutor::no_tools();
2171
2172 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2173 let cid = memory.sqlite().create_conversation().await.unwrap();
2174 let sqlite = memory.sqlite();
2175
2176 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2178 id: "call_rc3".to_string(),
2179 name: "memory_save".to_string(),
2180 input: serde_json::json!({"content": "something"}),
2181 }])
2182 .unwrap();
2183 sqlite
2184 .save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
2185 .await
2186 .unwrap();
2187
2188 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2190 tool_use_id: "call_rc3".to_string(),
2191 content: "saved".to_string(),
2192 is_error: false,
2193 }])
2194 .unwrap();
2195 sqlite
2196 .save_message_with_parts(cid, "user", "", &result_parts)
2197 .await
2198 .unwrap();
2199
2200 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2201
2202 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2203 std::sync::Arc::new(memory),
2204 cid,
2205 50,
2206 5,
2207 100,
2208 );
2209
2210 let messages_before = agent.msg.messages.len();
2211 agent.load_history().await.unwrap();
2212
2213 assert_eq!(
2216 agent.msg.messages.len(),
2217 messages_before + 3,
2218 "user message with empty content but ToolResult parts must not be dropped"
2219 );
2220
2221 let user_msg = &agent.msg.messages[messages_before + 1];
2223 assert_eq!(user_msg.role, Role::User);
2224 assert!(
2225 user_msg.parts.iter().any(
2226 |p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
2227 ),
2228 "ToolResult part must be preserved on user message with empty content"
2229 );
2230 }
2231
2232 #[tokio::test]
2236 async fn strip_orphans_removes_orphaned_tool_result() {
2237 use zeph_llm::provider::MessagePart;
2238
2239 let provider = mock_provider(vec![]);
2240 let channel = MockChannel::new(vec![]);
2241 let registry = create_test_registry();
2242 let executor = MockToolExecutor::no_tools();
2243
2244 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2245 let cid = memory.sqlite().create_conversation().await.unwrap();
2246 let sqlite = memory.sqlite();
2247
2248 sqlite.save_message(cid, "user", "hello").await.unwrap();
2250 sqlite.save_message(cid, "assistant", "hi").await.unwrap();
2251
2252 sqlite
2254 .save_message(cid, "assistant", "plain answer")
2255 .await
2256 .unwrap();
2257
2258 let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2260 tool_use_id: "call_nonexistent".to_string(),
2261 content: "stale result".to_string(),
2262 is_error: false,
2263 }])
2264 .unwrap();
2265 sqlite
2266 .save_message_with_parts(
2267 cid,
2268 "user",
2269 "[tool_result: call_nonexistent]\nstale result",
2270 &orphan_result_parts,
2271 )
2272 .await
2273 .unwrap();
2274
2275 sqlite
2276 .save_message(cid, "assistant", "final")
2277 .await
2278 .unwrap();
2279
2280 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2281 std::sync::Arc::new(memory),
2282 cid,
2283 50,
2284 5,
2285 100,
2286 );
2287
2288 let messages_before = agent.msg.messages.len();
2289 agent.load_history().await.unwrap();
2290
2291 let loaded = &agent.msg.messages[messages_before..];
2295 for msg in loaded {
2296 assert!(
2297 !msg.parts.iter().any(|p| matches!(
2298 p,
2299 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
2300 )),
2301 "orphaned ToolResult part must be stripped from history"
2302 );
2303 }
2304 }
2305
2306 #[tokio::test]
2309 async fn strip_orphans_keeps_complete_pair() {
2310 use zeph_llm::provider::MessagePart;
2311
2312 let provider = mock_provider(vec![]);
2313 let channel = MockChannel::new(vec![]);
2314 let registry = create_test_registry();
2315 let executor = MockToolExecutor::no_tools();
2316
2317 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2318 let cid = memory.sqlite().create_conversation().await.unwrap();
2319 let sqlite = memory.sqlite();
2320
2321 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2322 id: "call_valid".to_string(),
2323 name: "shell".to_string(),
2324 input: serde_json::json!({"command": "ls"}),
2325 }])
2326 .unwrap();
2327 sqlite
2328 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2329 .await
2330 .unwrap();
2331
2332 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2333 tool_use_id: "call_valid".to_string(),
2334 content: "file.rs".to_string(),
2335 is_error: false,
2336 }])
2337 .unwrap();
2338 sqlite
2339 .save_message_with_parts(cid, "user", "", &result_parts)
2340 .await
2341 .unwrap();
2342
2343 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2344 std::sync::Arc::new(memory),
2345 cid,
2346 50,
2347 5,
2348 100,
2349 );
2350
2351 let messages_before = agent.msg.messages.len();
2352 agent.load_history().await.unwrap();
2353
2354 assert_eq!(
2355 agent.msg.messages.len(),
2356 messages_before + 2,
2357 "complete tool_use/tool_result pair must be preserved"
2358 );
2359
2360 let user_msg = &agent.msg.messages[messages_before + 1];
2361 assert!(
2362 user_msg.parts.iter().any(|p| matches!(
2363 p,
2364 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
2365 )),
2366 "ToolResult part for a matched tool_use must not be stripped"
2367 );
2368 }
2369
2370 #[tokio::test]
2373 async fn strip_orphans_mixed_history() {
2374 use zeph_llm::provider::MessagePart;
2375
2376 let provider = mock_provider(vec![]);
2377 let channel = MockChannel::new(vec![]);
2378 let registry = create_test_registry();
2379 let executor = MockToolExecutor::no_tools();
2380
2381 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2382 let cid = memory.sqlite().create_conversation().await.unwrap();
2383 let sqlite = memory.sqlite();
2384
2385 let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
2387 id: "call_good".to_string(),
2388 name: "shell".to_string(),
2389 input: serde_json::json!({"command": "pwd"}),
2390 }])
2391 .unwrap();
2392 sqlite
2393 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
2394 .await
2395 .unwrap();
2396
2397 let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
2398 tool_use_id: "call_good".to_string(),
2399 content: "/home".to_string(),
2400 is_error: false,
2401 }])
2402 .unwrap();
2403 sqlite
2404 .save_message_with_parts(cid, "user", "", &result_parts_ok)
2405 .await
2406 .unwrap();
2407
2408 sqlite
2410 .save_message(cid, "assistant", "text only")
2411 .await
2412 .unwrap();
2413
2414 let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
2415 tool_use_id: "call_ghost".to_string(),
2416 content: "ghost result".to_string(),
2417 is_error: false,
2418 }])
2419 .unwrap();
2420 sqlite
2421 .save_message_with_parts(
2422 cid,
2423 "user",
2424 "[tool_result: call_ghost]\nghost result",
2425 &orphan_parts,
2426 )
2427 .await
2428 .unwrap();
2429
2430 sqlite
2431 .save_message(cid, "assistant", "final reply")
2432 .await
2433 .unwrap();
2434
2435 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2436 std::sync::Arc::new(memory),
2437 cid,
2438 50,
2439 5,
2440 100,
2441 );
2442
2443 let messages_before = agent.msg.messages.len();
2444 agent.load_history().await.unwrap();
2445
2446 let loaded = &agent.msg.messages[messages_before..];
2447
2448 for msg in loaded {
2450 assert!(
2451 !msg.parts.iter().any(|p| matches!(
2452 p,
2453 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
2454 )),
2455 "orphaned ToolResult (call_ghost) must be stripped from history"
2456 );
2457 }
2458
2459 let has_good_result = loaded.iter().any(|msg| {
2462 msg.role == Role::User
2463 && msg.parts.iter().any(|p| {
2464 matches!(
2465 p,
2466 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
2467 )
2468 })
2469 });
2470 assert!(
2471 has_good_result,
2472 "matched ToolResult (call_good) must be preserved in history"
2473 );
2474 }
2475
2476 #[tokio::test]
2479 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
2480 use zeph_llm::provider::MessagePart;
2481
2482 let provider = mock_provider(vec![]);
2483 let channel = MockChannel::new(vec![]);
2484 let registry = create_test_registry();
2485 let executor = MockToolExecutor::no_tools();
2486
2487 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2488 let cid = memory.sqlite().create_conversation().await.unwrap();
2489 let sqlite = memory.sqlite();
2490
2491 sqlite
2492 .save_message(cid, "user", "run a command")
2493 .await
2494 .unwrap();
2495
2496 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2498 id: "call_ok".to_string(),
2499 name: "shell".to_string(),
2500 input: serde_json::json!({"command": "echo hi"}),
2501 }])
2502 .unwrap();
2503 sqlite
2504 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2505 .await
2506 .unwrap();
2507
2508 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2510 tool_use_id: "call_ok".to_string(),
2511 content: "hi".to_string(),
2512 is_error: false,
2513 }])
2514 .unwrap();
2515 sqlite
2516 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
2517 .await
2518 .unwrap();
2519
2520 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2521
2522 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2523 std::sync::Arc::new(memory),
2524 cid,
2525 50,
2526 5,
2527 100,
2528 );
2529
2530 let messages_before = agent.msg.messages.len();
2531 agent.load_history().await.unwrap();
2532
2533 assert_eq!(
2535 agent.msg.messages.len(),
2536 messages_before + 4,
2537 "matched tool pair must not be removed"
2538 );
2539 let tool_msg = &agent.msg.messages[messages_before + 1];
2540 assert!(
2541 tool_msg
2542 .parts
2543 .iter()
2544 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
2545 "matched ToolUse parts must be preserved"
2546 );
2547 }
2548
2549 #[tokio::test]
2553 async fn persist_cancelled_tool_results_pairs_tool_use() {
2554 use zeph_llm::provider::MessagePart;
2555
2556 let provider = mock_provider(vec![]);
2557 let channel = MockChannel::new(vec![]);
2558 let registry = create_test_registry();
2559 let executor = MockToolExecutor::no_tools();
2560
2561 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2562 let cid = memory.sqlite().create_conversation().await.unwrap();
2563
2564 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2565 std::sync::Arc::new(memory),
2566 cid,
2567 50,
2568 5,
2569 100,
2570 );
2571
2572 let tool_calls = vec![
2574 zeph_llm::provider::ToolUseRequest {
2575 id: "cancel_id_1".to_string(),
2576 name: "shell".to_string(),
2577 input: serde_json::json!({}),
2578 },
2579 zeph_llm::provider::ToolUseRequest {
2580 id: "cancel_id_2".to_string(),
2581 name: "read_file".to_string(),
2582 input: serde_json::json!({}),
2583 },
2584 ];
2585
2586 agent.persist_cancelled_tool_results(&tool_calls).await;
2587
2588 let history = agent
2589 .memory_state
2590 .memory
2591 .as_ref()
2592 .unwrap()
2593 .sqlite()
2594 .load_history(cid, 50)
2595 .await
2596 .unwrap();
2597
2598 assert_eq!(history.len(), 1);
2600 assert_eq!(history[0].role, Role::User);
2601
2602 for tc in &tool_calls {
2604 assert!(
2605 history[0].parts.iter().any(|p| matches!(
2606 p,
2607 MessagePart::ToolResult { tool_use_id, is_error, .. }
2608 if tool_use_id == &tc.id && *is_error
2609 )),
2610 "tombstone ToolResult for {} must be present and is_error=true",
2611 tc.id
2612 );
2613 }
2614 }
2615
2616 #[test]
2619 fn meaningful_content_empty_string() {
2620 assert!(!has_meaningful_content(""));
2621 }
2622
2623 #[test]
2624 fn meaningful_content_whitespace_only() {
2625 assert!(!has_meaningful_content(" \n\t "));
2626 }
2627
2628 #[test]
2629 fn meaningful_content_tool_use_only() {
2630 assert!(!has_meaningful_content("[tool_use: shell(call_1)]"));
2631 }
2632
2633 #[test]
2634 fn meaningful_content_tool_use_no_parens() {
2635 assert!(!has_meaningful_content("[tool_use: memory_save]"));
2637 }
2638
2639 #[test]
2640 fn meaningful_content_tool_result_with_body() {
2641 assert!(!has_meaningful_content(
2642 "[tool_result: call_1]\nsome output here"
2643 ));
2644 }
2645
2646 #[test]
2647 fn meaningful_content_tool_result_empty_body() {
2648 assert!(!has_meaningful_content("[tool_result: call_1]\n"));
2649 }
2650
2651 #[test]
2652 fn meaningful_content_tool_output_inline() {
2653 assert!(!has_meaningful_content("[tool output: bash] some result"));
2654 }
2655
2656 #[test]
2657 fn meaningful_content_tool_output_pruned() {
2658 assert!(!has_meaningful_content("[tool output: bash] (pruned)"));
2659 }
2660
2661 #[test]
2662 fn meaningful_content_tool_output_fenced() {
2663 assert!(!has_meaningful_content(
2664 "[tool output: bash]\n```\nls output\n```"
2665 ));
2666 }
2667
2668 #[test]
2669 fn meaningful_content_multiple_tool_use_tags() {
2670 assert!(!has_meaningful_content(
2671 "[tool_use: bash(id1)][tool_use: read(id2)]"
2672 ));
2673 }
2674
2675 #[test]
2676 fn meaningful_content_multiple_tool_use_tags_space_separator() {
2677 assert!(!has_meaningful_content(
2679 "[tool_use: bash(id1)] [tool_use: read(id2)]"
2680 ));
2681 }
2682
2683 #[test]
2684 fn meaningful_content_multiple_tool_use_tags_newline_separator() {
2685 assert!(!has_meaningful_content(
2687 "[tool_use: bash(id1)]\n[tool_use: read(id2)]"
2688 ));
2689 }
2690
2691 #[test]
2692 fn meaningful_content_tool_result_followed_by_tool_use() {
2693 assert!(!has_meaningful_content(
2695 "[tool_result: call_1]\nresult\n[tool_use: bash(call_2)]"
2696 ));
2697 }
2698
2699 #[test]
2700 fn meaningful_content_real_text_only() {
2701 assert!(has_meaningful_content("Hello, how can I help you?"));
2702 }
2703
2704 #[test]
2705 fn meaningful_content_text_before_tool_tag() {
2706 assert!(has_meaningful_content("Let me check. [tool_use: bash(id)]"));
2707 }
2708
2709 #[test]
2710 fn meaningful_content_text_after_tool_use_tag() {
2711 assert!(has_meaningful_content("[tool_use: bash] I ran the command"));
2715 }
2716
2717 #[test]
2718 fn meaningful_content_text_between_tags() {
2719 assert!(has_meaningful_content(
2720 "[tool_use: bash(id1)]\nand then\n[tool_use: read(id2)]"
2721 ));
2722 }
2723
2724 #[test]
2725 fn meaningful_content_malformed_tag_no_closing_bracket() {
2726 assert!(has_meaningful_content("[tool_use: "));
2728 }
2729
2730 #[test]
2731 fn meaningful_content_tool_use_and_tool_result_only() {
2732 assert!(!has_meaningful_content(
2734 "[tool_use: memory_save(call_abc)]\n[tool_result: call_abc]\nsaved"
2735 ));
2736 }
2737
2738 #[test]
2739 fn meaningful_content_tool_result_body_with_json_array() {
2740 assert!(!has_meaningful_content(
2741 "[tool_result: id1]\n[\"array\", \"value\"]"
2742 ));
2743 }
2744
2745 #[tokio::test]
2756 async fn issue_2529_orphaned_legacy_content_pair_is_soft_deleted() {
2757 use zeph_llm::provider::MessagePart;
2758
2759 let provider = mock_provider(vec![]);
2760 let channel = MockChannel::new(vec![]);
2761 let registry = create_test_registry();
2762 let executor = MockToolExecutor::no_tools();
2763
2764 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2765 let cid = memory.sqlite().create_conversation().await.unwrap();
2766 let sqlite = memory.sqlite();
2767
2768 sqlite
2770 .save_message(cid, "user", "save this for me")
2771 .await
2772 .unwrap();
2773
2774 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2777 id: "call_2529".to_string(),
2778 name: "memory_save".to_string(),
2779 input: serde_json::json!({"content": "save this"}),
2780 }])
2781 .unwrap();
2782 let orphan_assistant_id = sqlite
2783 .save_message_with_parts(
2784 cid,
2785 "assistant",
2786 "[tool_use: memory_save(call_2529)]",
2787 &use_parts,
2788 )
2789 .await
2790 .unwrap();
2791
2792 sqlite
2797 .save_message(cid, "assistant", "here is a plain reply")
2798 .await
2799 .unwrap();
2800
2801 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2802 tool_use_id: "call_2529".to_string(),
2803 content: "saved".to_string(),
2804 is_error: false,
2805 }])
2806 .unwrap();
2807 let orphan_user_id = sqlite
2808 .save_message_with_parts(
2809 cid,
2810 "user",
2811 "[tool_result: call_2529]\nsaved",
2812 &result_parts,
2813 )
2814 .await
2815 .unwrap();
2816
2817 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2819
2820 let memory_arc = std::sync::Arc::new(memory);
2821 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2822 memory_arc.clone(),
2823 cid,
2824 50,
2825 5,
2826 100,
2827 );
2828
2829 agent.load_history().await.unwrap();
2830
2831 let assistant_deleted_count: Vec<i64> = zeph_db::query_scalar(
2834 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
2835 )
2836 .bind(orphan_assistant_id)
2837 .fetch_all(memory_arc.sqlite().pool())
2838 .await
2839 .unwrap();
2840
2841 let user_deleted_count: Vec<i64> = zeph_db::query_scalar(
2842 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
2843 )
2844 .bind(orphan_user_id)
2845 .fetch_all(memory_arc.sqlite().pool())
2846 .await
2847 .unwrap();
2848
2849 assert_eq!(
2850 assistant_deleted_count.first().copied().unwrap_or(0),
2851 1,
2852 "orphaned assistant[ToolUse] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
2853 );
2854 assert_eq!(
2855 user_deleted_count.first().copied().unwrap_or(0),
2856 1,
2857 "orphaned user[ToolResult] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
2858 );
2859 }
2860
2861 #[tokio::test]
2865 async fn issue_2529_soft_delete_is_idempotent_across_sessions() {
2866 use zeph_llm::provider::MessagePart;
2867
2868 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2869 let cid = memory.sqlite().create_conversation().await.unwrap();
2870 let sqlite = memory.sqlite();
2871
2872 sqlite
2874 .save_message(cid, "user", "do something")
2875 .await
2876 .unwrap();
2877
2878 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2880 id: "call_idem".to_string(),
2881 name: "shell".to_string(),
2882 input: serde_json::json!({"command": "ls"}),
2883 }])
2884 .unwrap();
2885 sqlite
2886 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_idem)]", &use_parts)
2887 .await
2888 .unwrap();
2889
2890 sqlite
2892 .save_message(cid, "assistant", "continuing")
2893 .await
2894 .unwrap();
2895
2896 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2898 tool_use_id: "call_idem".to_string(),
2899 content: "output".to_string(),
2900 is_error: false,
2901 }])
2902 .unwrap();
2903 sqlite
2904 .save_message_with_parts(
2905 cid,
2906 "user",
2907 "[tool_result: call_idem]\noutput",
2908 &result_parts,
2909 )
2910 .await
2911 .unwrap();
2912
2913 sqlite
2914 .save_message(cid, "assistant", "final")
2915 .await
2916 .unwrap();
2917
2918 let memory_arc = std::sync::Arc::new(memory);
2919
2920 let mut agent1 = Agent::new(
2922 mock_provider(vec![]),
2923 MockChannel::new(vec![]),
2924 create_test_registry(),
2925 None,
2926 5,
2927 MockToolExecutor::no_tools(),
2928 )
2929 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
2930 agent1.load_history().await.unwrap();
2931 let count_after_first = agent1.msg.messages.len();
2932
2933 let mut agent2 = Agent::new(
2936 mock_provider(vec![]),
2937 MockChannel::new(vec![]),
2938 create_test_registry(),
2939 None,
2940 5,
2941 MockToolExecutor::no_tools(),
2942 )
2943 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
2944 agent2.load_history().await.unwrap();
2945 let count_after_second = agent2.msg.messages.len();
2946
2947 assert_eq!(
2949 count_after_first, count_after_second,
2950 "second load_history must load the same message count as the first (soft-deleted orphans excluded)"
2951 );
2952 }
2953
2954 #[tokio::test]
2958 async fn issue_2529_message_with_text_and_tool_tag_is_kept_after_part_strip() {
2959 use zeph_llm::provider::MessagePart;
2960
2961 let provider = mock_provider(vec![]);
2962 let channel = MockChannel::new(vec![]);
2963 let registry = create_test_registry();
2964 let executor = MockToolExecutor::no_tools();
2965
2966 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2967 let cid = memory.sqlite().create_conversation().await.unwrap();
2968 let sqlite = memory.sqlite();
2969
2970 sqlite
2972 .save_message(cid, "user", "check the files")
2973 .await
2974 .unwrap();
2975
2976 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2979 id: "call_mixed".to_string(),
2980 name: "shell".to_string(),
2981 input: serde_json::json!({"command": "ls"}),
2982 }])
2983 .unwrap();
2984 sqlite
2985 .save_message_with_parts(
2986 cid,
2987 "assistant",
2988 "Let me list the directory. [tool_use: shell(call_mixed)]",
2989 &use_parts,
2990 )
2991 .await
2992 .unwrap();
2993
2994 sqlite.save_message(cid, "user", "thanks").await.unwrap();
2996 sqlite
2997 .save_message(cid, "assistant", "you are welcome")
2998 .await
2999 .unwrap();
3000
3001 let memory_arc = std::sync::Arc::new(memory);
3002 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3003 memory_arc.clone(),
3004 cid,
3005 50,
3006 5,
3007 100,
3008 );
3009
3010 let messages_before = agent.msg.messages.len();
3011 agent.load_history().await.unwrap();
3012
3013 assert_eq!(
3015 agent.msg.messages.len(),
3016 messages_before + 4,
3017 "assistant message with text + tool tag must not be removed after ToolUse strip"
3018 );
3019
3020 let mixed_msg = agent
3022 .msg
3023 .messages
3024 .iter()
3025 .find(|m| m.content.contains("Let me list the directory"))
3026 .expect("mixed-content assistant message must still be in history");
3027 assert!(
3028 !mixed_msg
3029 .parts
3030 .iter()
3031 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
3032 "orphaned ToolUse parts must be stripped even when message has meaningful text"
3033 );
3034 assert_eq!(
3035 mixed_msg.content, "Let me list the directory. [tool_use: shell(call_mixed)]",
3036 "content field must be unchanged — only parts are stripped"
3037 );
3038 }
3039}