1use std::collections::HashSet;
5
6use crate::channel::Channel;
7use zeph_llm::provider::{Message, MessagePart, Role};
8use zeph_memory::sqlite::role_str;
9
10use super::Agent;
11
12fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> usize {
28 let mut removed = 0;
29
30 loop {
31 if let Some(last) = messages.last()
33 && last.role == Role::Assistant
34 && last
35 .parts
36 .iter()
37 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
38 {
39 let ids: Vec<String> = last
40 .parts
41 .iter()
42 .filter_map(|p| {
43 if let MessagePart::ToolUse { id, .. } = p {
44 Some(id.clone())
45 } else {
46 None
47 }
48 })
49 .collect();
50 tracing::warn!(
51 tool_ids = ?ids,
52 "removing orphaned trailing tool_use message from restored history"
53 );
54 messages.pop();
55 removed += 1;
56 continue;
57 }
58
59 if let Some(first) = messages.first()
61 && first.role == Role::User
62 && first
63 .parts
64 .iter()
65 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
66 {
67 let ids: Vec<String> = first
68 .parts
69 .iter()
70 .filter_map(|p| {
71 if let MessagePart::ToolResult { tool_use_id, .. } = p {
72 Some(tool_use_id.clone())
73 } else {
74 None
75 }
76 })
77 .collect();
78 tracing::warn!(
79 tool_use_ids = ?ids,
80 "removing orphaned leading tool_result message from restored history"
81 );
82 messages.remove(0);
83 removed += 1;
84 continue;
85 }
86
87 break;
88 }
89
90 removed += strip_mid_history_orphans(messages);
93
94 removed
95}
96
97#[allow(clippy::too_many_lines)]
109fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> usize {
110 let mut removed = 0;
111 let mut i = 0;
112 while i < messages.len() {
113 if messages[i].role == Role::Assistant
117 && messages[i]
118 .parts
119 .iter()
120 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
121 {
122 let matched_ids: HashSet<String> = messages
124 .get(i + 1)
125 .filter(|next| next.role == Role::User)
126 .map(|next| {
127 messages[i]
128 .parts
129 .iter()
130 .filter_map(|p| {
131 if let MessagePart::ToolUse { id, .. } = p {
132 Some(id.clone())
133 } else {
134 None
135 }
136 })
137 .filter(|uid| {
138 next.parts.iter().any(|np| {
139 matches!(np, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)
140 })
141 })
142 .collect()
143 })
144 .unwrap_or_default();
145
146 let orphaned_ids: HashSet<String> = messages[i]
148 .parts
149 .iter()
150 .filter_map(|p| {
151 if let MessagePart::ToolUse { id, .. } = p
152 && !matched_ids.contains(id)
153 {
154 return Some(id.clone());
155 }
156 None
157 })
158 .collect();
159
160 if !orphaned_ids.is_empty() {
161 tracing::warn!(
162 tool_ids = ?orphaned_ids,
163 index = i,
164 "stripping orphaned mid-history tool_use parts from assistant message"
165 );
166 messages[i].parts.retain(|p| {
167 !matches!(
168 p,
169 MessagePart::ToolUse { id, .. } if orphaned_ids.contains(id)
170 )
171 });
172
173 let is_empty =
174 messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
175 if is_empty {
176 messages.remove(i);
177 removed += 1;
178 continue;
180 }
181 }
182 }
183
184 if messages[i].role == Role::User
186 && messages[i]
187 .parts
188 .iter()
189 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
190 {
191 let preceding_tool_use_ids: HashSet<&str> =
193 if i > 0 && messages[i - 1].role == Role::Assistant {
194 messages[i - 1]
195 .parts
196 .iter()
197 .filter_map(|p| {
198 if let MessagePart::ToolUse { id, .. } = p {
199 Some(id.as_str())
200 } else {
201 None
202 }
203 })
204 .collect()
205 } else {
206 HashSet::new()
207 };
208
209 let orphaned_ids: HashSet<String> = messages[i]
212 .parts
213 .iter()
214 .filter_map(|p| {
215 if let MessagePart::ToolResult { tool_use_id, .. } = p
216 && !preceding_tool_use_ids.contains(tool_use_id.as_str())
217 {
218 return Some(tool_use_id.clone());
219 }
220 None
221 })
222 .collect();
223
224 if !orphaned_ids.is_empty() {
225 tracing::warn!(
226 tool_use_ids = ?orphaned_ids,
227 index = i,
228 "stripping orphaned mid-history tool_result parts from user message"
229 );
230 messages[i].parts.retain(|p| {
231 !matches!(
232 p,
233 MessagePart::ToolResult { tool_use_id, .. }
234 if orphaned_ids.contains(tool_use_id.as_str())
235 )
236 });
237
238 let is_empty =
239 messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
240 if is_empty {
241 messages.remove(i);
242 removed += 1;
243 continue;
245 }
246 }
247 }
248
249 i += 1;
250 }
251 removed
252}
253
254impl<C: Channel> Agent<C> {
255 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
261 let (Some(memory), Some(cid)) =
262 (&self.memory_state.memory, self.memory_state.conversation_id)
263 else {
264 return Ok(());
265 };
266
267 let history = memory
268 .sqlite()
269 .load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
270 .await?;
271 if !history.is_empty() {
272 let mut loaded = 0;
273 let mut skipped = 0;
274
275 for msg in history {
276 if msg.content.trim().is_empty() && msg.parts.is_empty() {
281 tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
282 skipped += 1;
283 continue;
284 }
285 self.messages.push(msg);
286 loaded += 1;
287 }
288
289 let history_start = self.messages.len() - loaded;
291 let mut restored_slice = self.messages.split_off(history_start);
292 let orphans = sanitize_tool_pairs(&mut restored_slice);
293 skipped += orphans;
294 loaded = loaded.saturating_sub(orphans);
295 self.messages.append(&mut restored_slice);
296
297 tracing::info!("restored {loaded} message(s) from conversation {cid}");
298 if skipped > 0 {
299 tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
300 }
301 }
302
303 if let Ok(count) = memory.message_count(cid).await {
304 let count_u64 = u64::try_from(count).unwrap_or(0);
305 self.update_metrics(|m| {
306 m.sqlite_message_count = count_u64;
307 });
308 }
309
310 if let Ok(count) = memory.unsummarized_message_count(cid).await {
311 self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
312 }
313
314 self.recompute_prompt_tokens();
315 Ok(())
316 }
317
318 pub(crate) async fn persist_message(
324 &mut self,
325 role: Role,
326 content: &str,
327 parts: &[MessagePart],
328 has_injection_flags: bool,
329 ) {
330 let (Some(memory), Some(cid)) =
331 (&self.memory_state.memory, self.memory_state.conversation_id)
332 else {
333 return;
334 };
335
336 let parts_json = if parts.is_empty() {
337 "[]".to_string()
338 } else {
339 serde_json::to_string(parts).unwrap_or_else(|e| {
340 tracing::warn!("failed to serialize message parts, storing empty: {e}");
341 "[]".to_string()
342 })
343 };
344
345 let guard_event = self
349 .security
350 .exfiltration_guard
351 .should_guard_memory_write(has_injection_flags);
352 if let Some(ref event) = guard_event {
353 tracing::warn!(
354 ?event,
355 "exfiltration guard: skipping Qdrant embedding for flagged content"
356 );
357 self.update_metrics(|m| m.exfiltration_memory_guards += 1);
358 self.push_security_event(
359 crate::metrics::SecurityEventCategory::ExfiltrationBlock,
360 "memory_write",
361 "Qdrant embedding skipped: flagged content",
362 );
363 }
364
365 let skip_embedding = guard_event.is_some();
366
367 let should_embed = if skip_embedding {
368 false
369 } else {
370 match role {
371 Role::Assistant => {
372 self.memory_state.autosave_assistant
373 && content.len() >= self.memory_state.autosave_min_length
374 }
375 _ => true,
376 }
377 };
378
379 let embedding_stored = if should_embed {
380 match memory
381 .remember_with_parts(cid, role_str(role), content, &parts_json)
382 .await
383 {
384 Ok((_message_id, stored)) => stored,
385 Err(e) => {
386 tracing::error!("failed to persist message: {e:#}");
387 return;
388 }
389 }
390 } else {
391 match memory
392 .save_only(cid, role_str(role), content, &parts_json)
393 .await
394 {
395 Ok(_) => false,
396 Err(e) => {
397 tracing::error!("failed to persist message: {e:#}");
398 return;
399 }
400 }
401 };
402
403 self.memory_state.unsummarized_count += 1;
404
405 self.update_metrics(|m| {
406 m.sqlite_message_count += 1;
407 if embedding_stored {
408 m.embeddings_generated += 1;
409 }
410 });
411
412 self.check_summarization().await;
413
414 self.maybe_spawn_graph_extraction(content, has_injection_flags)
415 .await;
416 }
417
418 async fn maybe_spawn_graph_extraction(&mut self, content: &str, has_injection_flags: bool) {
419 use zeph_memory::semantic::GraphExtractionConfig;
420
421 if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
422 return;
423 }
424
425 if has_injection_flags {
427 tracing::warn!("graph extraction skipped: injection patterns detected in content");
428 return;
429 }
430
431 let extraction_cfg = {
433 let cfg = &self.memory_state.graph_config;
434 if !cfg.enabled {
435 return;
436 }
437 GraphExtractionConfig {
438 max_entities: cfg.max_entities_per_message,
439 max_edges: cfg.max_edges_per_message,
440 extraction_timeout_secs: cfg.extraction_timeout_secs,
441 community_refresh_interval: cfg.community_refresh_interval,
442 expired_edge_retention_days: cfg.expired_edge_retention_days,
443 max_entities_cap: cfg.max_entities,
444 community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
445 community_summary_concurrency: cfg.community_summary_concurrency,
446 lpa_edge_chunk_size: cfg.lpa_edge_chunk_size,
447 }
448 };
449
450 let context_messages: Vec<String> = self
452 .messages
453 .iter()
454 .rev()
455 .filter(|m| m.role == Role::User)
456 .take(4)
457 .map(|m| m.content.clone())
458 .collect();
459
460 let _ = self.channel.send_status("extracting graph...").await;
461
462 if let Some(memory) = &self.memory_state.memory {
463 memory.spawn_graph_extraction(content.to_owned(), context_messages, extraction_cfg);
464 }
465 self.sync_community_detection_failures();
466 self.sync_graph_extraction_metrics();
467 self.sync_graph_counts().await;
468 }
469
470 pub(crate) async fn check_summarization(&mut self) {
471 let (Some(memory), Some(cid)) =
472 (&self.memory_state.memory, self.memory_state.conversation_id)
473 else {
474 return;
475 };
476
477 if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
478 let _ = self.channel.send_status("summarizing...").await;
479 let batch_size = self.memory_state.summarization_threshold / 2;
480 match memory.summarize(cid, batch_size).await {
481 Ok(Some(summary_id)) => {
482 tracing::info!("created summary {summary_id} for conversation {cid}");
483 self.memory_state.unsummarized_count = 0;
484 self.update_metrics(|m| {
485 m.summaries_count += 1;
486 });
487 }
488 Ok(None) => {
489 tracing::debug!("no summarization needed");
490 }
491 Err(e) => {
492 tracing::error!("summarization failed: {e:#}");
493 }
494 }
495 let _ = self.channel.send_status("").await;
496 }
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::super::agent_tests::{
503 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
504 };
505 use super::*;
506 use zeph_llm::any::AnyProvider;
507 use zeph_memory::semantic::SemanticMemory;
508
509 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
510 SemanticMemory::new(
511 ":memory:",
512 "http://127.0.0.1:1",
513 provider.clone(),
514 "test-model",
515 )
516 .await
517 .unwrap()
518 }
519
520 #[tokio::test]
521 async fn load_history_without_memory_returns_ok() {
522 let provider = mock_provider(vec![]);
523 let channel = MockChannel::new(vec![]);
524 let registry = create_test_registry();
525 let executor = MockToolExecutor::no_tools();
526 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
527
528 let result = agent.load_history().await;
529 assert!(result.is_ok());
530 assert_eq!(agent.messages.len(), 1); }
533
534 #[tokio::test]
535 async fn load_history_with_messages_injects_into_agent() {
536 let provider = mock_provider(vec![]);
537 let channel = MockChannel::new(vec![]);
538 let registry = create_test_registry();
539 let executor = MockToolExecutor::no_tools();
540
541 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
542 let cid = memory.sqlite().create_conversation().await.unwrap();
543
544 memory
545 .sqlite()
546 .save_message(cid, "user", "hello from history")
547 .await
548 .unwrap();
549 memory
550 .sqlite()
551 .save_message(cid, "assistant", "hi back")
552 .await
553 .unwrap();
554
555 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
556 std::sync::Arc::new(memory),
557 cid,
558 50,
559 5,
560 100,
561 );
562
563 let messages_before = agent.messages.len();
564 agent.load_history().await.unwrap();
565 assert_eq!(agent.messages.len(), messages_before + 2);
567 }
568
569 #[tokio::test]
570 async fn load_history_skips_empty_messages() {
571 let provider = mock_provider(vec![]);
572 let channel = MockChannel::new(vec![]);
573 let registry = create_test_registry();
574 let executor = MockToolExecutor::no_tools();
575
576 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
577 let cid = memory.sqlite().create_conversation().await.unwrap();
578
579 memory
581 .sqlite()
582 .save_message(cid, "user", " ")
583 .await
584 .unwrap();
585 memory
586 .sqlite()
587 .save_message(cid, "user", "real message")
588 .await
589 .unwrap();
590
591 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
592 std::sync::Arc::new(memory),
593 cid,
594 50,
595 5,
596 100,
597 );
598
599 let messages_before = agent.messages.len();
600 agent.load_history().await.unwrap();
601 assert_eq!(agent.messages.len(), messages_before + 1);
603 }
604
605 #[tokio::test]
606 async fn load_history_with_empty_store_returns_ok() {
607 let provider = mock_provider(vec![]);
608 let channel = MockChannel::new(vec![]);
609 let registry = create_test_registry();
610 let executor = MockToolExecutor::no_tools();
611
612 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
613 let cid = memory.sqlite().create_conversation().await.unwrap();
614
615 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
616 std::sync::Arc::new(memory),
617 cid,
618 50,
619 5,
620 100,
621 );
622
623 let messages_before = agent.messages.len();
624 agent.load_history().await.unwrap();
625 assert_eq!(agent.messages.len(), messages_before);
627 }
628
629 #[tokio::test]
630 async fn persist_message_without_memory_silently_returns() {
631 let provider = mock_provider(vec![]);
633 let channel = MockChannel::new(vec![]);
634 let registry = create_test_registry();
635 let executor = MockToolExecutor::no_tools();
636 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
637
638 agent.persist_message(Role::User, "hello", &[], false).await;
640 }
641
642 #[tokio::test]
643 async fn persist_message_assistant_autosave_false_uses_save_only() {
644 let provider = mock_provider(vec![]);
645 let channel = MockChannel::new(vec![]);
646 let registry = create_test_registry();
647 let executor = MockToolExecutor::no_tools();
648
649 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
650 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
651 let cid = memory.sqlite().create_conversation().await.unwrap();
652
653 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
654 .with_metrics(tx)
655 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
656 .with_autosave_config(false, 20);
657
658 agent
659 .persist_message(Role::Assistant, "short assistant reply", &[], false)
660 .await;
661
662 let history = agent
663 .memory_state
664 .memory
665 .as_ref()
666 .unwrap()
667 .sqlite()
668 .load_history(cid, 50)
669 .await
670 .unwrap();
671 assert_eq!(history.len(), 1, "message must be saved");
672 assert_eq!(history[0].content, "short assistant reply");
673 assert_eq!(rx.borrow().embeddings_generated, 0);
675 }
676
677 #[tokio::test]
678 async fn persist_message_assistant_below_min_length_uses_save_only() {
679 let provider = mock_provider(vec![]);
680 let channel = MockChannel::new(vec![]);
681 let registry = create_test_registry();
682 let executor = MockToolExecutor::no_tools();
683
684 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
685 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
686 let cid = memory.sqlite().create_conversation().await.unwrap();
687
688 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
690 .with_metrics(tx)
691 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
692 .with_autosave_config(true, 1000);
693
694 agent
695 .persist_message(Role::Assistant, "too short", &[], false)
696 .await;
697
698 let history = agent
699 .memory_state
700 .memory
701 .as_ref()
702 .unwrap()
703 .sqlite()
704 .load_history(cid, 50)
705 .await
706 .unwrap();
707 assert_eq!(history.len(), 1, "message must be saved");
708 assert_eq!(history[0].content, "too short");
709 assert_eq!(rx.borrow().embeddings_generated, 0);
710 }
711
712 #[tokio::test]
713 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
714 let provider = mock_provider(vec![]);
716 let channel = MockChannel::new(vec![]);
717 let registry = create_test_registry();
718 let executor = MockToolExecutor::no_tools();
719
720 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
721 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
722 let cid = memory.sqlite().create_conversation().await.unwrap();
723
724 let min_length = 10usize;
725 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
726 .with_metrics(tx)
727 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
728 .with_autosave_config(true, min_length);
729
730 let content_at_boundary = "A".repeat(min_length);
732 assert_eq!(content_at_boundary.len(), min_length);
733 agent
734 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
735 .await;
736
737 assert_eq!(rx.borrow().sqlite_message_count, 1);
739 }
740
741 #[tokio::test]
742 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
743 let provider = mock_provider(vec![]);
745 let channel = MockChannel::new(vec![]);
746 let registry = create_test_registry();
747 let executor = MockToolExecutor::no_tools();
748
749 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
750 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
751 let cid = memory.sqlite().create_conversation().await.unwrap();
752
753 let min_length = 10usize;
754 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
755 .with_metrics(tx)
756 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
757 .with_autosave_config(true, min_length);
758
759 let content_below_boundary = "A".repeat(min_length - 1);
761 assert_eq!(content_below_boundary.len(), min_length - 1);
762 agent
763 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
764 .await;
765
766 let history = agent
767 .memory_state
768 .memory
769 .as_ref()
770 .unwrap()
771 .sqlite()
772 .load_history(cid, 50)
773 .await
774 .unwrap();
775 assert_eq!(history.len(), 1, "message must still be saved");
776 assert_eq!(rx.borrow().embeddings_generated, 0);
778 }
779
780 #[tokio::test]
781 async fn persist_message_increments_unsummarized_count() {
782 let provider = mock_provider(vec![]);
783 let channel = MockChannel::new(vec![]);
784 let registry = create_test_registry();
785 let executor = MockToolExecutor::no_tools();
786
787 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
788 let cid = memory.sqlite().create_conversation().await.unwrap();
789
790 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
792 std::sync::Arc::new(memory),
793 cid,
794 50,
795 5,
796 100,
797 );
798
799 assert_eq!(agent.memory_state.unsummarized_count, 0);
800
801 agent.persist_message(Role::User, "first", &[], false).await;
802 assert_eq!(agent.memory_state.unsummarized_count, 1);
803
804 agent
805 .persist_message(Role::User, "second", &[], false)
806 .await;
807 assert_eq!(agent.memory_state.unsummarized_count, 2);
808 }
809
810 #[tokio::test]
811 async fn check_summarization_resets_counter_on_success() {
812 let provider = mock_provider(vec![]);
813 let channel = MockChannel::new(vec![]);
814 let registry = create_test_registry();
815 let executor = MockToolExecutor::no_tools();
816
817 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
818 let cid = memory.sqlite().create_conversation().await.unwrap();
819
820 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
822 std::sync::Arc::new(memory),
823 cid,
824 50,
825 5,
826 1,
827 );
828
829 agent.persist_message(Role::User, "msg1", &[], false).await;
830 agent.persist_message(Role::User, "msg2", &[], false).await;
831
832 assert!(agent.memory_state.unsummarized_count <= 2);
837 }
838
839 #[tokio::test]
840 async fn unsummarized_count_not_incremented_without_memory() {
841 let provider = mock_provider(vec![]);
842 let channel = MockChannel::new(vec![]);
843 let registry = create_test_registry();
844 let executor = MockToolExecutor::no_tools();
845 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
846
847 agent.persist_message(Role::User, "hello", &[], false).await;
848 assert_eq!(agent.memory_state.unsummarized_count, 0);
850 }
851
852 mod graph_extraction_guards {
854 use super::*;
855 use crate::config::GraphConfig;
856 use zeph_memory::graph::GraphStore;
857
858 fn enabled_graph_config() -> GraphConfig {
859 GraphConfig {
860 enabled: true,
861 ..GraphConfig::default()
862 }
863 }
864
865 async fn agent_with_graph(
866 provider: &AnyProvider,
867 config: GraphConfig,
868 ) -> Agent<MockChannel> {
869 let memory =
870 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
871 let cid = memory.sqlite().create_conversation().await.unwrap();
872 Agent::new(
873 provider.clone(),
874 MockChannel::new(vec![]),
875 create_test_registry(),
876 None,
877 5,
878 MockToolExecutor::no_tools(),
879 )
880 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
881 .with_graph_config(config)
882 }
883
884 #[tokio::test]
885 async fn injection_flag_guard_skips_extraction() {
886 let provider = mock_provider(vec![]);
888 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
889 let pool = agent
890 .memory_state
891 .memory
892 .as_ref()
893 .unwrap()
894 .sqlite()
895 .pool()
896 .clone();
897
898 agent.maybe_spawn_graph_extraction("I use Rust", true).await;
899
900 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
902
903 let store = GraphStore::new(pool);
904 let count = store.get_metadata("extraction_count").await.unwrap();
905 assert!(
906 count.is_none(),
907 "injection flag must prevent extraction_count from being written"
908 );
909 }
910
911 #[tokio::test]
912 async fn disabled_config_guard_skips_extraction() {
913 let provider = mock_provider(vec![]);
915 let disabled_cfg = GraphConfig {
916 enabled: false,
917 ..GraphConfig::default()
918 };
919 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
920 let pool = agent
921 .memory_state
922 .memory
923 .as_ref()
924 .unwrap()
925 .sqlite()
926 .pool()
927 .clone();
928
929 agent
930 .maybe_spawn_graph_extraction("I use Rust", false)
931 .await;
932
933 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
934
935 let store = GraphStore::new(pool);
936 let count = store.get_metadata("extraction_count").await.unwrap();
937 assert!(
938 count.is_none(),
939 "disabled graph config must prevent extraction"
940 );
941 }
942
943 #[tokio::test]
944 async fn happy_path_fires_extraction() {
945 let provider = mock_provider(vec![]);
948 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
949 let pool = agent
950 .memory_state
951 .memory
952 .as_ref()
953 .unwrap()
954 .sqlite()
955 .pool()
956 .clone();
957
958 agent
959 .maybe_spawn_graph_extraction("I use Rust for systems programming", false)
960 .await;
961
962 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
964
965 let store = GraphStore::new(pool);
966 let count = store.get_metadata("extraction_count").await.unwrap();
967 assert!(
968 count.is_some(),
969 "happy-path extraction must increment extraction_count"
970 );
971 }
972 }
973
974 #[tokio::test]
975 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
976 let provider = mock_provider(vec![]);
977 let channel = MockChannel::new(vec![]);
978 let registry = create_test_registry();
979 let executor = MockToolExecutor::no_tools();
980
981 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
982 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
983 let cid = memory.sqlite().create_conversation().await.unwrap();
984
985 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
987 .with_metrics(tx)
988 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
989 .with_autosave_config(false, 20);
990
991 let long_user_msg = "A".repeat(100);
992 agent
993 .persist_message(Role::User, &long_user_msg, &[], false)
994 .await;
995
996 let history = agent
997 .memory_state
998 .memory
999 .as_ref()
1000 .unwrap()
1001 .sqlite()
1002 .load_history(cid, 50)
1003 .await
1004 .unwrap();
1005 assert_eq!(history.len(), 1, "user message must be saved");
1006 assert_eq!(rx.borrow().sqlite_message_count, 1);
1009 }
1010
1011 #[tokio::test]
1015 async fn persist_message_saves_correct_tool_use_parts() {
1016 use zeph_llm::provider::MessagePart;
1017
1018 let provider = mock_provider(vec![]);
1019 let channel = MockChannel::new(vec![]);
1020 let registry = create_test_registry();
1021 let executor = MockToolExecutor::no_tools();
1022
1023 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1024 let cid = memory.sqlite().create_conversation().await.unwrap();
1025
1026 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1027 std::sync::Arc::new(memory),
1028 cid,
1029 50,
1030 5,
1031 100,
1032 );
1033
1034 let parts = vec![MessagePart::ToolUse {
1035 id: "call_abc123".to_string(),
1036 name: "read_file".to_string(),
1037 input: serde_json::json!({"path": "/tmp/test.txt"}),
1038 }];
1039 let content = "[tool_use: read_file(call_abc123)]";
1040
1041 agent
1042 .persist_message(Role::Assistant, content, &parts, false)
1043 .await;
1044
1045 let history = agent
1046 .memory_state
1047 .memory
1048 .as_ref()
1049 .unwrap()
1050 .sqlite()
1051 .load_history(cid, 50)
1052 .await
1053 .unwrap();
1054
1055 assert_eq!(history.len(), 1);
1056 assert_eq!(history[0].role, Role::Assistant);
1057 assert_eq!(history[0].content, content);
1058 assert_eq!(history[0].parts.len(), 1);
1059 match &history[0].parts[0] {
1060 MessagePart::ToolUse { id, name, .. } => {
1061 assert_eq!(id, "call_abc123");
1062 assert_eq!(name, "read_file");
1063 }
1064 other => panic!("expected ToolUse part, got {other:?}"),
1065 }
1066 assert!(
1068 !history[0]
1069 .parts
1070 .iter()
1071 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1072 "assistant message must not contain ToolResult parts"
1073 );
1074 }
1075
1076 #[tokio::test]
1077 async fn persist_message_saves_correct_tool_result_parts() {
1078 use zeph_llm::provider::MessagePart;
1079
1080 let provider = mock_provider(vec![]);
1081 let channel = MockChannel::new(vec![]);
1082 let registry = create_test_registry();
1083 let executor = MockToolExecutor::no_tools();
1084
1085 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1086 let cid = memory.sqlite().create_conversation().await.unwrap();
1087
1088 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1089 std::sync::Arc::new(memory),
1090 cid,
1091 50,
1092 5,
1093 100,
1094 );
1095
1096 let parts = vec![MessagePart::ToolResult {
1097 tool_use_id: "call_abc123".to_string(),
1098 content: "file contents here".to_string(),
1099 is_error: false,
1100 }];
1101 let content = "[tool_result: call_abc123]\nfile contents here";
1102
1103 agent
1104 .persist_message(Role::User, content, &parts, false)
1105 .await;
1106
1107 let history = agent
1108 .memory_state
1109 .memory
1110 .as_ref()
1111 .unwrap()
1112 .sqlite()
1113 .load_history(cid, 50)
1114 .await
1115 .unwrap();
1116
1117 assert_eq!(history.len(), 1);
1118 assert_eq!(history[0].role, Role::User);
1119 assert_eq!(history[0].content, content);
1120 assert_eq!(history[0].parts.len(), 1);
1121 match &history[0].parts[0] {
1122 MessagePart::ToolResult {
1123 tool_use_id,
1124 content: result_content,
1125 is_error,
1126 } => {
1127 assert_eq!(tool_use_id, "call_abc123");
1128 assert_eq!(result_content, "file contents here");
1129 assert!(!is_error);
1130 }
1131 other => panic!("expected ToolResult part, got {other:?}"),
1132 }
1133 assert!(
1135 !history[0]
1136 .parts
1137 .iter()
1138 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1139 "user ToolResult message must not contain ToolUse parts"
1140 );
1141 }
1142
1143 #[tokio::test]
1144 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1145 use zeph_llm::provider::MessagePart;
1146
1147 let provider = mock_provider(vec![]);
1148 let channel = MockChannel::new(vec![]);
1149 let registry = create_test_registry();
1150 let executor = MockToolExecutor::no_tools();
1151
1152 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1153 let cid = memory.sqlite().create_conversation().await.unwrap();
1154
1155 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1156 std::sync::Arc::new(memory),
1157 cid,
1158 50,
1159 5,
1160 100,
1161 );
1162
1163 let assistant_parts = vec![MessagePart::ToolUse {
1165 id: "id_1".to_string(),
1166 name: "list_dir".to_string(),
1167 input: serde_json::json!({"path": "/tmp"}),
1168 }];
1169 agent
1170 .persist_message(
1171 Role::Assistant,
1172 "[tool_use: list_dir(id_1)]",
1173 &assistant_parts,
1174 false,
1175 )
1176 .await;
1177
1178 let user_parts = vec![MessagePart::ToolResult {
1180 tool_use_id: "id_1".to_string(),
1181 content: "file1.txt\nfile2.txt".to_string(),
1182 is_error: false,
1183 }];
1184 agent
1185 .persist_message(
1186 Role::User,
1187 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1188 &user_parts,
1189 false,
1190 )
1191 .await;
1192
1193 let history = agent
1194 .memory_state
1195 .memory
1196 .as_ref()
1197 .unwrap()
1198 .sqlite()
1199 .load_history(cid, 50)
1200 .await
1201 .unwrap();
1202
1203 assert_eq!(history.len(), 2);
1204
1205 assert_eq!(history[0].role, Role::Assistant);
1207 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1208 assert!(
1209 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1210 "first message must be assistant ToolUse"
1211 );
1212
1213 assert_eq!(history[1].role, Role::User);
1215 assert_eq!(
1216 history[1].content,
1217 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1218 );
1219 assert!(
1220 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1221 "second message must be user ToolResult"
1222 );
1223
1224 assert!(
1226 !history[0]
1227 .parts
1228 .iter()
1229 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1230 "assistant message must not have ToolResult parts"
1231 );
1232 assert!(
1233 !history[1]
1234 .parts
1235 .iter()
1236 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1237 "user message must not have ToolUse parts"
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn persist_message_saves_correct_tool_output_parts() {
1243 use zeph_llm::provider::MessagePart;
1244
1245 let provider = mock_provider(vec![]);
1246 let channel = MockChannel::new(vec![]);
1247 let registry = create_test_registry();
1248 let executor = MockToolExecutor::no_tools();
1249
1250 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1251 let cid = memory.sqlite().create_conversation().await.unwrap();
1252
1253 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1254 std::sync::Arc::new(memory),
1255 cid,
1256 50,
1257 5,
1258 100,
1259 );
1260
1261 let parts = vec![MessagePart::ToolOutput {
1262 tool_name: "shell".to_string(),
1263 body: "hello from shell".to_string(),
1264 compacted_at: None,
1265 }];
1266 let content = "[tool: shell]\nhello from shell";
1267
1268 agent
1269 .persist_message(Role::User, content, &parts, false)
1270 .await;
1271
1272 let history = agent
1273 .memory_state
1274 .memory
1275 .as_ref()
1276 .unwrap()
1277 .sqlite()
1278 .load_history(cid, 50)
1279 .await
1280 .unwrap();
1281
1282 assert_eq!(history.len(), 1);
1283 assert_eq!(history[0].role, Role::User);
1284 assert_eq!(history[0].content, content);
1285 assert_eq!(history[0].parts.len(), 1);
1286 match &history[0].parts[0] {
1287 MessagePart::ToolOutput {
1288 tool_name,
1289 body,
1290 compacted_at,
1291 } => {
1292 assert_eq!(tool_name, "shell");
1293 assert_eq!(body, "hello from shell");
1294 assert!(compacted_at.is_none());
1295 }
1296 other => panic!("expected ToolOutput part, got {other:?}"),
1297 }
1298 }
1299
1300 #[tokio::test]
1303 async fn load_history_removes_trailing_orphan_tool_use() {
1304 use zeph_llm::provider::MessagePart;
1305
1306 let provider = mock_provider(vec![]);
1307 let channel = MockChannel::new(vec![]);
1308 let registry = create_test_registry();
1309 let executor = MockToolExecutor::no_tools();
1310
1311 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1312 let cid = memory.sqlite().create_conversation().await.unwrap();
1313 let sqlite = memory.sqlite();
1314
1315 sqlite
1317 .save_message(cid, "user", "do something with a tool")
1318 .await
1319 .unwrap();
1320
1321 let parts = serde_json::to_string(&[MessagePart::ToolUse {
1323 id: "call_orphan".to_string(),
1324 name: "shell".to_string(),
1325 input: serde_json::json!({"command": "ls"}),
1326 }])
1327 .unwrap();
1328 sqlite
1329 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
1330 .await
1331 .unwrap();
1332
1333 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1334 std::sync::Arc::new(memory),
1335 cid,
1336 50,
1337 5,
1338 100,
1339 );
1340
1341 let messages_before = agent.messages.len();
1342 agent.load_history().await.unwrap();
1343
1344 assert_eq!(
1346 agent.messages.len(),
1347 messages_before + 1,
1348 "orphaned trailing tool_use must be removed"
1349 );
1350 assert_eq!(agent.messages.last().unwrap().role, Role::User);
1351 }
1352
1353 #[tokio::test]
1354 async fn load_history_removes_leading_orphan_tool_result() {
1355 use zeph_llm::provider::MessagePart;
1356
1357 let provider = mock_provider(vec![]);
1358 let channel = MockChannel::new(vec![]);
1359 let registry = create_test_registry();
1360 let executor = MockToolExecutor::no_tools();
1361
1362 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1363 let cid = memory.sqlite().create_conversation().await.unwrap();
1364 let sqlite = memory.sqlite();
1365
1366 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1368 tool_use_id: "call_missing".to_string(),
1369 content: "result data".to_string(),
1370 is_error: false,
1371 }])
1372 .unwrap();
1373 sqlite
1374 .save_message_with_parts(
1375 cid,
1376 "user",
1377 "[tool_result: call_missing]\nresult data",
1378 &result_parts,
1379 )
1380 .await
1381 .unwrap();
1382
1383 sqlite
1385 .save_message(cid, "assistant", "here is my response")
1386 .await
1387 .unwrap();
1388
1389 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1390 std::sync::Arc::new(memory),
1391 cid,
1392 50,
1393 5,
1394 100,
1395 );
1396
1397 let messages_before = agent.messages.len();
1398 agent.load_history().await.unwrap();
1399
1400 assert_eq!(
1402 agent.messages.len(),
1403 messages_before + 1,
1404 "orphaned leading tool_result must be removed"
1405 );
1406 assert_eq!(agent.messages.last().unwrap().role, Role::Assistant);
1407 }
1408
1409 #[tokio::test]
1410 async fn load_history_preserves_complete_tool_pairs() {
1411 use zeph_llm::provider::MessagePart;
1412
1413 let provider = mock_provider(vec![]);
1414 let channel = MockChannel::new(vec![]);
1415 let registry = create_test_registry();
1416 let executor = MockToolExecutor::no_tools();
1417
1418 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1419 let cid = memory.sqlite().create_conversation().await.unwrap();
1420 let sqlite = memory.sqlite();
1421
1422 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1424 id: "call_ok".to_string(),
1425 name: "shell".to_string(),
1426 input: serde_json::json!({"command": "pwd"}),
1427 }])
1428 .unwrap();
1429 sqlite
1430 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
1431 .await
1432 .unwrap();
1433
1434 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1435 tool_use_id: "call_ok".to_string(),
1436 content: "/home/user".to_string(),
1437 is_error: false,
1438 }])
1439 .unwrap();
1440 sqlite
1441 .save_message_with_parts(
1442 cid,
1443 "user",
1444 "[tool_result: call_ok]\n/home/user",
1445 &result_parts,
1446 )
1447 .await
1448 .unwrap();
1449
1450 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1451 std::sync::Arc::new(memory),
1452 cid,
1453 50,
1454 5,
1455 100,
1456 );
1457
1458 let messages_before = agent.messages.len();
1459 agent.load_history().await.unwrap();
1460
1461 assert_eq!(
1463 agent.messages.len(),
1464 messages_before + 2,
1465 "complete tool_use/tool_result pair must be preserved"
1466 );
1467 assert_eq!(agent.messages[messages_before].role, Role::Assistant);
1468 assert_eq!(agent.messages[messages_before + 1].role, Role::User);
1469 }
1470
1471 #[tokio::test]
1472 async fn load_history_handles_multiple_trailing_orphans() {
1473 use zeph_llm::provider::MessagePart;
1474
1475 let provider = mock_provider(vec![]);
1476 let channel = MockChannel::new(vec![]);
1477 let registry = create_test_registry();
1478 let executor = MockToolExecutor::no_tools();
1479
1480 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1481 let cid = memory.sqlite().create_conversation().await.unwrap();
1482 let sqlite = memory.sqlite();
1483
1484 sqlite.save_message(cid, "user", "start").await.unwrap();
1486
1487 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
1489 id: "call_1".to_string(),
1490 name: "shell".to_string(),
1491 input: serde_json::json!({}),
1492 }])
1493 .unwrap();
1494 sqlite
1495 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
1496 .await
1497 .unwrap();
1498
1499 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
1501 id: "call_2".to_string(),
1502 name: "read_file".to_string(),
1503 input: serde_json::json!({}),
1504 }])
1505 .unwrap();
1506 sqlite
1507 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
1508 .await
1509 .unwrap();
1510
1511 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1512 std::sync::Arc::new(memory),
1513 cid,
1514 50,
1515 5,
1516 100,
1517 );
1518
1519 let messages_before = agent.messages.len();
1520 agent.load_history().await.unwrap();
1521
1522 assert_eq!(
1524 agent.messages.len(),
1525 messages_before + 1,
1526 "all trailing orphaned tool_use messages must be removed"
1527 );
1528 assert_eq!(agent.messages.last().unwrap().role, Role::User);
1529 }
1530
1531 #[tokio::test]
1532 async fn load_history_no_tool_messages_unchanged() {
1533 let provider = mock_provider(vec![]);
1534 let channel = MockChannel::new(vec![]);
1535 let registry = create_test_registry();
1536 let executor = MockToolExecutor::no_tools();
1537
1538 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1539 let cid = memory.sqlite().create_conversation().await.unwrap();
1540 let sqlite = memory.sqlite();
1541
1542 sqlite.save_message(cid, "user", "hello").await.unwrap();
1543 sqlite
1544 .save_message(cid, "assistant", "hi there")
1545 .await
1546 .unwrap();
1547 sqlite
1548 .save_message(cid, "user", "how are you?")
1549 .await
1550 .unwrap();
1551
1552 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1553 std::sync::Arc::new(memory),
1554 cid,
1555 50,
1556 5,
1557 100,
1558 );
1559
1560 let messages_before = agent.messages.len();
1561 agent.load_history().await.unwrap();
1562
1563 assert_eq!(
1565 agent.messages.len(),
1566 messages_before + 3,
1567 "plain messages without tool parts must pass through unchanged"
1568 );
1569 }
1570
1571 #[tokio::test]
1572 async fn load_history_removes_both_leading_and_trailing_orphans() {
1573 use zeph_llm::provider::MessagePart;
1574
1575 let provider = mock_provider(vec![]);
1576 let channel = MockChannel::new(vec![]);
1577 let registry = create_test_registry();
1578 let executor = MockToolExecutor::no_tools();
1579
1580 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1581 let cid = memory.sqlite().create_conversation().await.unwrap();
1582 let sqlite = memory.sqlite();
1583
1584 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1586 tool_use_id: "call_leading".to_string(),
1587 content: "orphaned result".to_string(),
1588 is_error: false,
1589 }])
1590 .unwrap();
1591 sqlite
1592 .save_message_with_parts(
1593 cid,
1594 "user",
1595 "[tool_result: call_leading]\norphaned result",
1596 &result_parts,
1597 )
1598 .await
1599 .unwrap();
1600
1601 sqlite
1603 .save_message(cid, "user", "what is 2+2?")
1604 .await
1605 .unwrap();
1606 sqlite.save_message(cid, "assistant", "4").await.unwrap();
1607
1608 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1610 id: "call_trailing".to_string(),
1611 name: "shell".to_string(),
1612 input: serde_json::json!({"command": "date"}),
1613 }])
1614 .unwrap();
1615 sqlite
1616 .save_message_with_parts(
1617 cid,
1618 "assistant",
1619 "[tool_use: shell(call_trailing)]",
1620 &use_parts,
1621 )
1622 .await
1623 .unwrap();
1624
1625 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1626 std::sync::Arc::new(memory),
1627 cid,
1628 50,
1629 5,
1630 100,
1631 );
1632
1633 let messages_before = agent.messages.len();
1634 agent.load_history().await.unwrap();
1635
1636 assert_eq!(
1638 agent.messages.len(),
1639 messages_before + 2,
1640 "both leading and trailing orphans must be removed"
1641 );
1642 assert_eq!(agent.messages[messages_before].role, Role::User);
1643 assert_eq!(agent.messages[messages_before].content, "what is 2+2?");
1644 assert_eq!(agent.messages[messages_before + 1].role, Role::Assistant);
1645 assert_eq!(agent.messages[messages_before + 1].content, "4");
1646 }
1647
1648 #[tokio::test]
1653 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
1654 use zeph_llm::provider::MessagePart;
1655
1656 let provider = mock_provider(vec![]);
1657 let channel = MockChannel::new(vec![]);
1658 let registry = create_test_registry();
1659 let executor = MockToolExecutor::no_tools();
1660
1661 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1662 let cid = memory.sqlite().create_conversation().await.unwrap();
1663 let sqlite = memory.sqlite();
1664
1665 sqlite
1667 .save_message(cid, "user", "first question")
1668 .await
1669 .unwrap();
1670 sqlite
1671 .save_message(cid, "assistant", "first answer")
1672 .await
1673 .unwrap();
1674
1675 let use_parts = serde_json::to_string(&[
1679 MessagePart::ToolUse {
1680 id: "call_mid_1".to_string(),
1681 name: "shell".to_string(),
1682 input: serde_json::json!({"command": "ls"}),
1683 },
1684 MessagePart::Text {
1685 text: "Let me check the files.".to_string(),
1686 },
1687 ])
1688 .unwrap();
1689 sqlite
1690 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
1691 .await
1692 .unwrap();
1693
1694 sqlite
1696 .save_message(cid, "user", "second question")
1697 .await
1698 .unwrap();
1699 sqlite
1700 .save_message(cid, "assistant", "second answer")
1701 .await
1702 .unwrap();
1703
1704 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1705 std::sync::Arc::new(memory),
1706 cid,
1707 50,
1708 5,
1709 100,
1710 );
1711
1712 let messages_before = agent.messages.len();
1713 agent.load_history().await.unwrap();
1714
1715 assert_eq!(
1718 agent.messages.len(),
1719 messages_before + 5,
1720 "message count must be 5 (orphan message kept — has text content)"
1721 );
1722
1723 let orphan = &agent.messages[messages_before + 2];
1725 assert_eq!(orphan.role, Role::Assistant);
1726 assert!(
1727 !orphan
1728 .parts
1729 .iter()
1730 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1731 "orphaned ToolUse parts must be stripped from mid-history message"
1732 );
1733 assert!(
1735 orphan.parts.iter().any(
1736 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
1737 ),
1738 "text content of orphaned assistant message must be preserved"
1739 );
1740 }
1741
1742 #[tokio::test]
1747 async fn load_history_keeps_tool_only_user_message() {
1748 use zeph_llm::provider::MessagePart;
1749
1750 let provider = mock_provider(vec![]);
1751 let channel = MockChannel::new(vec![]);
1752 let registry = create_test_registry();
1753 let executor = MockToolExecutor::no_tools();
1754
1755 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1756 let cid = memory.sqlite().create_conversation().await.unwrap();
1757 let sqlite = memory.sqlite();
1758
1759 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1761 id: "call_rc3".to_string(),
1762 name: "memory_save".to_string(),
1763 input: serde_json::json!({"content": "something"}),
1764 }])
1765 .unwrap();
1766 sqlite
1767 .save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
1768 .await
1769 .unwrap();
1770
1771 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1773 tool_use_id: "call_rc3".to_string(),
1774 content: "saved".to_string(),
1775 is_error: false,
1776 }])
1777 .unwrap();
1778 sqlite
1779 .save_message_with_parts(cid, "user", "", &result_parts)
1780 .await
1781 .unwrap();
1782
1783 sqlite.save_message(cid, "assistant", "done").await.unwrap();
1784
1785 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1786 std::sync::Arc::new(memory),
1787 cid,
1788 50,
1789 5,
1790 100,
1791 );
1792
1793 let messages_before = agent.messages.len();
1794 agent.load_history().await.unwrap();
1795
1796 assert_eq!(
1799 agent.messages.len(),
1800 messages_before + 3,
1801 "user message with empty content but ToolResult parts must not be dropped"
1802 );
1803
1804 let user_msg = &agent.messages[messages_before + 1];
1806 assert_eq!(user_msg.role, Role::User);
1807 assert!(
1808 user_msg.parts.iter().any(
1809 |p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
1810 ),
1811 "ToolResult part must be preserved on user message with empty content"
1812 );
1813 }
1814
1815 #[tokio::test]
1819 async fn strip_orphans_removes_orphaned_tool_result() {
1820 use zeph_llm::provider::MessagePart;
1821
1822 let provider = mock_provider(vec![]);
1823 let channel = MockChannel::new(vec![]);
1824 let registry = create_test_registry();
1825 let executor = MockToolExecutor::no_tools();
1826
1827 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1828 let cid = memory.sqlite().create_conversation().await.unwrap();
1829 let sqlite = memory.sqlite();
1830
1831 sqlite.save_message(cid, "user", "hello").await.unwrap();
1833 sqlite.save_message(cid, "assistant", "hi").await.unwrap();
1834
1835 sqlite
1837 .save_message(cid, "assistant", "plain answer")
1838 .await
1839 .unwrap();
1840
1841 let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1843 tool_use_id: "call_nonexistent".to_string(),
1844 content: "stale result".to_string(),
1845 is_error: false,
1846 }])
1847 .unwrap();
1848 sqlite
1849 .save_message_with_parts(
1850 cid,
1851 "user",
1852 "[tool_result: call_nonexistent]\nstale result",
1853 &orphan_result_parts,
1854 )
1855 .await
1856 .unwrap();
1857
1858 sqlite
1859 .save_message(cid, "assistant", "final")
1860 .await
1861 .unwrap();
1862
1863 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1864 std::sync::Arc::new(memory),
1865 cid,
1866 50,
1867 5,
1868 100,
1869 );
1870
1871 let messages_before = agent.messages.len();
1872 agent.load_history().await.unwrap();
1873
1874 let loaded = &agent.messages[messages_before..];
1878 for msg in loaded {
1879 assert!(
1880 !msg.parts.iter().any(|p| matches!(
1881 p,
1882 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
1883 )),
1884 "orphaned ToolResult part must be stripped from history"
1885 );
1886 }
1887 }
1888
1889 #[tokio::test]
1892 async fn strip_orphans_keeps_complete_pair() {
1893 use zeph_llm::provider::MessagePart;
1894
1895 let provider = mock_provider(vec![]);
1896 let channel = MockChannel::new(vec![]);
1897 let registry = create_test_registry();
1898 let executor = MockToolExecutor::no_tools();
1899
1900 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1901 let cid = memory.sqlite().create_conversation().await.unwrap();
1902 let sqlite = memory.sqlite();
1903
1904 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1905 id: "call_valid".to_string(),
1906 name: "shell".to_string(),
1907 input: serde_json::json!({"command": "ls"}),
1908 }])
1909 .unwrap();
1910 sqlite
1911 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
1912 .await
1913 .unwrap();
1914
1915 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1916 tool_use_id: "call_valid".to_string(),
1917 content: "file.rs".to_string(),
1918 is_error: false,
1919 }])
1920 .unwrap();
1921 sqlite
1922 .save_message_with_parts(cid, "user", "", &result_parts)
1923 .await
1924 .unwrap();
1925
1926 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1927 std::sync::Arc::new(memory),
1928 cid,
1929 50,
1930 5,
1931 100,
1932 );
1933
1934 let messages_before = agent.messages.len();
1935 agent.load_history().await.unwrap();
1936
1937 assert_eq!(
1938 agent.messages.len(),
1939 messages_before + 2,
1940 "complete tool_use/tool_result pair must be preserved"
1941 );
1942
1943 let user_msg = &agent.messages[messages_before + 1];
1944 assert!(
1945 user_msg.parts.iter().any(|p| matches!(
1946 p,
1947 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
1948 )),
1949 "ToolResult part for a matched tool_use must not be stripped"
1950 );
1951 }
1952
1953 #[tokio::test]
1956 async fn strip_orphans_mixed_history() {
1957 use zeph_llm::provider::MessagePart;
1958
1959 let provider = mock_provider(vec![]);
1960 let channel = MockChannel::new(vec![]);
1961 let registry = create_test_registry();
1962 let executor = MockToolExecutor::no_tools();
1963
1964 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1965 let cid = memory.sqlite().create_conversation().await.unwrap();
1966 let sqlite = memory.sqlite();
1967
1968 let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
1970 id: "call_good".to_string(),
1971 name: "shell".to_string(),
1972 input: serde_json::json!({"command": "pwd"}),
1973 }])
1974 .unwrap();
1975 sqlite
1976 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
1977 .await
1978 .unwrap();
1979
1980 let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
1981 tool_use_id: "call_good".to_string(),
1982 content: "/home".to_string(),
1983 is_error: false,
1984 }])
1985 .unwrap();
1986 sqlite
1987 .save_message_with_parts(cid, "user", "", &result_parts_ok)
1988 .await
1989 .unwrap();
1990
1991 sqlite
1993 .save_message(cid, "assistant", "text only")
1994 .await
1995 .unwrap();
1996
1997 let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
1998 tool_use_id: "call_ghost".to_string(),
1999 content: "ghost result".to_string(),
2000 is_error: false,
2001 }])
2002 .unwrap();
2003 sqlite
2004 .save_message_with_parts(
2005 cid,
2006 "user",
2007 "[tool_result: call_ghost]\nghost result",
2008 &orphan_parts,
2009 )
2010 .await
2011 .unwrap();
2012
2013 sqlite
2014 .save_message(cid, "assistant", "final reply")
2015 .await
2016 .unwrap();
2017
2018 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2019 std::sync::Arc::new(memory),
2020 cid,
2021 50,
2022 5,
2023 100,
2024 );
2025
2026 let messages_before = agent.messages.len();
2027 agent.load_history().await.unwrap();
2028
2029 let loaded = &agent.messages[messages_before..];
2030
2031 for msg in loaded {
2033 assert!(
2034 !msg.parts.iter().any(|p| matches!(
2035 p,
2036 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
2037 )),
2038 "orphaned ToolResult (call_ghost) must be stripped from history"
2039 );
2040 }
2041
2042 let has_good_result = loaded.iter().any(|msg| {
2045 msg.role == Role::User
2046 && msg.parts.iter().any(|p| {
2047 matches!(
2048 p,
2049 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
2050 )
2051 })
2052 });
2053 assert!(
2054 has_good_result,
2055 "matched ToolResult (call_good) must be preserved in history"
2056 );
2057 }
2058
2059 #[tokio::test]
2062 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
2063 use zeph_llm::provider::MessagePart;
2064
2065 let provider = mock_provider(vec![]);
2066 let channel = MockChannel::new(vec![]);
2067 let registry = create_test_registry();
2068 let executor = MockToolExecutor::no_tools();
2069
2070 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2071 let cid = memory.sqlite().create_conversation().await.unwrap();
2072 let sqlite = memory.sqlite();
2073
2074 sqlite
2075 .save_message(cid, "user", "run a command")
2076 .await
2077 .unwrap();
2078
2079 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2081 id: "call_ok".to_string(),
2082 name: "shell".to_string(),
2083 input: serde_json::json!({"command": "echo hi"}),
2084 }])
2085 .unwrap();
2086 sqlite
2087 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2088 .await
2089 .unwrap();
2090
2091 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2093 tool_use_id: "call_ok".to_string(),
2094 content: "hi".to_string(),
2095 is_error: false,
2096 }])
2097 .unwrap();
2098 sqlite
2099 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
2100 .await
2101 .unwrap();
2102
2103 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2104
2105 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2106 std::sync::Arc::new(memory),
2107 cid,
2108 50,
2109 5,
2110 100,
2111 );
2112
2113 let messages_before = agent.messages.len();
2114 agent.load_history().await.unwrap();
2115
2116 assert_eq!(
2118 agent.messages.len(),
2119 messages_before + 4,
2120 "matched tool pair must not be removed"
2121 );
2122 let tool_msg = &agent.messages[messages_before + 1];
2123 assert!(
2124 tool_msg
2125 .parts
2126 .iter()
2127 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
2128 "matched ToolUse parts must be preserved"
2129 );
2130 }
2131
2132 #[tokio::test]
2136 async fn persist_cancelled_tool_results_pairs_tool_use() {
2137 use zeph_llm::provider::MessagePart;
2138
2139 let provider = mock_provider(vec![]);
2140 let channel = MockChannel::new(vec![]);
2141 let registry = create_test_registry();
2142 let executor = MockToolExecutor::no_tools();
2143
2144 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2145 let cid = memory.sqlite().create_conversation().await.unwrap();
2146
2147 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2148 std::sync::Arc::new(memory),
2149 cid,
2150 50,
2151 5,
2152 100,
2153 );
2154
2155 let tool_calls = vec![
2157 zeph_llm::provider::ToolUseRequest {
2158 id: "cancel_id_1".to_string(),
2159 name: "shell".to_string(),
2160 input: serde_json::json!({}),
2161 },
2162 zeph_llm::provider::ToolUseRequest {
2163 id: "cancel_id_2".to_string(),
2164 name: "read_file".to_string(),
2165 input: serde_json::json!({}),
2166 },
2167 ];
2168
2169 agent.persist_cancelled_tool_results(&tool_calls).await;
2170
2171 let history = agent
2172 .memory_state
2173 .memory
2174 .as_ref()
2175 .unwrap()
2176 .sqlite()
2177 .load_history(cid, 50)
2178 .await
2179 .unwrap();
2180
2181 assert_eq!(history.len(), 1);
2183 assert_eq!(history[0].role, Role::User);
2184
2185 for tc in &tool_calls {
2187 assert!(
2188 history[0].parts.iter().any(|p| matches!(
2189 p,
2190 MessagePart::ToolResult { tool_use_id, is_error, .. }
2191 if tool_use_id == &tc.id && *is_error
2192 )),
2193 "tombstone ToolResult for {} must be present and is_error=true",
2194 tc.id
2195 );
2196 }
2197 }
2198}