1use std::collections::HashSet;
5
6use crate::channel::Channel;
7use zeph_llm::provider::{Message, MessagePart, Role};
8use zeph_memory::sqlite::role_str;
9
10use super::Agent;
11
12fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> usize {
28 let mut removed = 0;
29
30 loop {
31 if let Some(last) = messages.last()
33 && last.role == Role::Assistant
34 && last
35 .parts
36 .iter()
37 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
38 {
39 let ids: Vec<String> = last
40 .parts
41 .iter()
42 .filter_map(|p| {
43 if let MessagePart::ToolUse { id, .. } = p {
44 Some(id.clone())
45 } else {
46 None
47 }
48 })
49 .collect();
50 tracing::warn!(
51 tool_ids = ?ids,
52 "removing orphaned trailing tool_use message from restored history"
53 );
54 messages.pop();
55 removed += 1;
56 continue;
57 }
58
59 if let Some(first) = messages.first()
61 && first.role == Role::User
62 && first
63 .parts
64 .iter()
65 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
66 {
67 let ids: Vec<String> = first
68 .parts
69 .iter()
70 .filter_map(|p| {
71 if let MessagePart::ToolResult { tool_use_id, .. } = p {
72 Some(tool_use_id.clone())
73 } else {
74 None
75 }
76 })
77 .collect();
78 tracing::warn!(
79 tool_use_ids = ?ids,
80 "removing orphaned leading tool_result message from restored history"
81 );
82 messages.remove(0);
83 removed += 1;
84 continue;
85 }
86
87 break;
88 }
89
90 removed += strip_mid_history_orphans(messages);
93
94 removed
95}
96
97fn orphaned_tool_use_ids(msg: &Message, next_msg: Option<&Message>) -> HashSet<String> {
110 let matched: HashSet<String> = next_msg
111 .filter(|n| n.role == Role::User)
112 .map(|n| {
113 msg.parts
114 .iter()
115 .filter_map(|p| if let MessagePart::ToolUse { id, .. } = p { Some(id.clone()) } else { None })
116 .filter(|uid| n.parts.iter().any(|np| matches!(np, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)))
117 .collect()
118 })
119 .unwrap_or_default();
120 msg.parts
121 .iter()
122 .filter_map(|p| {
123 if let MessagePart::ToolUse { id, .. } = p
124 && !matched.contains(id)
125 {
126 Some(id.clone())
127 } else {
128 None
129 }
130 })
131 .collect()
132}
133
134fn orphaned_tool_result_ids(msg: &Message, prev_msg: Option<&Message>) -> HashSet<String> {
136 let avail: HashSet<&str> = prev_msg
137 .filter(|p| p.role == Role::Assistant)
138 .map(|p| {
139 p.parts
140 .iter()
141 .filter_map(|part| {
142 if let MessagePart::ToolUse { id, .. } = part {
143 Some(id.as_str())
144 } else {
145 None
146 }
147 })
148 .collect()
149 })
150 .unwrap_or_default();
151 msg.parts
152 .iter()
153 .filter_map(|p| {
154 if let MessagePart::ToolResult { tool_use_id, .. } = p
155 && !avail.contains(tool_use_id.as_str())
156 {
157 Some(tool_use_id.clone())
158 } else {
159 None
160 }
161 })
162 .collect()
163}
164
165fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> usize {
166 let mut removed = 0;
167 let mut i = 0;
168 while i < messages.len() {
169 if messages[i].role == Role::Assistant
173 && messages[i]
174 .parts
175 .iter()
176 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
177 {
178 let orphaned_ids = orphaned_tool_use_ids(&messages[i], messages.get(i + 1));
179 if !orphaned_ids.is_empty() {
180 tracing::warn!(
181 tool_ids = ?orphaned_ids,
182 index = i,
183 "stripping orphaned mid-history tool_use parts from assistant message"
184 );
185 messages[i].parts.retain(
186 |p| !matches!(p, MessagePart::ToolUse { id, .. } if orphaned_ids.contains(id)),
187 );
188 let is_empty =
189 messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
190 if is_empty {
191 messages.remove(i);
192 removed += 1;
193 continue; }
195 }
196 }
197
198 if messages[i].role == Role::User
200 && messages[i]
201 .parts
202 .iter()
203 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
204 {
205 let orphaned_ids = orphaned_tool_result_ids(
206 &messages[i],
207 if i > 0 { messages.get(i - 1) } else { None },
208 );
209 if !orphaned_ids.is_empty() {
210 tracing::warn!(
211 tool_use_ids = ?orphaned_ids,
212 index = i,
213 "stripping orphaned mid-history tool_result parts from user message"
214 );
215 messages[i].parts.retain(|p| {
216 !matches!(p, MessagePart::ToolResult { tool_use_id, .. } if orphaned_ids.contains(tool_use_id.as_str()))
217 });
218
219 let is_empty =
220 messages[i].content.trim().is_empty() && messages[i].parts.is_empty();
221 if is_empty {
222 messages.remove(i);
223 removed += 1;
224 continue;
226 }
227 }
228 }
229
230 i += 1;
231 }
232 removed
233}
234
235impl<C: Channel> Agent<C> {
236 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
242 let (Some(memory), Some(cid)) =
243 (&self.memory_state.memory, self.memory_state.conversation_id)
244 else {
245 return Ok(());
246 };
247
248 let history = memory
249 .sqlite()
250 .load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
251 .await?;
252 if !history.is_empty() {
253 let mut loaded = 0;
254 let mut skipped = 0;
255
256 for msg in history {
257 if msg.content.trim().is_empty() && msg.parts.is_empty() {
262 tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
263 skipped += 1;
264 continue;
265 }
266 self.messages.push(msg);
267 loaded += 1;
268 }
269
270 let history_start = self.messages.len() - loaded;
272 let mut restored_slice = self.messages.split_off(history_start);
273 let orphans = sanitize_tool_pairs(&mut restored_slice);
274 skipped += orphans;
275 loaded = loaded.saturating_sub(orphans);
276 self.messages.append(&mut restored_slice);
277
278 tracing::info!("restored {loaded} message(s) from conversation {cid}");
279 if skipped > 0 {
280 tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
281 }
282 }
283
284 if let Ok(count) = memory.message_count(cid).await {
285 let count_u64 = u64::try_from(count).unwrap_or(0);
286 self.update_metrics(|m| {
287 m.sqlite_message_count = count_u64;
288 });
289 }
290
291 if let Ok(count) = memory.unsummarized_message_count(cid).await {
292 self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
293 }
294
295 self.recompute_prompt_tokens();
296 Ok(())
297 }
298
299 pub(crate) async fn persist_message(
305 &mut self,
306 role: Role,
307 content: &str,
308 parts: &[MessagePart],
309 has_injection_flags: bool,
310 ) {
311 let (Some(memory), Some(cid)) =
312 (&self.memory_state.memory, self.memory_state.conversation_id)
313 else {
314 return;
315 };
316
317 let parts_json = if parts.is_empty() {
318 "[]".to_string()
319 } else {
320 serde_json::to_string(parts).unwrap_or_else(|e| {
321 tracing::warn!("failed to serialize message parts, storing empty: {e}");
322 "[]".to_string()
323 })
324 };
325
326 let guard_event = self
330 .security
331 .exfiltration_guard
332 .should_guard_memory_write(has_injection_flags);
333 if let Some(ref event) = guard_event {
334 tracing::warn!(
335 ?event,
336 "exfiltration guard: skipping Qdrant embedding for flagged content"
337 );
338 self.update_metrics(|m| m.exfiltration_memory_guards += 1);
339 self.push_security_event(
340 crate::metrics::SecurityEventCategory::ExfiltrationBlock,
341 "memory_write",
342 "Qdrant embedding skipped: flagged content",
343 );
344 }
345
346 let skip_embedding = guard_event.is_some();
347
348 let should_embed = if skip_embedding {
349 false
350 } else {
351 match role {
352 Role::Assistant => {
353 self.memory_state.autosave_assistant
354 && content.len() >= self.memory_state.autosave_min_length
355 }
356 _ => true,
357 }
358 };
359
360 let embedding_stored = if should_embed {
361 match memory
362 .remember_with_parts(cid, role_str(role), content, &parts_json)
363 .await
364 {
365 Ok((_message_id, stored)) => stored,
366 Err(e) => {
367 tracing::error!("failed to persist message: {e:#}");
368 return;
369 }
370 }
371 } else {
372 match memory
373 .save_only(cid, role_str(role), content, &parts_json)
374 .await
375 {
376 Ok(_) => false,
377 Err(e) => {
378 tracing::error!("failed to persist message: {e:#}");
379 return;
380 }
381 }
382 };
383
384 self.memory_state.unsummarized_count += 1;
385
386 self.update_metrics(|m| {
387 m.sqlite_message_count += 1;
388 if embedding_stored {
389 m.embeddings_generated += 1;
390 }
391 });
392
393 self.check_summarization().await;
394
395 self.maybe_spawn_graph_extraction(content, has_injection_flags)
396 .await;
397 }
398
399 async fn maybe_spawn_graph_extraction(&mut self, content: &str, has_injection_flags: bool) {
400 use zeph_memory::semantic::GraphExtractionConfig;
401
402 if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
403 return;
404 }
405
406 if has_injection_flags {
408 tracing::warn!("graph extraction skipped: injection patterns detected in content");
409 return;
410 }
411
412 let extraction_cfg = {
414 let cfg = &self.memory_state.graph_config;
415 if !cfg.enabled {
416 return;
417 }
418 GraphExtractionConfig {
419 max_entities: cfg.max_entities_per_message,
420 max_edges: cfg.max_edges_per_message,
421 extraction_timeout_secs: cfg.extraction_timeout_secs,
422 community_refresh_interval: cfg.community_refresh_interval,
423 expired_edge_retention_days: cfg.expired_edge_retention_days,
424 max_entities_cap: cfg.max_entities,
425 community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
426 community_summary_concurrency: cfg.community_summary_concurrency,
427 lpa_edge_chunk_size: cfg.lpa_edge_chunk_size,
428 note_linking: zeph_memory::NoteLinkingConfig {
429 enabled: cfg.note_linking.enabled,
430 similarity_threshold: cfg.note_linking.similarity_threshold,
431 top_k: cfg.note_linking.top_k,
432 timeout_secs: cfg.note_linking.timeout_secs,
433 },
434 }
435 };
436
437 let context_messages: Vec<String> = self
439 .messages
440 .iter()
441 .rev()
442 .filter(|m| m.role == Role::User)
443 .take(4)
444 .map(|m| m.content.clone())
445 .collect();
446
447 let _ = self.channel.send_status("extracting graph...").await;
448
449 if let Some(memory) = &self.memory_state.memory {
450 let validator: zeph_memory::semantic::PostExtractValidator =
453 if self.security.memory_validator.is_enabled() {
454 let v = self.security.memory_validator.clone();
455 Some(Box::new(move |result| {
456 v.validate_graph_extraction(result)
457 .map_err(|e| e.to_string())
458 }))
459 } else {
460 None
461 };
462 memory.spawn_graph_extraction(
463 content.to_owned(),
464 context_messages,
465 extraction_cfg,
466 validator,
467 );
468 }
469 self.sync_community_detection_failures();
470 self.sync_graph_extraction_metrics();
471 self.sync_graph_counts().await;
472 }
473
474 pub(crate) async fn check_summarization(&mut self) {
475 let (Some(memory), Some(cid)) =
476 (&self.memory_state.memory, self.memory_state.conversation_id)
477 else {
478 return;
479 };
480
481 if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
482 let _ = self.channel.send_status("summarizing...").await;
483 let batch_size = self.memory_state.summarization_threshold / 2;
484 match memory.summarize(cid, batch_size).await {
485 Ok(Some(summary_id)) => {
486 tracing::info!("created summary {summary_id} for conversation {cid}");
487 self.memory_state.unsummarized_count = 0;
488 self.update_metrics(|m| {
489 m.summaries_count += 1;
490 });
491 }
492 Ok(None) => {
493 tracing::debug!("no summarization needed");
494 }
495 Err(e) => {
496 tracing::error!("summarization failed: {e:#}");
497 }
498 }
499 let _ = self.channel.send_status("").await;
500 }
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::super::agent_tests::{
507 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
508 };
509 use super::*;
510 use zeph_llm::any::AnyProvider;
511 use zeph_memory::semantic::SemanticMemory;
512
513 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
514 SemanticMemory::new(
515 ":memory:",
516 "http://127.0.0.1:1",
517 provider.clone(),
518 "test-model",
519 )
520 .await
521 .unwrap()
522 }
523
524 #[tokio::test]
525 async fn load_history_without_memory_returns_ok() {
526 let provider = mock_provider(vec![]);
527 let channel = MockChannel::new(vec![]);
528 let registry = create_test_registry();
529 let executor = MockToolExecutor::no_tools();
530 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
531
532 let result = agent.load_history().await;
533 assert!(result.is_ok());
534 assert_eq!(agent.messages.len(), 1); }
537
538 #[tokio::test]
539 async fn load_history_with_messages_injects_into_agent() {
540 let provider = mock_provider(vec![]);
541 let channel = MockChannel::new(vec![]);
542 let registry = create_test_registry();
543 let executor = MockToolExecutor::no_tools();
544
545 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
546 let cid = memory.sqlite().create_conversation().await.unwrap();
547
548 memory
549 .sqlite()
550 .save_message(cid, "user", "hello from history")
551 .await
552 .unwrap();
553 memory
554 .sqlite()
555 .save_message(cid, "assistant", "hi back")
556 .await
557 .unwrap();
558
559 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
560 std::sync::Arc::new(memory),
561 cid,
562 50,
563 5,
564 100,
565 );
566
567 let messages_before = agent.messages.len();
568 agent.load_history().await.unwrap();
569 assert_eq!(agent.messages.len(), messages_before + 2);
571 }
572
573 #[tokio::test]
574 async fn load_history_skips_empty_messages() {
575 let provider = mock_provider(vec![]);
576 let channel = MockChannel::new(vec![]);
577 let registry = create_test_registry();
578 let executor = MockToolExecutor::no_tools();
579
580 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
581 let cid = memory.sqlite().create_conversation().await.unwrap();
582
583 memory
585 .sqlite()
586 .save_message(cid, "user", " ")
587 .await
588 .unwrap();
589 memory
590 .sqlite()
591 .save_message(cid, "user", "real message")
592 .await
593 .unwrap();
594
595 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
596 std::sync::Arc::new(memory),
597 cid,
598 50,
599 5,
600 100,
601 );
602
603 let messages_before = agent.messages.len();
604 agent.load_history().await.unwrap();
605 assert_eq!(agent.messages.len(), messages_before + 1);
607 }
608
609 #[tokio::test]
610 async fn load_history_with_empty_store_returns_ok() {
611 let provider = mock_provider(vec![]);
612 let channel = MockChannel::new(vec![]);
613 let registry = create_test_registry();
614 let executor = MockToolExecutor::no_tools();
615
616 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
617 let cid = memory.sqlite().create_conversation().await.unwrap();
618
619 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
620 std::sync::Arc::new(memory),
621 cid,
622 50,
623 5,
624 100,
625 );
626
627 let messages_before = agent.messages.len();
628 agent.load_history().await.unwrap();
629 assert_eq!(agent.messages.len(), messages_before);
631 }
632
633 #[tokio::test]
634 async fn persist_message_without_memory_silently_returns() {
635 let provider = mock_provider(vec![]);
637 let channel = MockChannel::new(vec![]);
638 let registry = create_test_registry();
639 let executor = MockToolExecutor::no_tools();
640 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
641
642 agent.persist_message(Role::User, "hello", &[], false).await;
644 }
645
646 #[tokio::test]
647 async fn persist_message_assistant_autosave_false_uses_save_only() {
648 let provider = mock_provider(vec![]);
649 let channel = MockChannel::new(vec![]);
650 let registry = create_test_registry();
651 let executor = MockToolExecutor::no_tools();
652
653 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
654 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
655 let cid = memory.sqlite().create_conversation().await.unwrap();
656
657 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
658 .with_metrics(tx)
659 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
660 .with_autosave_config(false, 20);
661
662 agent
663 .persist_message(Role::Assistant, "short assistant reply", &[], false)
664 .await;
665
666 let history = agent
667 .memory_state
668 .memory
669 .as_ref()
670 .unwrap()
671 .sqlite()
672 .load_history(cid, 50)
673 .await
674 .unwrap();
675 assert_eq!(history.len(), 1, "message must be saved");
676 assert_eq!(history[0].content, "short assistant reply");
677 assert_eq!(rx.borrow().embeddings_generated, 0);
679 }
680
681 #[tokio::test]
682 async fn persist_message_assistant_below_min_length_uses_save_only() {
683 let provider = mock_provider(vec![]);
684 let channel = MockChannel::new(vec![]);
685 let registry = create_test_registry();
686 let executor = MockToolExecutor::no_tools();
687
688 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
689 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
690 let cid = memory.sqlite().create_conversation().await.unwrap();
691
692 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
694 .with_metrics(tx)
695 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
696 .with_autosave_config(true, 1000);
697
698 agent
699 .persist_message(Role::Assistant, "too short", &[], false)
700 .await;
701
702 let history = agent
703 .memory_state
704 .memory
705 .as_ref()
706 .unwrap()
707 .sqlite()
708 .load_history(cid, 50)
709 .await
710 .unwrap();
711 assert_eq!(history.len(), 1, "message must be saved");
712 assert_eq!(history[0].content, "too short");
713 assert_eq!(rx.borrow().embeddings_generated, 0);
714 }
715
716 #[tokio::test]
717 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
718 let provider = mock_provider(vec![]);
720 let channel = MockChannel::new(vec![]);
721 let registry = create_test_registry();
722 let executor = MockToolExecutor::no_tools();
723
724 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
725 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
726 let cid = memory.sqlite().create_conversation().await.unwrap();
727
728 let min_length = 10usize;
729 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
730 .with_metrics(tx)
731 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
732 .with_autosave_config(true, min_length);
733
734 let content_at_boundary = "A".repeat(min_length);
736 assert_eq!(content_at_boundary.len(), min_length);
737 agent
738 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
739 .await;
740
741 assert_eq!(rx.borrow().sqlite_message_count, 1);
743 }
744
745 #[tokio::test]
746 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
747 let provider = mock_provider(vec![]);
749 let channel = MockChannel::new(vec![]);
750 let registry = create_test_registry();
751 let executor = MockToolExecutor::no_tools();
752
753 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
754 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
755 let cid = memory.sqlite().create_conversation().await.unwrap();
756
757 let min_length = 10usize;
758 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
759 .with_metrics(tx)
760 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
761 .with_autosave_config(true, min_length);
762
763 let content_below_boundary = "A".repeat(min_length - 1);
765 assert_eq!(content_below_boundary.len(), min_length - 1);
766 agent
767 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
768 .await;
769
770 let history = agent
771 .memory_state
772 .memory
773 .as_ref()
774 .unwrap()
775 .sqlite()
776 .load_history(cid, 50)
777 .await
778 .unwrap();
779 assert_eq!(history.len(), 1, "message must still be saved");
780 assert_eq!(rx.borrow().embeddings_generated, 0);
782 }
783
784 #[tokio::test]
785 async fn persist_message_increments_unsummarized_count() {
786 let provider = mock_provider(vec![]);
787 let channel = MockChannel::new(vec![]);
788 let registry = create_test_registry();
789 let executor = MockToolExecutor::no_tools();
790
791 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
792 let cid = memory.sqlite().create_conversation().await.unwrap();
793
794 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
796 std::sync::Arc::new(memory),
797 cid,
798 50,
799 5,
800 100,
801 );
802
803 assert_eq!(agent.memory_state.unsummarized_count, 0);
804
805 agent.persist_message(Role::User, "first", &[], false).await;
806 assert_eq!(agent.memory_state.unsummarized_count, 1);
807
808 agent
809 .persist_message(Role::User, "second", &[], false)
810 .await;
811 assert_eq!(agent.memory_state.unsummarized_count, 2);
812 }
813
814 #[tokio::test]
815 async fn check_summarization_resets_counter_on_success() {
816 let provider = mock_provider(vec![]);
817 let channel = MockChannel::new(vec![]);
818 let registry = create_test_registry();
819 let executor = MockToolExecutor::no_tools();
820
821 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
822 let cid = memory.sqlite().create_conversation().await.unwrap();
823
824 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
826 std::sync::Arc::new(memory),
827 cid,
828 50,
829 5,
830 1,
831 );
832
833 agent.persist_message(Role::User, "msg1", &[], false).await;
834 agent.persist_message(Role::User, "msg2", &[], false).await;
835
836 assert!(agent.memory_state.unsummarized_count <= 2);
841 }
842
843 #[tokio::test]
844 async fn unsummarized_count_not_incremented_without_memory() {
845 let provider = mock_provider(vec![]);
846 let channel = MockChannel::new(vec![]);
847 let registry = create_test_registry();
848 let executor = MockToolExecutor::no_tools();
849 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
850
851 agent.persist_message(Role::User, "hello", &[], false).await;
852 assert_eq!(agent.memory_state.unsummarized_count, 0);
854 }
855
856 mod graph_extraction_guards {
858 use super::*;
859 use crate::config::GraphConfig;
860 use zeph_memory::graph::GraphStore;
861
862 fn enabled_graph_config() -> GraphConfig {
863 GraphConfig {
864 enabled: true,
865 ..GraphConfig::default()
866 }
867 }
868
869 async fn agent_with_graph(
870 provider: &AnyProvider,
871 config: GraphConfig,
872 ) -> Agent<MockChannel> {
873 let memory =
874 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
875 let cid = memory.sqlite().create_conversation().await.unwrap();
876 Agent::new(
877 provider.clone(),
878 MockChannel::new(vec![]),
879 create_test_registry(),
880 None,
881 5,
882 MockToolExecutor::no_tools(),
883 )
884 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
885 .with_graph_config(config)
886 }
887
888 #[tokio::test]
889 async fn injection_flag_guard_skips_extraction() {
890 let provider = mock_provider(vec![]);
892 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
893 let pool = agent
894 .memory_state
895 .memory
896 .as_ref()
897 .unwrap()
898 .sqlite()
899 .pool()
900 .clone();
901
902 agent.maybe_spawn_graph_extraction("I use Rust", true).await;
903
904 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
906
907 let store = GraphStore::new(pool);
908 let count = store.get_metadata("extraction_count").await.unwrap();
909 assert!(
910 count.is_none(),
911 "injection flag must prevent extraction_count from being written"
912 );
913 }
914
915 #[tokio::test]
916 async fn disabled_config_guard_skips_extraction() {
917 let provider = mock_provider(vec![]);
919 let disabled_cfg = GraphConfig {
920 enabled: false,
921 ..GraphConfig::default()
922 };
923 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
924 let pool = agent
925 .memory_state
926 .memory
927 .as_ref()
928 .unwrap()
929 .sqlite()
930 .pool()
931 .clone();
932
933 agent
934 .maybe_spawn_graph_extraction("I use Rust", false)
935 .await;
936
937 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
938
939 let store = GraphStore::new(pool);
940 let count = store.get_metadata("extraction_count").await.unwrap();
941 assert!(
942 count.is_none(),
943 "disabled graph config must prevent extraction"
944 );
945 }
946
947 #[tokio::test]
948 async fn happy_path_fires_extraction() {
949 let provider = mock_provider(vec![]);
952 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
953 let pool = agent
954 .memory_state
955 .memory
956 .as_ref()
957 .unwrap()
958 .sqlite()
959 .pool()
960 .clone();
961
962 agent
963 .maybe_spawn_graph_extraction("I use Rust for systems programming", false)
964 .await;
965
966 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
968
969 let store = GraphStore::new(pool);
970 let count = store.get_metadata("extraction_count").await.unwrap();
971 assert!(
972 count.is_some(),
973 "happy-path extraction must increment extraction_count"
974 );
975 }
976 }
977
978 #[tokio::test]
979 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
980 let provider = mock_provider(vec![]);
981 let channel = MockChannel::new(vec![]);
982 let registry = create_test_registry();
983 let executor = MockToolExecutor::no_tools();
984
985 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
986 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
987 let cid = memory.sqlite().create_conversation().await.unwrap();
988
989 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
991 .with_metrics(tx)
992 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
993 .with_autosave_config(false, 20);
994
995 let long_user_msg = "A".repeat(100);
996 agent
997 .persist_message(Role::User, &long_user_msg, &[], false)
998 .await;
999
1000 let history = agent
1001 .memory_state
1002 .memory
1003 .as_ref()
1004 .unwrap()
1005 .sqlite()
1006 .load_history(cid, 50)
1007 .await
1008 .unwrap();
1009 assert_eq!(history.len(), 1, "user message must be saved");
1010 assert_eq!(rx.borrow().sqlite_message_count, 1);
1013 }
1014
1015 #[tokio::test]
1019 async fn persist_message_saves_correct_tool_use_parts() {
1020 use zeph_llm::provider::MessagePart;
1021
1022 let provider = mock_provider(vec![]);
1023 let channel = MockChannel::new(vec![]);
1024 let registry = create_test_registry();
1025 let executor = MockToolExecutor::no_tools();
1026
1027 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1028 let cid = memory.sqlite().create_conversation().await.unwrap();
1029
1030 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1031 std::sync::Arc::new(memory),
1032 cid,
1033 50,
1034 5,
1035 100,
1036 );
1037
1038 let parts = vec![MessagePart::ToolUse {
1039 id: "call_abc123".to_string(),
1040 name: "read_file".to_string(),
1041 input: serde_json::json!({"path": "/tmp/test.txt"}),
1042 }];
1043 let content = "[tool_use: read_file(call_abc123)]";
1044
1045 agent
1046 .persist_message(Role::Assistant, content, &parts, false)
1047 .await;
1048
1049 let history = agent
1050 .memory_state
1051 .memory
1052 .as_ref()
1053 .unwrap()
1054 .sqlite()
1055 .load_history(cid, 50)
1056 .await
1057 .unwrap();
1058
1059 assert_eq!(history.len(), 1);
1060 assert_eq!(history[0].role, Role::Assistant);
1061 assert_eq!(history[0].content, content);
1062 assert_eq!(history[0].parts.len(), 1);
1063 match &history[0].parts[0] {
1064 MessagePart::ToolUse { id, name, .. } => {
1065 assert_eq!(id, "call_abc123");
1066 assert_eq!(name, "read_file");
1067 }
1068 other => panic!("expected ToolUse part, got {other:?}"),
1069 }
1070 assert!(
1072 !history[0]
1073 .parts
1074 .iter()
1075 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1076 "assistant message must not contain ToolResult parts"
1077 );
1078 }
1079
1080 #[tokio::test]
1081 async fn persist_message_saves_correct_tool_result_parts() {
1082 use zeph_llm::provider::MessagePart;
1083
1084 let provider = mock_provider(vec![]);
1085 let channel = MockChannel::new(vec![]);
1086 let registry = create_test_registry();
1087 let executor = MockToolExecutor::no_tools();
1088
1089 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1090 let cid = memory.sqlite().create_conversation().await.unwrap();
1091
1092 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1093 std::sync::Arc::new(memory),
1094 cid,
1095 50,
1096 5,
1097 100,
1098 );
1099
1100 let parts = vec![MessagePart::ToolResult {
1101 tool_use_id: "call_abc123".to_string(),
1102 content: "file contents here".to_string(),
1103 is_error: false,
1104 }];
1105 let content = "[tool_result: call_abc123]\nfile contents here";
1106
1107 agent
1108 .persist_message(Role::User, content, &parts, false)
1109 .await;
1110
1111 let history = agent
1112 .memory_state
1113 .memory
1114 .as_ref()
1115 .unwrap()
1116 .sqlite()
1117 .load_history(cid, 50)
1118 .await
1119 .unwrap();
1120
1121 assert_eq!(history.len(), 1);
1122 assert_eq!(history[0].role, Role::User);
1123 assert_eq!(history[0].content, content);
1124 assert_eq!(history[0].parts.len(), 1);
1125 match &history[0].parts[0] {
1126 MessagePart::ToolResult {
1127 tool_use_id,
1128 content: result_content,
1129 is_error,
1130 } => {
1131 assert_eq!(tool_use_id, "call_abc123");
1132 assert_eq!(result_content, "file contents here");
1133 assert!(!is_error);
1134 }
1135 other => panic!("expected ToolResult part, got {other:?}"),
1136 }
1137 assert!(
1139 !history[0]
1140 .parts
1141 .iter()
1142 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1143 "user ToolResult message must not contain ToolUse parts"
1144 );
1145 }
1146
1147 #[tokio::test]
1148 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1149 use zeph_llm::provider::MessagePart;
1150
1151 let provider = mock_provider(vec![]);
1152 let channel = MockChannel::new(vec![]);
1153 let registry = create_test_registry();
1154 let executor = MockToolExecutor::no_tools();
1155
1156 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1157 let cid = memory.sqlite().create_conversation().await.unwrap();
1158
1159 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1160 std::sync::Arc::new(memory),
1161 cid,
1162 50,
1163 5,
1164 100,
1165 );
1166
1167 let assistant_parts = vec![MessagePart::ToolUse {
1169 id: "id_1".to_string(),
1170 name: "list_dir".to_string(),
1171 input: serde_json::json!({"path": "/tmp"}),
1172 }];
1173 agent
1174 .persist_message(
1175 Role::Assistant,
1176 "[tool_use: list_dir(id_1)]",
1177 &assistant_parts,
1178 false,
1179 )
1180 .await;
1181
1182 let user_parts = vec![MessagePart::ToolResult {
1184 tool_use_id: "id_1".to_string(),
1185 content: "file1.txt\nfile2.txt".to_string(),
1186 is_error: false,
1187 }];
1188 agent
1189 .persist_message(
1190 Role::User,
1191 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1192 &user_parts,
1193 false,
1194 )
1195 .await;
1196
1197 let history = agent
1198 .memory_state
1199 .memory
1200 .as_ref()
1201 .unwrap()
1202 .sqlite()
1203 .load_history(cid, 50)
1204 .await
1205 .unwrap();
1206
1207 assert_eq!(history.len(), 2);
1208
1209 assert_eq!(history[0].role, Role::Assistant);
1211 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1212 assert!(
1213 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1214 "first message must be assistant ToolUse"
1215 );
1216
1217 assert_eq!(history[1].role, Role::User);
1219 assert_eq!(
1220 history[1].content,
1221 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1222 );
1223 assert!(
1224 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1225 "second message must be user ToolResult"
1226 );
1227
1228 assert!(
1230 !history[0]
1231 .parts
1232 .iter()
1233 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1234 "assistant message must not have ToolResult parts"
1235 );
1236 assert!(
1237 !history[1]
1238 .parts
1239 .iter()
1240 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1241 "user message must not have ToolUse parts"
1242 );
1243 }
1244
1245 #[tokio::test]
1246 async fn persist_message_saves_correct_tool_output_parts() {
1247 use zeph_llm::provider::MessagePart;
1248
1249 let provider = mock_provider(vec![]);
1250 let channel = MockChannel::new(vec![]);
1251 let registry = create_test_registry();
1252 let executor = MockToolExecutor::no_tools();
1253
1254 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1255 let cid = memory.sqlite().create_conversation().await.unwrap();
1256
1257 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1258 std::sync::Arc::new(memory),
1259 cid,
1260 50,
1261 5,
1262 100,
1263 );
1264
1265 let parts = vec![MessagePart::ToolOutput {
1266 tool_name: "shell".to_string(),
1267 body: "hello from shell".to_string(),
1268 compacted_at: None,
1269 }];
1270 let content = "[tool: shell]\nhello from shell";
1271
1272 agent
1273 .persist_message(Role::User, content, &parts, false)
1274 .await;
1275
1276 let history = agent
1277 .memory_state
1278 .memory
1279 .as_ref()
1280 .unwrap()
1281 .sqlite()
1282 .load_history(cid, 50)
1283 .await
1284 .unwrap();
1285
1286 assert_eq!(history.len(), 1);
1287 assert_eq!(history[0].role, Role::User);
1288 assert_eq!(history[0].content, content);
1289 assert_eq!(history[0].parts.len(), 1);
1290 match &history[0].parts[0] {
1291 MessagePart::ToolOutput {
1292 tool_name,
1293 body,
1294 compacted_at,
1295 } => {
1296 assert_eq!(tool_name, "shell");
1297 assert_eq!(body, "hello from shell");
1298 assert!(compacted_at.is_none());
1299 }
1300 other => panic!("expected ToolOutput part, got {other:?}"),
1301 }
1302 }
1303
1304 #[tokio::test]
1307 async fn load_history_removes_trailing_orphan_tool_use() {
1308 use zeph_llm::provider::MessagePart;
1309
1310 let provider = mock_provider(vec![]);
1311 let channel = MockChannel::new(vec![]);
1312 let registry = create_test_registry();
1313 let executor = MockToolExecutor::no_tools();
1314
1315 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1316 let cid = memory.sqlite().create_conversation().await.unwrap();
1317 let sqlite = memory.sqlite();
1318
1319 sqlite
1321 .save_message(cid, "user", "do something with a tool")
1322 .await
1323 .unwrap();
1324
1325 let parts = serde_json::to_string(&[MessagePart::ToolUse {
1327 id: "call_orphan".to_string(),
1328 name: "shell".to_string(),
1329 input: serde_json::json!({"command": "ls"}),
1330 }])
1331 .unwrap();
1332 sqlite
1333 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
1334 .await
1335 .unwrap();
1336
1337 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1338 std::sync::Arc::new(memory),
1339 cid,
1340 50,
1341 5,
1342 100,
1343 );
1344
1345 let messages_before = agent.messages.len();
1346 agent.load_history().await.unwrap();
1347
1348 assert_eq!(
1350 agent.messages.len(),
1351 messages_before + 1,
1352 "orphaned trailing tool_use must be removed"
1353 );
1354 assert_eq!(agent.messages.last().unwrap().role, Role::User);
1355 }
1356
1357 #[tokio::test]
1358 async fn load_history_removes_leading_orphan_tool_result() {
1359 use zeph_llm::provider::MessagePart;
1360
1361 let provider = mock_provider(vec![]);
1362 let channel = MockChannel::new(vec![]);
1363 let registry = create_test_registry();
1364 let executor = MockToolExecutor::no_tools();
1365
1366 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1367 let cid = memory.sqlite().create_conversation().await.unwrap();
1368 let sqlite = memory.sqlite();
1369
1370 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1372 tool_use_id: "call_missing".to_string(),
1373 content: "result data".to_string(),
1374 is_error: false,
1375 }])
1376 .unwrap();
1377 sqlite
1378 .save_message_with_parts(
1379 cid,
1380 "user",
1381 "[tool_result: call_missing]\nresult data",
1382 &result_parts,
1383 )
1384 .await
1385 .unwrap();
1386
1387 sqlite
1389 .save_message(cid, "assistant", "here is my response")
1390 .await
1391 .unwrap();
1392
1393 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1394 std::sync::Arc::new(memory),
1395 cid,
1396 50,
1397 5,
1398 100,
1399 );
1400
1401 let messages_before = agent.messages.len();
1402 agent.load_history().await.unwrap();
1403
1404 assert_eq!(
1406 agent.messages.len(),
1407 messages_before + 1,
1408 "orphaned leading tool_result must be removed"
1409 );
1410 assert_eq!(agent.messages.last().unwrap().role, Role::Assistant);
1411 }
1412
1413 #[tokio::test]
1414 async fn load_history_preserves_complete_tool_pairs() {
1415 use zeph_llm::provider::MessagePart;
1416
1417 let provider = mock_provider(vec![]);
1418 let channel = MockChannel::new(vec![]);
1419 let registry = create_test_registry();
1420 let executor = MockToolExecutor::no_tools();
1421
1422 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1423 let cid = memory.sqlite().create_conversation().await.unwrap();
1424 let sqlite = memory.sqlite();
1425
1426 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1428 id: "call_ok".to_string(),
1429 name: "shell".to_string(),
1430 input: serde_json::json!({"command": "pwd"}),
1431 }])
1432 .unwrap();
1433 sqlite
1434 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
1435 .await
1436 .unwrap();
1437
1438 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1439 tool_use_id: "call_ok".to_string(),
1440 content: "/home/user".to_string(),
1441 is_error: false,
1442 }])
1443 .unwrap();
1444 sqlite
1445 .save_message_with_parts(
1446 cid,
1447 "user",
1448 "[tool_result: call_ok]\n/home/user",
1449 &result_parts,
1450 )
1451 .await
1452 .unwrap();
1453
1454 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1455 std::sync::Arc::new(memory),
1456 cid,
1457 50,
1458 5,
1459 100,
1460 );
1461
1462 let messages_before = agent.messages.len();
1463 agent.load_history().await.unwrap();
1464
1465 assert_eq!(
1467 agent.messages.len(),
1468 messages_before + 2,
1469 "complete tool_use/tool_result pair must be preserved"
1470 );
1471 assert_eq!(agent.messages[messages_before].role, Role::Assistant);
1472 assert_eq!(agent.messages[messages_before + 1].role, Role::User);
1473 }
1474
1475 #[tokio::test]
1476 async fn load_history_handles_multiple_trailing_orphans() {
1477 use zeph_llm::provider::MessagePart;
1478
1479 let provider = mock_provider(vec![]);
1480 let channel = MockChannel::new(vec![]);
1481 let registry = create_test_registry();
1482 let executor = MockToolExecutor::no_tools();
1483
1484 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1485 let cid = memory.sqlite().create_conversation().await.unwrap();
1486 let sqlite = memory.sqlite();
1487
1488 sqlite.save_message(cid, "user", "start").await.unwrap();
1490
1491 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
1493 id: "call_1".to_string(),
1494 name: "shell".to_string(),
1495 input: serde_json::json!({}),
1496 }])
1497 .unwrap();
1498 sqlite
1499 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
1500 .await
1501 .unwrap();
1502
1503 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
1505 id: "call_2".to_string(),
1506 name: "read_file".to_string(),
1507 input: serde_json::json!({}),
1508 }])
1509 .unwrap();
1510 sqlite
1511 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
1512 .await
1513 .unwrap();
1514
1515 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1516 std::sync::Arc::new(memory),
1517 cid,
1518 50,
1519 5,
1520 100,
1521 );
1522
1523 let messages_before = agent.messages.len();
1524 agent.load_history().await.unwrap();
1525
1526 assert_eq!(
1528 agent.messages.len(),
1529 messages_before + 1,
1530 "all trailing orphaned tool_use messages must be removed"
1531 );
1532 assert_eq!(agent.messages.last().unwrap().role, Role::User);
1533 }
1534
1535 #[tokio::test]
1536 async fn load_history_no_tool_messages_unchanged() {
1537 let provider = mock_provider(vec![]);
1538 let channel = MockChannel::new(vec![]);
1539 let registry = create_test_registry();
1540 let executor = MockToolExecutor::no_tools();
1541
1542 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1543 let cid = memory.sqlite().create_conversation().await.unwrap();
1544 let sqlite = memory.sqlite();
1545
1546 sqlite.save_message(cid, "user", "hello").await.unwrap();
1547 sqlite
1548 .save_message(cid, "assistant", "hi there")
1549 .await
1550 .unwrap();
1551 sqlite
1552 .save_message(cid, "user", "how are you?")
1553 .await
1554 .unwrap();
1555
1556 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1557 std::sync::Arc::new(memory),
1558 cid,
1559 50,
1560 5,
1561 100,
1562 );
1563
1564 let messages_before = agent.messages.len();
1565 agent.load_history().await.unwrap();
1566
1567 assert_eq!(
1569 agent.messages.len(),
1570 messages_before + 3,
1571 "plain messages without tool parts must pass through unchanged"
1572 );
1573 }
1574
1575 #[tokio::test]
1576 async fn load_history_removes_both_leading_and_trailing_orphans() {
1577 use zeph_llm::provider::MessagePart;
1578
1579 let provider = mock_provider(vec![]);
1580 let channel = MockChannel::new(vec![]);
1581 let registry = create_test_registry();
1582 let executor = MockToolExecutor::no_tools();
1583
1584 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1585 let cid = memory.sqlite().create_conversation().await.unwrap();
1586 let sqlite = memory.sqlite();
1587
1588 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1590 tool_use_id: "call_leading".to_string(),
1591 content: "orphaned result".to_string(),
1592 is_error: false,
1593 }])
1594 .unwrap();
1595 sqlite
1596 .save_message_with_parts(
1597 cid,
1598 "user",
1599 "[tool_result: call_leading]\norphaned result",
1600 &result_parts,
1601 )
1602 .await
1603 .unwrap();
1604
1605 sqlite
1607 .save_message(cid, "user", "what is 2+2?")
1608 .await
1609 .unwrap();
1610 sqlite.save_message(cid, "assistant", "4").await.unwrap();
1611
1612 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1614 id: "call_trailing".to_string(),
1615 name: "shell".to_string(),
1616 input: serde_json::json!({"command": "date"}),
1617 }])
1618 .unwrap();
1619 sqlite
1620 .save_message_with_parts(
1621 cid,
1622 "assistant",
1623 "[tool_use: shell(call_trailing)]",
1624 &use_parts,
1625 )
1626 .await
1627 .unwrap();
1628
1629 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1630 std::sync::Arc::new(memory),
1631 cid,
1632 50,
1633 5,
1634 100,
1635 );
1636
1637 let messages_before = agent.messages.len();
1638 agent.load_history().await.unwrap();
1639
1640 assert_eq!(
1642 agent.messages.len(),
1643 messages_before + 2,
1644 "both leading and trailing orphans must be removed"
1645 );
1646 assert_eq!(agent.messages[messages_before].role, Role::User);
1647 assert_eq!(agent.messages[messages_before].content, "what is 2+2?");
1648 assert_eq!(agent.messages[messages_before + 1].role, Role::Assistant);
1649 assert_eq!(agent.messages[messages_before + 1].content, "4");
1650 }
1651
1652 #[tokio::test]
1657 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
1658 use zeph_llm::provider::MessagePart;
1659
1660 let provider = mock_provider(vec![]);
1661 let channel = MockChannel::new(vec![]);
1662 let registry = create_test_registry();
1663 let executor = MockToolExecutor::no_tools();
1664
1665 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1666 let cid = memory.sqlite().create_conversation().await.unwrap();
1667 let sqlite = memory.sqlite();
1668
1669 sqlite
1671 .save_message(cid, "user", "first question")
1672 .await
1673 .unwrap();
1674 sqlite
1675 .save_message(cid, "assistant", "first answer")
1676 .await
1677 .unwrap();
1678
1679 let use_parts = serde_json::to_string(&[
1683 MessagePart::ToolUse {
1684 id: "call_mid_1".to_string(),
1685 name: "shell".to_string(),
1686 input: serde_json::json!({"command": "ls"}),
1687 },
1688 MessagePart::Text {
1689 text: "Let me check the files.".to_string(),
1690 },
1691 ])
1692 .unwrap();
1693 sqlite
1694 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
1695 .await
1696 .unwrap();
1697
1698 sqlite
1700 .save_message(cid, "user", "second question")
1701 .await
1702 .unwrap();
1703 sqlite
1704 .save_message(cid, "assistant", "second answer")
1705 .await
1706 .unwrap();
1707
1708 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1709 std::sync::Arc::new(memory),
1710 cid,
1711 50,
1712 5,
1713 100,
1714 );
1715
1716 let messages_before = agent.messages.len();
1717 agent.load_history().await.unwrap();
1718
1719 assert_eq!(
1722 agent.messages.len(),
1723 messages_before + 5,
1724 "message count must be 5 (orphan message kept — has text content)"
1725 );
1726
1727 let orphan = &agent.messages[messages_before + 2];
1729 assert_eq!(orphan.role, Role::Assistant);
1730 assert!(
1731 !orphan
1732 .parts
1733 .iter()
1734 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1735 "orphaned ToolUse parts must be stripped from mid-history message"
1736 );
1737 assert!(
1739 orphan.parts.iter().any(
1740 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
1741 ),
1742 "text content of orphaned assistant message must be preserved"
1743 );
1744 }
1745
1746 #[tokio::test]
1751 async fn load_history_keeps_tool_only_user_message() {
1752 use zeph_llm::provider::MessagePart;
1753
1754 let provider = mock_provider(vec![]);
1755 let channel = MockChannel::new(vec![]);
1756 let registry = create_test_registry();
1757 let executor = MockToolExecutor::no_tools();
1758
1759 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1760 let cid = memory.sqlite().create_conversation().await.unwrap();
1761 let sqlite = memory.sqlite();
1762
1763 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1765 id: "call_rc3".to_string(),
1766 name: "memory_save".to_string(),
1767 input: serde_json::json!({"content": "something"}),
1768 }])
1769 .unwrap();
1770 sqlite
1771 .save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
1772 .await
1773 .unwrap();
1774
1775 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1777 tool_use_id: "call_rc3".to_string(),
1778 content: "saved".to_string(),
1779 is_error: false,
1780 }])
1781 .unwrap();
1782 sqlite
1783 .save_message_with_parts(cid, "user", "", &result_parts)
1784 .await
1785 .unwrap();
1786
1787 sqlite.save_message(cid, "assistant", "done").await.unwrap();
1788
1789 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1790 std::sync::Arc::new(memory),
1791 cid,
1792 50,
1793 5,
1794 100,
1795 );
1796
1797 let messages_before = agent.messages.len();
1798 agent.load_history().await.unwrap();
1799
1800 assert_eq!(
1803 agent.messages.len(),
1804 messages_before + 3,
1805 "user message with empty content but ToolResult parts must not be dropped"
1806 );
1807
1808 let user_msg = &agent.messages[messages_before + 1];
1810 assert_eq!(user_msg.role, Role::User);
1811 assert!(
1812 user_msg.parts.iter().any(
1813 |p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
1814 ),
1815 "ToolResult part must be preserved on user message with empty content"
1816 );
1817 }
1818
1819 #[tokio::test]
1823 async fn strip_orphans_removes_orphaned_tool_result() {
1824 use zeph_llm::provider::MessagePart;
1825
1826 let provider = mock_provider(vec![]);
1827 let channel = MockChannel::new(vec![]);
1828 let registry = create_test_registry();
1829 let executor = MockToolExecutor::no_tools();
1830
1831 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1832 let cid = memory.sqlite().create_conversation().await.unwrap();
1833 let sqlite = memory.sqlite();
1834
1835 sqlite.save_message(cid, "user", "hello").await.unwrap();
1837 sqlite.save_message(cid, "assistant", "hi").await.unwrap();
1838
1839 sqlite
1841 .save_message(cid, "assistant", "plain answer")
1842 .await
1843 .unwrap();
1844
1845 let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1847 tool_use_id: "call_nonexistent".to_string(),
1848 content: "stale result".to_string(),
1849 is_error: false,
1850 }])
1851 .unwrap();
1852 sqlite
1853 .save_message_with_parts(
1854 cid,
1855 "user",
1856 "[tool_result: call_nonexistent]\nstale result",
1857 &orphan_result_parts,
1858 )
1859 .await
1860 .unwrap();
1861
1862 sqlite
1863 .save_message(cid, "assistant", "final")
1864 .await
1865 .unwrap();
1866
1867 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1868 std::sync::Arc::new(memory),
1869 cid,
1870 50,
1871 5,
1872 100,
1873 );
1874
1875 let messages_before = agent.messages.len();
1876 agent.load_history().await.unwrap();
1877
1878 let loaded = &agent.messages[messages_before..];
1882 for msg in loaded {
1883 assert!(
1884 !msg.parts.iter().any(|p| matches!(
1885 p,
1886 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
1887 )),
1888 "orphaned ToolResult part must be stripped from history"
1889 );
1890 }
1891 }
1892
1893 #[tokio::test]
1896 async fn strip_orphans_keeps_complete_pair() {
1897 use zeph_llm::provider::MessagePart;
1898
1899 let provider = mock_provider(vec![]);
1900 let channel = MockChannel::new(vec![]);
1901 let registry = create_test_registry();
1902 let executor = MockToolExecutor::no_tools();
1903
1904 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1905 let cid = memory.sqlite().create_conversation().await.unwrap();
1906 let sqlite = memory.sqlite();
1907
1908 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1909 id: "call_valid".to_string(),
1910 name: "shell".to_string(),
1911 input: serde_json::json!({"command": "ls"}),
1912 }])
1913 .unwrap();
1914 sqlite
1915 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
1916 .await
1917 .unwrap();
1918
1919 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1920 tool_use_id: "call_valid".to_string(),
1921 content: "file.rs".to_string(),
1922 is_error: false,
1923 }])
1924 .unwrap();
1925 sqlite
1926 .save_message_with_parts(cid, "user", "", &result_parts)
1927 .await
1928 .unwrap();
1929
1930 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1931 std::sync::Arc::new(memory),
1932 cid,
1933 50,
1934 5,
1935 100,
1936 );
1937
1938 let messages_before = agent.messages.len();
1939 agent.load_history().await.unwrap();
1940
1941 assert_eq!(
1942 agent.messages.len(),
1943 messages_before + 2,
1944 "complete tool_use/tool_result pair must be preserved"
1945 );
1946
1947 let user_msg = &agent.messages[messages_before + 1];
1948 assert!(
1949 user_msg.parts.iter().any(|p| matches!(
1950 p,
1951 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
1952 )),
1953 "ToolResult part for a matched tool_use must not be stripped"
1954 );
1955 }
1956
1957 #[tokio::test]
1960 async fn strip_orphans_mixed_history() {
1961 use zeph_llm::provider::MessagePart;
1962
1963 let provider = mock_provider(vec![]);
1964 let channel = MockChannel::new(vec![]);
1965 let registry = create_test_registry();
1966 let executor = MockToolExecutor::no_tools();
1967
1968 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1969 let cid = memory.sqlite().create_conversation().await.unwrap();
1970 let sqlite = memory.sqlite();
1971
1972 let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
1974 id: "call_good".to_string(),
1975 name: "shell".to_string(),
1976 input: serde_json::json!({"command": "pwd"}),
1977 }])
1978 .unwrap();
1979 sqlite
1980 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
1981 .await
1982 .unwrap();
1983
1984 let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
1985 tool_use_id: "call_good".to_string(),
1986 content: "/home".to_string(),
1987 is_error: false,
1988 }])
1989 .unwrap();
1990 sqlite
1991 .save_message_with_parts(cid, "user", "", &result_parts_ok)
1992 .await
1993 .unwrap();
1994
1995 sqlite
1997 .save_message(cid, "assistant", "text only")
1998 .await
1999 .unwrap();
2000
2001 let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
2002 tool_use_id: "call_ghost".to_string(),
2003 content: "ghost result".to_string(),
2004 is_error: false,
2005 }])
2006 .unwrap();
2007 sqlite
2008 .save_message_with_parts(
2009 cid,
2010 "user",
2011 "[tool_result: call_ghost]\nghost result",
2012 &orphan_parts,
2013 )
2014 .await
2015 .unwrap();
2016
2017 sqlite
2018 .save_message(cid, "assistant", "final reply")
2019 .await
2020 .unwrap();
2021
2022 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2023 std::sync::Arc::new(memory),
2024 cid,
2025 50,
2026 5,
2027 100,
2028 );
2029
2030 let messages_before = agent.messages.len();
2031 agent.load_history().await.unwrap();
2032
2033 let loaded = &agent.messages[messages_before..];
2034
2035 for msg in loaded {
2037 assert!(
2038 !msg.parts.iter().any(|p| matches!(
2039 p,
2040 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
2041 )),
2042 "orphaned ToolResult (call_ghost) must be stripped from history"
2043 );
2044 }
2045
2046 let has_good_result = loaded.iter().any(|msg| {
2049 msg.role == Role::User
2050 && msg.parts.iter().any(|p| {
2051 matches!(
2052 p,
2053 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
2054 )
2055 })
2056 });
2057 assert!(
2058 has_good_result,
2059 "matched ToolResult (call_good) must be preserved in history"
2060 );
2061 }
2062
2063 #[tokio::test]
2066 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
2067 use zeph_llm::provider::MessagePart;
2068
2069 let provider = mock_provider(vec![]);
2070 let channel = MockChannel::new(vec![]);
2071 let registry = create_test_registry();
2072 let executor = MockToolExecutor::no_tools();
2073
2074 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2075 let cid = memory.sqlite().create_conversation().await.unwrap();
2076 let sqlite = memory.sqlite();
2077
2078 sqlite
2079 .save_message(cid, "user", "run a command")
2080 .await
2081 .unwrap();
2082
2083 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2085 id: "call_ok".to_string(),
2086 name: "shell".to_string(),
2087 input: serde_json::json!({"command": "echo hi"}),
2088 }])
2089 .unwrap();
2090 sqlite
2091 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2092 .await
2093 .unwrap();
2094
2095 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2097 tool_use_id: "call_ok".to_string(),
2098 content: "hi".to_string(),
2099 is_error: false,
2100 }])
2101 .unwrap();
2102 sqlite
2103 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
2104 .await
2105 .unwrap();
2106
2107 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2108
2109 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2110 std::sync::Arc::new(memory),
2111 cid,
2112 50,
2113 5,
2114 100,
2115 );
2116
2117 let messages_before = agent.messages.len();
2118 agent.load_history().await.unwrap();
2119
2120 assert_eq!(
2122 agent.messages.len(),
2123 messages_before + 4,
2124 "matched tool pair must not be removed"
2125 );
2126 let tool_msg = &agent.messages[messages_before + 1];
2127 assert!(
2128 tool_msg
2129 .parts
2130 .iter()
2131 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
2132 "matched ToolUse parts must be preserved"
2133 );
2134 }
2135
2136 #[tokio::test]
2140 async fn persist_cancelled_tool_results_pairs_tool_use() {
2141 use zeph_llm::provider::MessagePart;
2142
2143 let provider = mock_provider(vec![]);
2144 let channel = MockChannel::new(vec![]);
2145 let registry = create_test_registry();
2146 let executor = MockToolExecutor::no_tools();
2147
2148 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2149 let cid = memory.sqlite().create_conversation().await.unwrap();
2150
2151 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2152 std::sync::Arc::new(memory),
2153 cid,
2154 50,
2155 5,
2156 100,
2157 );
2158
2159 let tool_calls = vec![
2161 zeph_llm::provider::ToolUseRequest {
2162 id: "cancel_id_1".to_string(),
2163 name: "shell".to_string(),
2164 input: serde_json::json!({}),
2165 },
2166 zeph_llm::provider::ToolUseRequest {
2167 id: "cancel_id_2".to_string(),
2168 name: "read_file".to_string(),
2169 input: serde_json::json!({}),
2170 },
2171 ];
2172
2173 agent.persist_cancelled_tool_results(&tool_calls).await;
2174
2175 let history = agent
2176 .memory_state
2177 .memory
2178 .as_ref()
2179 .unwrap()
2180 .sqlite()
2181 .load_history(cid, 50)
2182 .await
2183 .unwrap();
2184
2185 assert_eq!(history.len(), 1);
2187 assert_eq!(history[0].role, Role::User);
2188
2189 for tc in &tool_calls {
2191 assert!(
2192 history[0].parts.iter().any(|p| matches!(
2193 p,
2194 MessagePart::ToolResult { tool_use_id, is_error, .. }
2195 if tool_use_id == &tc.id && *is_error
2196 )),
2197 "tombstone ToolResult for {} must be present and is_error=true",
2198 tc.id
2199 );
2200 }
2201 }
2202}