1use crate::channel::Channel;
5use zeph_llm::provider::{Message, MessagePart, Role};
6use zeph_memory::sqlite::role_str;
7
8use super::Agent;
9
10fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> usize {
25 let mut removed = 0;
26
27 loop {
28 if let Some(last) = messages.last()
30 && last.role == Role::Assistant
31 && last
32 .parts
33 .iter()
34 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
35 {
36 let ids: Vec<String> = last
37 .parts
38 .iter()
39 .filter_map(|p| {
40 if let MessagePart::ToolUse { id, .. } = p {
41 Some(id.clone())
42 } else {
43 None
44 }
45 })
46 .collect();
47 tracing::warn!(
48 tool_ids = ?ids,
49 "removing orphaned trailing tool_use message from restored history"
50 );
51 messages.pop();
52 removed += 1;
53 continue;
54 }
55
56 if let Some(first) = messages.first()
58 && first.role == Role::User
59 && first
60 .parts
61 .iter()
62 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
63 {
64 let ids: Vec<String> = first
65 .parts
66 .iter()
67 .filter_map(|p| {
68 if let MessagePart::ToolResult { tool_use_id, .. } = p {
69 Some(tool_use_id.clone())
70 } else {
71 None
72 }
73 })
74 .collect();
75 tracing::warn!(
76 tool_use_ids = ?ids,
77 "removing orphaned leading tool_result message from restored history"
78 );
79 messages.remove(0);
80 removed += 1;
81 continue;
82 }
83
84 break;
85 }
86
87 removed += strip_mid_history_orphans(messages);
90
91 removed
92}
93
94fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> usize {
100 let mut removed = 0;
101 let mut i = 0;
102 while i < messages.len() {
103 let has_tool_use = messages[i].role == Role::Assistant
104 && messages[i]
105 .parts
106 .iter()
107 .any(|p| matches!(p, MessagePart::ToolUse { .. }));
108
109 if !has_tool_use {
110 i += 1;
111 continue;
112 }
113
114 let tool_use_ids: Vec<String> = messages[i]
116 .parts
117 .iter()
118 .filter_map(|p| {
119 if let MessagePart::ToolUse { id, .. } = p {
120 Some(id.clone())
121 } else {
122 None
123 }
124 })
125 .collect();
126
127 let next_has_results = messages
129 .get(i + 1)
130 .is_some_and(|next| {
131 if next.role != Role::User {
132 return false;
133 }
134 tool_use_ids.iter().all(|uid| {
135 next.parts.iter().any(|p| {
136 matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)
137 })
138 })
139 });
140
141 if next_has_results {
142 i += 1;
143 continue;
144 }
145
146 tracing::warn!(
148 tool_ids = ?tool_use_ids,
149 index = i,
150 "stripping orphaned mid-history tool_use parts from assistant message"
151 );
152 messages[i]
153 .parts
154 .retain(|p| !matches!(p, MessagePart::ToolUse { .. }));
155
156 let is_empty = messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
158 if is_empty {
159 messages.remove(i);
160 removed += 1;
161 } else {
163 i += 1;
164 }
165 }
166 removed
167}
168
169impl<C: Channel> Agent<C> {
170 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
176 let (Some(memory), Some(cid)) =
177 (&self.memory_state.memory, self.memory_state.conversation_id)
178 else {
179 return Ok(());
180 };
181
182 let history = memory
183 .sqlite()
184 .load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
185 .await?;
186 if !history.is_empty() {
187 let mut loaded = 0;
188 let mut skipped = 0;
189
190 for msg in history {
191 if msg.content.trim().is_empty() {
192 tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
193 skipped += 1;
194 continue;
195 }
196 self.messages.push(msg);
197 loaded += 1;
198 }
199
200 let history_start = self.messages.len() - loaded;
202 let mut restored_slice = self.messages.split_off(history_start);
203 let orphans = sanitize_tool_pairs(&mut restored_slice);
204 skipped += orphans;
205 loaded = loaded.saturating_sub(orphans);
206 self.messages.append(&mut restored_slice);
207
208 tracing::info!("restored {loaded} message(s) from conversation {cid}");
209 if skipped > 0 {
210 tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
211 }
212 }
213
214 if let Ok(count) = memory.message_count(cid).await {
215 let count_u64 = u64::try_from(count).unwrap_or(0);
216 self.update_metrics(|m| {
217 m.sqlite_message_count = count_u64;
218 });
219 }
220
221 if let Ok(count) = memory.unsummarized_message_count(cid).await {
222 self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
223 }
224
225 self.recompute_prompt_tokens();
226 Ok(())
227 }
228
229 pub(crate) async fn persist_message(
235 &mut self,
236 role: Role,
237 content: &str,
238 parts: &[MessagePart],
239 has_injection_flags: bool,
240 ) {
241 let (Some(memory), Some(cid)) =
242 (&self.memory_state.memory, self.memory_state.conversation_id)
243 else {
244 return;
245 };
246
247 let parts_json = if parts.is_empty() {
248 "[]".to_string()
249 } else {
250 serde_json::to_string(parts).unwrap_or_else(|e| {
251 tracing::warn!("failed to serialize message parts, storing empty: {e}");
252 "[]".to_string()
253 })
254 };
255
256 let guard_event = self
260 .exfiltration_guard
261 .should_guard_memory_write(has_injection_flags);
262 if let Some(ref event) = guard_event {
263 tracing::warn!(
264 ?event,
265 "exfiltration guard: skipping Qdrant embedding for flagged content"
266 );
267 self.update_metrics(|m| m.exfiltration_memory_guards += 1);
268 self.push_security_event(
269 crate::metrics::SecurityEventCategory::ExfiltrationBlock,
270 "memory_write",
271 "Qdrant embedding skipped: flagged content",
272 );
273 }
274
275 let skip_embedding = guard_event.is_some();
276
277 let should_embed = if skip_embedding {
278 false
279 } else {
280 match role {
281 Role::Assistant => {
282 self.memory_state.autosave_assistant
283 && content.len() >= self.memory_state.autosave_min_length
284 }
285 _ => true,
286 }
287 };
288
289 let embedding_stored = if should_embed {
290 match memory
291 .remember_with_parts(cid, role_str(role), content, &parts_json)
292 .await
293 {
294 Ok((_message_id, stored)) => stored,
295 Err(e) => {
296 tracing::error!("failed to persist message: {e:#}");
297 return;
298 }
299 }
300 } else {
301 match memory
302 .save_only(cid, role_str(role), content, &parts_json)
303 .await
304 {
305 Ok(_) => false,
306 Err(e) => {
307 tracing::error!("failed to persist message: {e:#}");
308 return;
309 }
310 }
311 };
312
313 self.memory_state.unsummarized_count += 1;
314
315 self.update_metrics(|m| {
316 m.sqlite_message_count += 1;
317 if embedding_stored {
318 m.embeddings_generated += 1;
319 }
320 });
321
322 self.check_summarization().await;
323
324 self.maybe_spawn_graph_extraction(content, has_injection_flags)
325 .await;
326 }
327
328 async fn maybe_spawn_graph_extraction(&mut self, content: &str, has_injection_flags: bool) {
329 use zeph_memory::semantic::GraphExtractionConfig;
330
331 if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
332 return;
333 }
334
335 if has_injection_flags {
337 tracing::warn!("graph extraction skipped: injection patterns detected in content");
338 return;
339 }
340
341 let extraction_cfg = {
343 let cfg = &self.memory_state.graph_config;
344 if !cfg.enabled {
345 return;
346 }
347 GraphExtractionConfig {
348 max_entities: cfg.max_entities_per_message,
349 max_edges: cfg.max_edges_per_message,
350 extraction_timeout_secs: cfg.extraction_timeout_secs,
351 community_refresh_interval: cfg.community_refresh_interval,
352 expired_edge_retention_days: cfg.expired_edge_retention_days,
353 max_entities_cap: cfg.max_entities,
354 community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
355 community_summary_concurrency: cfg.community_summary_concurrency,
356 }
357 };
358
359 let context_messages: Vec<String> = self
361 .messages
362 .iter()
363 .rev()
364 .filter(|m| m.role == Role::User)
365 .take(4)
366 .map(|m| m.content.clone())
367 .collect();
368
369 let _ = self.channel.send_status("extracting graph...").await;
370
371 if let Some(memory) = &self.memory_state.memory {
372 memory.spawn_graph_extraction(content.to_owned(), context_messages, extraction_cfg);
373 }
374 self.sync_community_detection_failures();
375 self.sync_graph_extraction_metrics();
376 self.sync_graph_counts().await;
377 }
378
379 pub(crate) async fn check_summarization(&mut self) {
380 let (Some(memory), Some(cid)) =
381 (&self.memory_state.memory, self.memory_state.conversation_id)
382 else {
383 return;
384 };
385
386 if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
387 let _ = self.channel.send_status("summarizing...").await;
388 let batch_size = self.memory_state.summarization_threshold / 2;
389 match memory.summarize(cid, batch_size).await {
390 Ok(Some(summary_id)) => {
391 tracing::info!("created summary {summary_id} for conversation {cid}");
392 self.memory_state.unsummarized_count = 0;
393 self.update_metrics(|m| {
394 m.summaries_count += 1;
395 });
396 }
397 Ok(None) => {
398 tracing::debug!("no summarization needed");
399 }
400 Err(e) => {
401 tracing::error!("summarization failed: {e:#}");
402 }
403 }
404 let _ = self.channel.send_status("").await;
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::super::agent_tests::{
412 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
413 };
414 use super::*;
415 use zeph_llm::any::AnyProvider;
416 use zeph_memory::semantic::SemanticMemory;
417
418 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
419 SemanticMemory::new(
420 ":memory:",
421 "http://127.0.0.1:1",
422 provider.clone(),
423 "test-model",
424 )
425 .await
426 .unwrap()
427 }
428
429 #[tokio::test]
430 async fn load_history_without_memory_returns_ok() {
431 let provider = mock_provider(vec![]);
432 let channel = MockChannel::new(vec![]);
433 let registry = create_test_registry();
434 let executor = MockToolExecutor::no_tools();
435 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
436
437 let result = agent.load_history().await;
438 assert!(result.is_ok());
439 assert_eq!(agent.messages.len(), 1); }
442
443 #[tokio::test]
444 async fn load_history_with_messages_injects_into_agent() {
445 let provider = mock_provider(vec![]);
446 let channel = MockChannel::new(vec![]);
447 let registry = create_test_registry();
448 let executor = MockToolExecutor::no_tools();
449
450 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
451 let cid = memory.sqlite().create_conversation().await.unwrap();
452
453 memory
454 .sqlite()
455 .save_message(cid, "user", "hello from history")
456 .await
457 .unwrap();
458 memory
459 .sqlite()
460 .save_message(cid, "assistant", "hi back")
461 .await
462 .unwrap();
463
464 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
465 std::sync::Arc::new(memory),
466 cid,
467 50,
468 5,
469 100,
470 );
471
472 let messages_before = agent.messages.len();
473 agent.load_history().await.unwrap();
474 assert_eq!(agent.messages.len(), messages_before + 2);
476 }
477
478 #[tokio::test]
479 async fn load_history_skips_empty_messages() {
480 let provider = mock_provider(vec![]);
481 let channel = MockChannel::new(vec![]);
482 let registry = create_test_registry();
483 let executor = MockToolExecutor::no_tools();
484
485 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
486 let cid = memory.sqlite().create_conversation().await.unwrap();
487
488 memory
490 .sqlite()
491 .save_message(cid, "user", " ")
492 .await
493 .unwrap();
494 memory
495 .sqlite()
496 .save_message(cid, "user", "real message")
497 .await
498 .unwrap();
499
500 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
501 std::sync::Arc::new(memory),
502 cid,
503 50,
504 5,
505 100,
506 );
507
508 let messages_before = agent.messages.len();
509 agent.load_history().await.unwrap();
510 assert_eq!(agent.messages.len(), messages_before + 1);
512 }
513
514 #[tokio::test]
515 async fn load_history_with_empty_store_returns_ok() {
516 let provider = mock_provider(vec![]);
517 let channel = MockChannel::new(vec![]);
518 let registry = create_test_registry();
519 let executor = MockToolExecutor::no_tools();
520
521 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
522 let cid = memory.sqlite().create_conversation().await.unwrap();
523
524 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
525 std::sync::Arc::new(memory),
526 cid,
527 50,
528 5,
529 100,
530 );
531
532 let messages_before = agent.messages.len();
533 agent.load_history().await.unwrap();
534 assert_eq!(agent.messages.len(), messages_before);
536 }
537
538 #[tokio::test]
539 async fn persist_message_without_memory_silently_returns() {
540 let provider = mock_provider(vec![]);
542 let channel = MockChannel::new(vec![]);
543 let registry = create_test_registry();
544 let executor = MockToolExecutor::no_tools();
545 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
546
547 agent.persist_message(Role::User, "hello", &[], false).await;
549 }
550
551 #[tokio::test]
552 async fn persist_message_assistant_autosave_false_uses_save_only() {
553 let provider = mock_provider(vec![]);
554 let channel = MockChannel::new(vec![]);
555 let registry = create_test_registry();
556 let executor = MockToolExecutor::no_tools();
557
558 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
559 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
560 let cid = memory.sqlite().create_conversation().await.unwrap();
561
562 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
563 .with_metrics(tx)
564 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
565 .with_autosave_config(false, 20);
566
567 agent
568 .persist_message(Role::Assistant, "short assistant reply", &[], false)
569 .await;
570
571 let history = agent
572 .memory_state
573 .memory
574 .as_ref()
575 .unwrap()
576 .sqlite()
577 .load_history(cid, 50)
578 .await
579 .unwrap();
580 assert_eq!(history.len(), 1, "message must be saved");
581 assert_eq!(history[0].content, "short assistant reply");
582 assert_eq!(rx.borrow().embeddings_generated, 0);
584 }
585
586 #[tokio::test]
587 async fn persist_message_assistant_below_min_length_uses_save_only() {
588 let provider = mock_provider(vec![]);
589 let channel = MockChannel::new(vec![]);
590 let registry = create_test_registry();
591 let executor = MockToolExecutor::no_tools();
592
593 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
594 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
595 let cid = memory.sqlite().create_conversation().await.unwrap();
596
597 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
599 .with_metrics(tx)
600 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
601 .with_autosave_config(true, 1000);
602
603 agent
604 .persist_message(Role::Assistant, "too short", &[], false)
605 .await;
606
607 let history = agent
608 .memory_state
609 .memory
610 .as_ref()
611 .unwrap()
612 .sqlite()
613 .load_history(cid, 50)
614 .await
615 .unwrap();
616 assert_eq!(history.len(), 1, "message must be saved");
617 assert_eq!(history[0].content, "too short");
618 assert_eq!(rx.borrow().embeddings_generated, 0);
619 }
620
621 #[tokio::test]
622 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
623 let provider = mock_provider(vec![]);
625 let channel = MockChannel::new(vec![]);
626 let registry = create_test_registry();
627 let executor = MockToolExecutor::no_tools();
628
629 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
630 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
631 let cid = memory.sqlite().create_conversation().await.unwrap();
632
633 let min_length = 10usize;
634 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
635 .with_metrics(tx)
636 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
637 .with_autosave_config(true, min_length);
638
639 let content_at_boundary = "A".repeat(min_length);
641 assert_eq!(content_at_boundary.len(), min_length);
642 agent
643 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
644 .await;
645
646 assert_eq!(rx.borrow().sqlite_message_count, 1);
648 }
649
650 #[tokio::test]
651 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
652 let provider = mock_provider(vec![]);
654 let channel = MockChannel::new(vec![]);
655 let registry = create_test_registry();
656 let executor = MockToolExecutor::no_tools();
657
658 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
659 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
660 let cid = memory.sqlite().create_conversation().await.unwrap();
661
662 let min_length = 10usize;
663 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
664 .with_metrics(tx)
665 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
666 .with_autosave_config(true, min_length);
667
668 let content_below_boundary = "A".repeat(min_length - 1);
670 assert_eq!(content_below_boundary.len(), min_length - 1);
671 agent
672 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
673 .await;
674
675 let history = agent
676 .memory_state
677 .memory
678 .as_ref()
679 .unwrap()
680 .sqlite()
681 .load_history(cid, 50)
682 .await
683 .unwrap();
684 assert_eq!(history.len(), 1, "message must still be saved");
685 assert_eq!(rx.borrow().embeddings_generated, 0);
687 }
688
689 #[tokio::test]
690 async fn persist_message_increments_unsummarized_count() {
691 let provider = mock_provider(vec![]);
692 let channel = MockChannel::new(vec![]);
693 let registry = create_test_registry();
694 let executor = MockToolExecutor::no_tools();
695
696 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
697 let cid = memory.sqlite().create_conversation().await.unwrap();
698
699 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
701 std::sync::Arc::new(memory),
702 cid,
703 50,
704 5,
705 100,
706 );
707
708 assert_eq!(agent.memory_state.unsummarized_count, 0);
709
710 agent.persist_message(Role::User, "first", &[], false).await;
711 assert_eq!(agent.memory_state.unsummarized_count, 1);
712
713 agent
714 .persist_message(Role::User, "second", &[], false)
715 .await;
716 assert_eq!(agent.memory_state.unsummarized_count, 2);
717 }
718
719 #[tokio::test]
720 async fn check_summarization_resets_counter_on_success() {
721 let provider = mock_provider(vec![]);
722 let channel = MockChannel::new(vec![]);
723 let registry = create_test_registry();
724 let executor = MockToolExecutor::no_tools();
725
726 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
727 let cid = memory.sqlite().create_conversation().await.unwrap();
728
729 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
731 std::sync::Arc::new(memory),
732 cid,
733 50,
734 5,
735 1,
736 );
737
738 agent.persist_message(Role::User, "msg1", &[], false).await;
739 agent.persist_message(Role::User, "msg2", &[], false).await;
740
741 assert!(agent.memory_state.unsummarized_count <= 2);
746 }
747
748 #[tokio::test]
749 async fn unsummarized_count_not_incremented_without_memory() {
750 let provider = mock_provider(vec![]);
751 let channel = MockChannel::new(vec![]);
752 let registry = create_test_registry();
753 let executor = MockToolExecutor::no_tools();
754 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
755
756 agent.persist_message(Role::User, "hello", &[], false).await;
757 assert_eq!(agent.memory_state.unsummarized_count, 0);
759 }
760
761 mod graph_extraction_guards {
763 use super::*;
764 use crate::config::GraphConfig;
765 use zeph_memory::graph::GraphStore;
766
767 fn enabled_graph_config() -> GraphConfig {
768 GraphConfig {
769 enabled: true,
770 ..GraphConfig::default()
771 }
772 }
773
774 async fn agent_with_graph(
775 provider: &AnyProvider,
776 config: GraphConfig,
777 ) -> Agent<MockChannel> {
778 let memory =
779 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
780 let cid = memory.sqlite().create_conversation().await.unwrap();
781 Agent::new(
782 provider.clone(),
783 MockChannel::new(vec![]),
784 create_test_registry(),
785 None,
786 5,
787 MockToolExecutor::no_tools(),
788 )
789 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
790 .with_graph_config(config)
791 }
792
793 #[tokio::test]
794 async fn injection_flag_guard_skips_extraction() {
795 let provider = mock_provider(vec![]);
797 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
798 let pool = agent
799 .memory_state
800 .memory
801 .as_ref()
802 .unwrap()
803 .sqlite()
804 .pool()
805 .clone();
806
807 agent.maybe_spawn_graph_extraction("I use Rust", true).await;
808
809 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
811
812 let store = GraphStore::new(pool);
813 let count = store.get_metadata("extraction_count").await.unwrap();
814 assert!(
815 count.is_none(),
816 "injection flag must prevent extraction_count from being written"
817 );
818 }
819
820 #[tokio::test]
821 async fn disabled_config_guard_skips_extraction() {
822 let provider = mock_provider(vec![]);
824 let disabled_cfg = GraphConfig {
825 enabled: false,
826 ..GraphConfig::default()
827 };
828 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
829 let pool = agent
830 .memory_state
831 .memory
832 .as_ref()
833 .unwrap()
834 .sqlite()
835 .pool()
836 .clone();
837
838 agent
839 .maybe_spawn_graph_extraction("I use Rust", false)
840 .await;
841
842 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
843
844 let store = GraphStore::new(pool);
845 let count = store.get_metadata("extraction_count").await.unwrap();
846 assert!(
847 count.is_none(),
848 "disabled graph config must prevent extraction"
849 );
850 }
851
852 #[tokio::test]
853 async fn happy_path_fires_extraction() {
854 let provider = mock_provider(vec![]);
857 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
858 let pool = agent
859 .memory_state
860 .memory
861 .as_ref()
862 .unwrap()
863 .sqlite()
864 .pool()
865 .clone();
866
867 agent
868 .maybe_spawn_graph_extraction("I use Rust for systems programming", false)
869 .await;
870
871 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
873
874 let store = GraphStore::new(pool);
875 let count = store.get_metadata("extraction_count").await.unwrap();
876 assert!(
877 count.is_some(),
878 "happy-path extraction must increment extraction_count"
879 );
880 }
881 }
882
883 #[tokio::test]
884 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
885 let provider = mock_provider(vec![]);
886 let channel = MockChannel::new(vec![]);
887 let registry = create_test_registry();
888 let executor = MockToolExecutor::no_tools();
889
890 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
891 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
892 let cid = memory.sqlite().create_conversation().await.unwrap();
893
894 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
896 .with_metrics(tx)
897 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
898 .with_autosave_config(false, 20);
899
900 let long_user_msg = "A".repeat(100);
901 agent
902 .persist_message(Role::User, &long_user_msg, &[], false)
903 .await;
904
905 let history = agent
906 .memory_state
907 .memory
908 .as_ref()
909 .unwrap()
910 .sqlite()
911 .load_history(cid, 50)
912 .await
913 .unwrap();
914 assert_eq!(history.len(), 1, "user message must be saved");
915 assert_eq!(rx.borrow().sqlite_message_count, 1);
918 }
919
920 #[tokio::test]
924 async fn persist_message_saves_correct_tool_use_parts() {
925 use zeph_llm::provider::MessagePart;
926
927 let provider = mock_provider(vec![]);
928 let channel = MockChannel::new(vec![]);
929 let registry = create_test_registry();
930 let executor = MockToolExecutor::no_tools();
931
932 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
933 let cid = memory.sqlite().create_conversation().await.unwrap();
934
935 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
936 std::sync::Arc::new(memory),
937 cid,
938 50,
939 5,
940 100,
941 );
942
943 let parts = vec![MessagePart::ToolUse {
944 id: "call_abc123".to_string(),
945 name: "read_file".to_string(),
946 input: serde_json::json!({"path": "/tmp/test.txt"}),
947 }];
948 let content = "[tool_use: read_file(call_abc123)]";
949
950 agent
951 .persist_message(Role::Assistant, content, &parts, false)
952 .await;
953
954 let history = agent
955 .memory_state
956 .memory
957 .as_ref()
958 .unwrap()
959 .sqlite()
960 .load_history(cid, 50)
961 .await
962 .unwrap();
963
964 assert_eq!(history.len(), 1);
965 assert_eq!(history[0].role, Role::Assistant);
966 assert_eq!(history[0].content, content);
967 assert_eq!(history[0].parts.len(), 1);
968 match &history[0].parts[0] {
969 MessagePart::ToolUse { id, name, .. } => {
970 assert_eq!(id, "call_abc123");
971 assert_eq!(name, "read_file");
972 }
973 other => panic!("expected ToolUse part, got {other:?}"),
974 }
975 assert!(
977 !history[0]
978 .parts
979 .iter()
980 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
981 "assistant message must not contain ToolResult parts"
982 );
983 }
984
985 #[tokio::test]
986 async fn persist_message_saves_correct_tool_result_parts() {
987 use zeph_llm::provider::MessagePart;
988
989 let provider = mock_provider(vec![]);
990 let channel = MockChannel::new(vec![]);
991 let registry = create_test_registry();
992 let executor = MockToolExecutor::no_tools();
993
994 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
995 let cid = memory.sqlite().create_conversation().await.unwrap();
996
997 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
998 std::sync::Arc::new(memory),
999 cid,
1000 50,
1001 5,
1002 100,
1003 );
1004
1005 let parts = vec![MessagePart::ToolResult {
1006 tool_use_id: "call_abc123".to_string(),
1007 content: "file contents here".to_string(),
1008 is_error: false,
1009 }];
1010 let content = "[tool_result: call_abc123]\nfile contents here";
1011
1012 agent
1013 .persist_message(Role::User, content, &parts, false)
1014 .await;
1015
1016 let history = agent
1017 .memory_state
1018 .memory
1019 .as_ref()
1020 .unwrap()
1021 .sqlite()
1022 .load_history(cid, 50)
1023 .await
1024 .unwrap();
1025
1026 assert_eq!(history.len(), 1);
1027 assert_eq!(history[0].role, Role::User);
1028 assert_eq!(history[0].content, content);
1029 assert_eq!(history[0].parts.len(), 1);
1030 match &history[0].parts[0] {
1031 MessagePart::ToolResult {
1032 tool_use_id,
1033 content: result_content,
1034 is_error,
1035 } => {
1036 assert_eq!(tool_use_id, "call_abc123");
1037 assert_eq!(result_content, "file contents here");
1038 assert!(!is_error);
1039 }
1040 other => panic!("expected ToolResult part, got {other:?}"),
1041 }
1042 assert!(
1044 !history[0]
1045 .parts
1046 .iter()
1047 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1048 "user ToolResult message must not contain ToolUse parts"
1049 );
1050 }
1051
1052 #[tokio::test]
1053 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1054 use zeph_llm::provider::MessagePart;
1055
1056 let provider = mock_provider(vec![]);
1057 let channel = MockChannel::new(vec![]);
1058 let registry = create_test_registry();
1059 let executor = MockToolExecutor::no_tools();
1060
1061 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1062 let cid = memory.sqlite().create_conversation().await.unwrap();
1063
1064 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1065 std::sync::Arc::new(memory),
1066 cid,
1067 50,
1068 5,
1069 100,
1070 );
1071
1072 let assistant_parts = vec![MessagePart::ToolUse {
1074 id: "id_1".to_string(),
1075 name: "list_dir".to_string(),
1076 input: serde_json::json!({"path": "/tmp"}),
1077 }];
1078 agent
1079 .persist_message(
1080 Role::Assistant,
1081 "[tool_use: list_dir(id_1)]",
1082 &assistant_parts,
1083 false,
1084 )
1085 .await;
1086
1087 let user_parts = vec![MessagePart::ToolResult {
1089 tool_use_id: "id_1".to_string(),
1090 content: "file1.txt\nfile2.txt".to_string(),
1091 is_error: false,
1092 }];
1093 agent
1094 .persist_message(
1095 Role::User,
1096 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1097 &user_parts,
1098 false,
1099 )
1100 .await;
1101
1102 let history = agent
1103 .memory_state
1104 .memory
1105 .as_ref()
1106 .unwrap()
1107 .sqlite()
1108 .load_history(cid, 50)
1109 .await
1110 .unwrap();
1111
1112 assert_eq!(history.len(), 2);
1113
1114 assert_eq!(history[0].role, Role::Assistant);
1116 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1117 assert!(
1118 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1119 "first message must be assistant ToolUse"
1120 );
1121
1122 assert_eq!(history[1].role, Role::User);
1124 assert_eq!(
1125 history[1].content,
1126 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1127 );
1128 assert!(
1129 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1130 "second message must be user ToolResult"
1131 );
1132
1133 assert!(
1135 !history[0]
1136 .parts
1137 .iter()
1138 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1139 "assistant message must not have ToolResult parts"
1140 );
1141 assert!(
1142 !history[1]
1143 .parts
1144 .iter()
1145 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1146 "user message must not have ToolUse parts"
1147 );
1148 }
1149
1150 #[tokio::test]
1151 async fn persist_message_saves_correct_tool_output_parts() {
1152 use zeph_llm::provider::MessagePart;
1153
1154 let provider = mock_provider(vec![]);
1155 let channel = MockChannel::new(vec![]);
1156 let registry = create_test_registry();
1157 let executor = MockToolExecutor::no_tools();
1158
1159 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1160 let cid = memory.sqlite().create_conversation().await.unwrap();
1161
1162 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1163 std::sync::Arc::new(memory),
1164 cid,
1165 50,
1166 5,
1167 100,
1168 );
1169
1170 let parts = vec![MessagePart::ToolOutput {
1171 tool_name: "shell".to_string(),
1172 body: "hello from shell".to_string(),
1173 compacted_at: None,
1174 }];
1175 let content = "[tool: shell]\nhello from shell";
1176
1177 agent
1178 .persist_message(Role::User, content, &parts, false)
1179 .await;
1180
1181 let history = agent
1182 .memory_state
1183 .memory
1184 .as_ref()
1185 .unwrap()
1186 .sqlite()
1187 .load_history(cid, 50)
1188 .await
1189 .unwrap();
1190
1191 assert_eq!(history.len(), 1);
1192 assert_eq!(history[0].role, Role::User);
1193 assert_eq!(history[0].content, content);
1194 assert_eq!(history[0].parts.len(), 1);
1195 match &history[0].parts[0] {
1196 MessagePart::ToolOutput {
1197 tool_name,
1198 body,
1199 compacted_at,
1200 } => {
1201 assert_eq!(tool_name, "shell");
1202 assert_eq!(body, "hello from shell");
1203 assert!(compacted_at.is_none());
1204 }
1205 other => panic!("expected ToolOutput part, got {other:?}"),
1206 }
1207 }
1208
1209 #[tokio::test]
1212 async fn load_history_removes_trailing_orphan_tool_use() {
1213 use zeph_llm::provider::MessagePart;
1214
1215 let provider = mock_provider(vec![]);
1216 let channel = MockChannel::new(vec![]);
1217 let registry = create_test_registry();
1218 let executor = MockToolExecutor::no_tools();
1219
1220 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1221 let cid = memory.sqlite().create_conversation().await.unwrap();
1222 let sqlite = memory.sqlite();
1223
1224 sqlite
1226 .save_message(cid, "user", "do something with a tool")
1227 .await
1228 .unwrap();
1229
1230 let parts = serde_json::to_string(&[MessagePart::ToolUse {
1232 id: "call_orphan".to_string(),
1233 name: "shell".to_string(),
1234 input: serde_json::json!({"command": "ls"}),
1235 }])
1236 .unwrap();
1237 sqlite
1238 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
1239 .await
1240 .unwrap();
1241
1242 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1243 std::sync::Arc::new(memory),
1244 cid,
1245 50,
1246 5,
1247 100,
1248 );
1249
1250 let messages_before = agent.messages.len();
1251 agent.load_history().await.unwrap();
1252
1253 assert_eq!(
1255 agent.messages.len(),
1256 messages_before + 1,
1257 "orphaned trailing tool_use must be removed"
1258 );
1259 assert_eq!(agent.messages.last().unwrap().role, Role::User);
1260 }
1261
1262 #[tokio::test]
1263 async fn load_history_removes_leading_orphan_tool_result() {
1264 use zeph_llm::provider::MessagePart;
1265
1266 let provider = mock_provider(vec![]);
1267 let channel = MockChannel::new(vec![]);
1268 let registry = create_test_registry();
1269 let executor = MockToolExecutor::no_tools();
1270
1271 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1272 let cid = memory.sqlite().create_conversation().await.unwrap();
1273 let sqlite = memory.sqlite();
1274
1275 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1277 tool_use_id: "call_missing".to_string(),
1278 content: "result data".to_string(),
1279 is_error: false,
1280 }])
1281 .unwrap();
1282 sqlite
1283 .save_message_with_parts(
1284 cid,
1285 "user",
1286 "[tool_result: call_missing]\nresult data",
1287 &result_parts,
1288 )
1289 .await
1290 .unwrap();
1291
1292 sqlite
1294 .save_message(cid, "assistant", "here is my response")
1295 .await
1296 .unwrap();
1297
1298 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1299 std::sync::Arc::new(memory),
1300 cid,
1301 50,
1302 5,
1303 100,
1304 );
1305
1306 let messages_before = agent.messages.len();
1307 agent.load_history().await.unwrap();
1308
1309 assert_eq!(
1311 agent.messages.len(),
1312 messages_before + 1,
1313 "orphaned leading tool_result must be removed"
1314 );
1315 assert_eq!(agent.messages.last().unwrap().role, Role::Assistant);
1316 }
1317
1318 #[tokio::test]
1319 async fn load_history_preserves_complete_tool_pairs() {
1320 use zeph_llm::provider::MessagePart;
1321
1322 let provider = mock_provider(vec![]);
1323 let channel = MockChannel::new(vec![]);
1324 let registry = create_test_registry();
1325 let executor = MockToolExecutor::no_tools();
1326
1327 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1328 let cid = memory.sqlite().create_conversation().await.unwrap();
1329 let sqlite = memory.sqlite();
1330
1331 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1333 id: "call_ok".to_string(),
1334 name: "shell".to_string(),
1335 input: serde_json::json!({"command": "pwd"}),
1336 }])
1337 .unwrap();
1338 sqlite
1339 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
1340 .await
1341 .unwrap();
1342
1343 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1344 tool_use_id: "call_ok".to_string(),
1345 content: "/home/user".to_string(),
1346 is_error: false,
1347 }])
1348 .unwrap();
1349 sqlite
1350 .save_message_with_parts(
1351 cid,
1352 "user",
1353 "[tool_result: call_ok]\n/home/user",
1354 &result_parts,
1355 )
1356 .await
1357 .unwrap();
1358
1359 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1360 std::sync::Arc::new(memory),
1361 cid,
1362 50,
1363 5,
1364 100,
1365 );
1366
1367 let messages_before = agent.messages.len();
1368 agent.load_history().await.unwrap();
1369
1370 assert_eq!(
1372 agent.messages.len(),
1373 messages_before + 2,
1374 "complete tool_use/tool_result pair must be preserved"
1375 );
1376 assert_eq!(agent.messages[messages_before].role, Role::Assistant);
1377 assert_eq!(agent.messages[messages_before + 1].role, Role::User);
1378 }
1379
1380 #[tokio::test]
1381 async fn load_history_handles_multiple_trailing_orphans() {
1382 use zeph_llm::provider::MessagePart;
1383
1384 let provider = mock_provider(vec![]);
1385 let channel = MockChannel::new(vec![]);
1386 let registry = create_test_registry();
1387 let executor = MockToolExecutor::no_tools();
1388
1389 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1390 let cid = memory.sqlite().create_conversation().await.unwrap();
1391 let sqlite = memory.sqlite();
1392
1393 sqlite.save_message(cid, "user", "start").await.unwrap();
1395
1396 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
1398 id: "call_1".to_string(),
1399 name: "shell".to_string(),
1400 input: serde_json::json!({}),
1401 }])
1402 .unwrap();
1403 sqlite
1404 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
1405 .await
1406 .unwrap();
1407
1408 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
1410 id: "call_2".to_string(),
1411 name: "read_file".to_string(),
1412 input: serde_json::json!({}),
1413 }])
1414 .unwrap();
1415 sqlite
1416 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
1417 .await
1418 .unwrap();
1419
1420 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1421 std::sync::Arc::new(memory),
1422 cid,
1423 50,
1424 5,
1425 100,
1426 );
1427
1428 let messages_before = agent.messages.len();
1429 agent.load_history().await.unwrap();
1430
1431 assert_eq!(
1433 agent.messages.len(),
1434 messages_before + 1,
1435 "all trailing orphaned tool_use messages must be removed"
1436 );
1437 assert_eq!(agent.messages.last().unwrap().role, Role::User);
1438 }
1439
1440 #[tokio::test]
1441 async fn load_history_no_tool_messages_unchanged() {
1442 let provider = mock_provider(vec![]);
1443 let channel = MockChannel::new(vec![]);
1444 let registry = create_test_registry();
1445 let executor = MockToolExecutor::no_tools();
1446
1447 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1448 let cid = memory.sqlite().create_conversation().await.unwrap();
1449 let sqlite = memory.sqlite();
1450
1451 sqlite.save_message(cid, "user", "hello").await.unwrap();
1452 sqlite
1453 .save_message(cid, "assistant", "hi there")
1454 .await
1455 .unwrap();
1456 sqlite
1457 .save_message(cid, "user", "how are you?")
1458 .await
1459 .unwrap();
1460
1461 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1462 std::sync::Arc::new(memory),
1463 cid,
1464 50,
1465 5,
1466 100,
1467 );
1468
1469 let messages_before = agent.messages.len();
1470 agent.load_history().await.unwrap();
1471
1472 assert_eq!(
1474 agent.messages.len(),
1475 messages_before + 3,
1476 "plain messages without tool parts must pass through unchanged"
1477 );
1478 }
1479
1480 #[tokio::test]
1481 async fn load_history_removes_both_leading_and_trailing_orphans() {
1482 use zeph_llm::provider::MessagePart;
1483
1484 let provider = mock_provider(vec![]);
1485 let channel = MockChannel::new(vec![]);
1486 let registry = create_test_registry();
1487 let executor = MockToolExecutor::no_tools();
1488
1489 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1490 let cid = memory.sqlite().create_conversation().await.unwrap();
1491 let sqlite = memory.sqlite();
1492
1493 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1495 tool_use_id: "call_leading".to_string(),
1496 content: "orphaned result".to_string(),
1497 is_error: false,
1498 }])
1499 .unwrap();
1500 sqlite
1501 .save_message_with_parts(
1502 cid,
1503 "user",
1504 "[tool_result: call_leading]\norphaned result",
1505 &result_parts,
1506 )
1507 .await
1508 .unwrap();
1509
1510 sqlite
1512 .save_message(cid, "user", "what is 2+2?")
1513 .await
1514 .unwrap();
1515 sqlite.save_message(cid, "assistant", "4").await.unwrap();
1516
1517 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1519 id: "call_trailing".to_string(),
1520 name: "shell".to_string(),
1521 input: serde_json::json!({"command": "date"}),
1522 }])
1523 .unwrap();
1524 sqlite
1525 .save_message_with_parts(
1526 cid,
1527 "assistant",
1528 "[tool_use: shell(call_trailing)]",
1529 &use_parts,
1530 )
1531 .await
1532 .unwrap();
1533
1534 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1535 std::sync::Arc::new(memory),
1536 cid,
1537 50,
1538 5,
1539 100,
1540 );
1541
1542 let messages_before = agent.messages.len();
1543 agent.load_history().await.unwrap();
1544
1545 assert_eq!(
1547 agent.messages.len(),
1548 messages_before + 2,
1549 "both leading and trailing orphans must be removed"
1550 );
1551 assert_eq!(agent.messages[messages_before].role, Role::User);
1552 assert_eq!(agent.messages[messages_before].content, "what is 2+2?");
1553 assert_eq!(agent.messages[messages_before + 1].role, Role::Assistant);
1554 assert_eq!(agent.messages[messages_before + 1].content, "4");
1555 }
1556
1557 #[tokio::test]
1561 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
1562 use zeph_llm::provider::MessagePart;
1563
1564 let provider = mock_provider(vec![]);
1565 let channel = MockChannel::new(vec![]);
1566 let registry = create_test_registry();
1567 let executor = MockToolExecutor::no_tools();
1568
1569 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1570 let cid = memory.sqlite().create_conversation().await.unwrap();
1571 let sqlite = memory.sqlite();
1572
1573 sqlite
1575 .save_message(cid, "user", "first question")
1576 .await
1577 .unwrap();
1578 sqlite
1579 .save_message(cid, "assistant", "first answer")
1580 .await
1581 .unwrap();
1582
1583 let use_parts = serde_json::to_string(&[
1587 MessagePart::ToolUse {
1588 id: "call_mid_1".to_string(),
1589 name: "shell".to_string(),
1590 input: serde_json::json!({"command": "ls"}),
1591 },
1592 MessagePart::Text {
1593 text: "Let me check the files.".to_string(),
1594 },
1595 ])
1596 .unwrap();
1597 sqlite
1598 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
1599 .await
1600 .unwrap();
1601
1602 sqlite
1604 .save_message(cid, "user", "second question")
1605 .await
1606 .unwrap();
1607 sqlite
1608 .save_message(cid, "assistant", "second answer")
1609 .await
1610 .unwrap();
1611
1612 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1613 std::sync::Arc::new(memory),
1614 cid,
1615 50,
1616 5,
1617 100,
1618 );
1619
1620 let messages_before = agent.messages.len();
1621 agent.load_history().await.unwrap();
1622
1623 assert_eq!(
1626 agent.messages.len(),
1627 messages_before + 5,
1628 "message count must be 5 (orphan message kept — has text content)"
1629 );
1630
1631 let orphan = &agent.messages[messages_before + 2];
1633 assert_eq!(orphan.role, Role::Assistant);
1634 assert!(
1635 !orphan
1636 .parts
1637 .iter()
1638 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1639 "orphaned ToolUse parts must be stripped from mid-history message"
1640 );
1641 assert!(
1643 orphan.parts.iter().any(
1644 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
1645 ),
1646 "text content of orphaned assistant message must be preserved"
1647 );
1648 }
1649
1650 #[tokio::test]
1653 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
1654 use zeph_llm::provider::MessagePart;
1655
1656 let provider = mock_provider(vec![]);
1657 let channel = MockChannel::new(vec![]);
1658 let registry = create_test_registry();
1659 let executor = MockToolExecutor::no_tools();
1660
1661 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1662 let cid = memory.sqlite().create_conversation().await.unwrap();
1663 let sqlite = memory.sqlite();
1664
1665 sqlite
1666 .save_message(cid, "user", "run a command")
1667 .await
1668 .unwrap();
1669
1670 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1672 id: "call_ok".to_string(),
1673 name: "shell".to_string(),
1674 input: serde_json::json!({"command": "echo hi"}),
1675 }])
1676 .unwrap();
1677 sqlite
1678 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
1679 .await
1680 .unwrap();
1681
1682 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1684 tool_use_id: "call_ok".to_string(),
1685 content: "hi".to_string(),
1686 is_error: false,
1687 }])
1688 .unwrap();
1689 sqlite
1690 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
1691 .await
1692 .unwrap();
1693
1694 sqlite.save_message(cid, "assistant", "done").await.unwrap();
1695
1696 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1697 std::sync::Arc::new(memory),
1698 cid,
1699 50,
1700 5,
1701 100,
1702 );
1703
1704 let messages_before = agent.messages.len();
1705 agent.load_history().await.unwrap();
1706
1707 assert_eq!(
1709 agent.messages.len(),
1710 messages_before + 4,
1711 "matched tool pair must not be removed"
1712 );
1713 let tool_msg = &agent.messages[messages_before + 1];
1714 assert!(
1715 tool_msg
1716 .parts
1717 .iter()
1718 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
1719 "matched ToolUse parts must be preserved"
1720 );
1721 }
1722}