1use std::collections::HashSet;
5
6use crate::channel::Channel;
7use zeph_llm::provider::{LlmProvider as _, Message, MessagePart, Role};
8use zeph_memory::store::role_str;
9
10use super::Agent;
11
12fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> (usize, Vec<i64>) {
28 let mut removed = 0;
29 let mut db_ids: Vec<i64> = Vec::new();
30
31 loop {
32 if let Some(last) = messages.last()
34 && last.role == Role::Assistant
35 && last
36 .parts
37 .iter()
38 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
39 {
40 let ids: Vec<String> = last
41 .parts
42 .iter()
43 .filter_map(|p| {
44 if let MessagePart::ToolUse { id, .. } = p {
45 Some(id.clone())
46 } else {
47 None
48 }
49 })
50 .collect();
51 tracing::warn!(
52 tool_ids = ?ids,
53 "removing orphaned trailing tool_use message from restored history"
54 );
55 if let Some(db_id) = messages.last().and_then(|m| m.metadata.db_id) {
56 db_ids.push(db_id);
57 }
58 messages.pop();
59 removed += 1;
60 continue;
61 }
62
63 if let Some(first) = messages.first()
65 && first.role == Role::User
66 && first
67 .parts
68 .iter()
69 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
70 {
71 let ids: Vec<String> = first
72 .parts
73 .iter()
74 .filter_map(|p| {
75 if let MessagePart::ToolResult { tool_use_id, .. } = p {
76 Some(tool_use_id.clone())
77 } else {
78 None
79 }
80 })
81 .collect();
82 tracing::warn!(
83 tool_use_ids = ?ids,
84 "removing orphaned leading tool_result message from restored history"
85 );
86 if let Some(db_id) = messages.first().and_then(|m| m.metadata.db_id) {
87 db_ids.push(db_id);
88 }
89 messages.remove(0);
90 removed += 1;
91 continue;
92 }
93
94 break;
95 }
96
97 let (mid_removed, mid_db_ids) = strip_mid_history_orphans(messages);
100 removed += mid_removed;
101 db_ids.extend(mid_db_ids);
102
103 (removed, db_ids)
104}
105
106fn orphaned_tool_use_ids(msg: &Message, next_msg: Option<&Message>) -> HashSet<String> {
108 let matched: HashSet<String> = next_msg
109 .filter(|n| n.role == Role::User)
110 .map(|n| {
111 msg.parts
112 .iter()
113 .filter_map(|p| if let MessagePart::ToolUse { id, .. } = p { Some(id.clone()) } else { None })
114 .filter(|uid| n.parts.iter().any(|np| matches!(np, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)))
115 .collect()
116 })
117 .unwrap_or_default();
118 msg.parts
119 .iter()
120 .filter_map(|p| {
121 if let MessagePart::ToolUse { id, .. } = p
122 && !matched.contains(id)
123 {
124 Some(id.clone())
125 } else {
126 None
127 }
128 })
129 .collect()
130}
131
132fn orphaned_tool_result_ids(msg: &Message, prev_msg: Option<&Message>) -> HashSet<String> {
134 let avail: HashSet<&str> = prev_msg
135 .filter(|p| p.role == Role::Assistant)
136 .map(|p| {
137 p.parts
138 .iter()
139 .filter_map(|part| {
140 if let MessagePart::ToolUse { id, .. } = part {
141 Some(id.as_str())
142 } else {
143 None
144 }
145 })
146 .collect()
147 })
148 .unwrap_or_default();
149 msg.parts
150 .iter()
151 .filter_map(|p| {
152 if let MessagePart::ToolResult { tool_use_id, .. } = p
153 && !avail.contains(tool_use_id.as_str())
154 {
155 Some(tool_use_id.clone())
156 } else {
157 None
158 }
159 })
160 .collect()
161}
162
163fn has_meaningful_content(content: &str) -> bool {
184 const PREFIXES: [&str; 3] = ["[tool_use: ", "[tool_result: ", "[tool output: "];
185
186 let mut remaining = content.trim();
187
188 loop {
189 let next = PREFIXES
191 .iter()
192 .filter_map(|prefix| remaining.find(prefix).map(|pos| (pos, *prefix)))
193 .min_by_key(|(pos, _)| *pos);
194
195 let Some((start, prefix)) = next else {
196 break;
198 };
199
200 if !remaining[..start].trim().is_empty() {
202 return true;
203 }
204
205 let after_prefix = &remaining[start + prefix.len()..];
207 let Some(close) = after_prefix.find(']') else {
208 return true;
210 };
211
212 let tag_end = start + prefix.len() + close + 1;
214
215 if prefix == "[tool_result: " || prefix == "[tool output: " {
216 let body = remaining[tag_end..].trim_start_matches('\n');
219 let next_tag = PREFIXES
220 .iter()
221 .filter_map(|p| body.find(p))
222 .min()
223 .unwrap_or(body.len());
224 remaining = &body[next_tag..];
225 } else {
226 remaining = &remaining[tag_end..];
227 }
228 }
229
230 !remaining.trim().is_empty()
231}
232
233fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> (usize, Vec<i64>) {
246 let mut removed = 0;
247 let mut db_ids: Vec<i64> = Vec::new();
248 let mut i = 0;
249 while i < messages.len() {
250 if messages[i].role == Role::Assistant
254 && messages[i]
255 .parts
256 .iter()
257 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
258 {
259 let orphaned_ids = orphaned_tool_use_ids(&messages[i], messages.get(i + 1));
260 if !orphaned_ids.is_empty() {
261 tracing::warn!(
262 tool_ids = ?orphaned_ids,
263 index = i,
264 "stripping orphaned mid-history tool_use parts from assistant message"
265 );
266 messages[i].parts.retain(
267 |p| !matches!(p, MessagePart::ToolUse { id, .. } if orphaned_ids.contains(id)),
268 );
269 let is_empty =
270 !has_meaningful_content(&messages[i].content) && messages[i].parts.is_empty();
271 if is_empty {
272 if let Some(db_id) = messages[i].metadata.db_id {
273 db_ids.push(db_id);
274 }
275 messages.remove(i);
276 removed += 1;
277 continue; }
279 }
280 }
281
282 if messages[i].role == Role::User
284 && messages[i]
285 .parts
286 .iter()
287 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
288 {
289 let orphaned_ids = orphaned_tool_result_ids(
290 &messages[i],
291 if i > 0 { messages.get(i - 1) } else { None },
292 );
293 if !orphaned_ids.is_empty() {
294 tracing::warn!(
295 tool_use_ids = ?orphaned_ids,
296 index = i,
297 "stripping orphaned mid-history tool_result parts from user message"
298 );
299 messages[i].parts.retain(|p| {
300 !matches!(p, MessagePart::ToolResult { tool_use_id, .. } if orphaned_ids.contains(tool_use_id.as_str()))
301 });
302
303 let is_empty =
304 !has_meaningful_content(&messages[i].content) && messages[i].parts.is_empty();
305 if is_empty {
306 if let Some(db_id) = messages[i].metadata.db_id {
307 db_ids.push(db_id);
308 }
309 messages.remove(i);
310 removed += 1;
311 continue;
313 }
314 }
315 }
316
317 i += 1;
318 }
319 (removed, db_ids)
320}
321
322impl<C: Channel> Agent<C> {
323 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
329 let (Some(memory), Some(cid)) =
330 (&self.memory_state.memory, self.memory_state.conversation_id)
331 else {
332 return Ok(());
333 };
334
335 let history = memory
336 .sqlite()
337 .load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
338 .await?;
339 if !history.is_empty() {
340 let mut loaded = 0;
341 let mut skipped = 0;
342
343 for msg in history {
344 if !has_meaningful_content(&msg.content) && msg.parts.is_empty() {
349 tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
350 skipped += 1;
351 continue;
352 }
353 self.msg.messages.push(msg);
354 loaded += 1;
355 }
356
357 let history_start = self.msg.messages.len() - loaded;
359 let mut restored_slice = self.msg.messages.split_off(history_start);
360 let (orphans, orphan_db_ids) = sanitize_tool_pairs(&mut restored_slice);
361 skipped += orphans;
362 loaded = loaded.saturating_sub(orphans);
363 self.msg.messages.append(&mut restored_slice);
364
365 if !orphan_db_ids.is_empty() {
366 let ids: Vec<zeph_memory::types::MessageId> = orphan_db_ids
367 .iter()
368 .map(|&id| zeph_memory::types::MessageId(id))
369 .collect();
370 if let Err(e) = memory.sqlite().soft_delete_messages(&ids).await {
371 tracing::warn!(
372 count = ids.len(),
373 error = %e,
374 "failed to soft-delete orphaned tool-pair messages from DB"
375 );
376 } else {
377 tracing::debug!(
378 count = ids.len(),
379 "soft-deleted orphaned tool-pair messages from DB"
380 );
381 }
382 }
383
384 tracing::info!("restored {loaded} message(s) from conversation {cid}");
385 if skipped > 0 {
386 tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
387 }
388
389 if loaded > 0 {
390 let _ = memory
393 .sqlite()
394 .increment_session_counts_for_conversation(cid)
395 .await
396 .inspect_err(|e| {
397 tracing::warn!(error = %e, "failed to increment tier session counts");
398 });
399 }
400 }
401
402 if let Ok(count) = memory.message_count(cid).await {
403 let count_u64 = u64::try_from(count).unwrap_or(0);
404 self.update_metrics(|m| {
405 m.sqlite_message_count = count_u64;
406 });
407 }
408
409 if let Ok(count) = memory.sqlite().count_semantic_facts().await {
410 let count_u64 = u64::try_from(count).unwrap_or(0);
411 self.update_metrics(|m| {
412 m.semantic_fact_count = count_u64;
413 });
414 }
415
416 if let Ok(count) = memory.unsummarized_message_count(cid).await {
417 self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
418 }
419
420 self.recompute_prompt_tokens();
421 Ok(())
422 }
423
424 #[allow(clippy::too_many_lines)]
430 pub(crate) async fn persist_message(
431 &mut self,
432 role: Role,
433 content: &str,
434 parts: &[MessagePart],
435 has_injection_flags: bool,
436 ) {
437 let (Some(memory), Some(cid)) =
438 (&self.memory_state.memory, self.memory_state.conversation_id)
439 else {
440 return;
441 };
442
443 let parts_json = if parts.is_empty() {
444 "[]".to_string()
445 } else {
446 serde_json::to_string(parts).unwrap_or_else(|e| {
447 tracing::warn!("failed to serialize message parts, storing empty: {e}");
448 "[]".to_string()
449 })
450 };
451
452 let guard_event = self
456 .security
457 .exfiltration_guard
458 .should_guard_memory_write(has_injection_flags);
459 if let Some(ref event) = guard_event {
460 tracing::warn!(
461 ?event,
462 "exfiltration guard: skipping Qdrant embedding for flagged content"
463 );
464 self.update_metrics(|m| m.exfiltration_memory_guards += 1);
465 self.push_security_event(
466 crate::metrics::SecurityEventCategory::ExfiltrationBlock,
467 "memory_write",
468 "Qdrant embedding skipped: flagged content",
469 );
470 }
471
472 let skip_embedding = guard_event.is_some();
473
474 let has_skipped_tool_result = parts.iter().any(|p| {
478 if let MessagePart::ToolResult { content, .. } = p {
479 content.starts_with("[skipped]") || content.starts_with("[stopped]")
480 } else {
481 false
482 }
483 });
484
485 let should_embed = if skip_embedding || has_skipped_tool_result {
486 false
487 } else {
488 match role {
489 Role::Assistant => {
490 self.memory_state.autosave_assistant
491 && content.len() >= self.memory_state.autosave_min_length
492 }
493 _ => true,
494 }
495 };
496
497 let goal_text = self.memory_state.goal_text.clone();
498
499 let (embedding_stored, was_persisted) = if should_embed {
500 match memory
501 .remember_with_parts(
502 cid,
503 role_str(role),
504 content,
505 &parts_json,
506 goal_text.as_deref(),
507 )
508 .await
509 {
510 Ok((Some(message_id), stored)) => {
511 self.last_persisted_message_id = Some(message_id.0);
512 (stored, true)
513 }
514 Ok((None, _)) => {
515 return;
517 }
518 Err(e) => {
519 tracing::error!("failed to persist message: {e:#}");
520 return;
521 }
522 }
523 } else {
524 match memory
525 .save_only(cid, role_str(role), content, &parts_json)
526 .await
527 {
528 Ok(message_id) => {
529 self.last_persisted_message_id = Some(message_id.0);
530 (false, true)
531 }
532 Err(e) => {
533 tracing::error!("failed to persist message: {e:#}");
534 return;
535 }
536 }
537 };
538
539 if !was_persisted {
540 return;
541 }
542
543 self.memory_state.unsummarized_count += 1;
544
545 self.update_metrics(|m| {
546 m.sqlite_message_count += 1;
547 if embedding_stored {
548 m.embeddings_generated += 1;
549 }
550 });
551
552 self.check_summarization().await;
553
554 let has_tool_result_parts = parts
557 .iter()
558 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
559
560 self.maybe_spawn_graph_extraction(content, has_injection_flags, has_tool_result_parts)
561 .await;
562 }
563
564 #[allow(clippy::too_many_lines)]
565 async fn maybe_spawn_graph_extraction(
566 &mut self,
567 content: &str,
568 has_injection_flags: bool,
569 has_tool_result_parts: bool,
570 ) {
571 use zeph_memory::semantic::GraphExtractionConfig;
572
573 if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
574 return;
575 }
576
577 if has_tool_result_parts {
580 tracing::debug!("graph extraction skipped: message contains ToolResult parts");
581 return;
582 }
583
584 if has_injection_flags {
586 tracing::warn!("graph extraction skipped: injection patterns detected in content");
587 return;
588 }
589
590 let extraction_cfg = {
592 let cfg = &self.memory_state.graph_config;
593 if !cfg.enabled {
594 return;
595 }
596 GraphExtractionConfig {
597 max_entities: cfg.max_entities_per_message,
598 max_edges: cfg.max_edges_per_message,
599 extraction_timeout_secs: cfg.extraction_timeout_secs,
600 community_refresh_interval: cfg.community_refresh_interval,
601 expired_edge_retention_days: cfg.expired_edge_retention_days,
602 max_entities_cap: cfg.max_entities,
603 community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
604 community_summary_concurrency: cfg.community_summary_concurrency,
605 lpa_edge_chunk_size: cfg.lpa_edge_chunk_size,
606 note_linking: zeph_memory::NoteLinkingConfig {
607 enabled: cfg.note_linking.enabled,
608 similarity_threshold: cfg.note_linking.similarity_threshold,
609 top_k: cfg.note_linking.top_k,
610 timeout_secs: cfg.note_linking.timeout_secs,
611 },
612 link_weight_decay_lambda: cfg.link_weight_decay_lambda,
613 link_weight_decay_interval_secs: cfg.link_weight_decay_interval_secs,
614 belief_revision_enabled: cfg.belief_revision.enabled,
615 belief_revision_similarity_threshold: cfg.belief_revision.similarity_threshold,
616 conversation_id: self.memory_state.conversation_id.map(|c| c.0),
617 }
618 };
619
620 if self.rpe_should_skip(content).await {
622 tracing::debug!("D-MEM RPE: low-surprise turn, skipping graph extraction");
623 return;
624 }
625
626 let context_messages: Vec<String> = self
630 .msg
631 .messages
632 .iter()
633 .rev()
634 .filter(|m| {
635 m.role == Role::User
636 && !m
637 .parts
638 .iter()
639 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
640 })
641 .take(4)
642 .map(|m| m.content.clone())
643 .collect();
644
645 let _ = self.channel.send_status("saving to graph...").await;
646
647 if let Some(memory) = &self.memory_state.memory {
648 let validator: zeph_memory::semantic::PostExtractValidator =
651 if self.security.memory_validator.is_enabled() {
652 let v = self.security.memory_validator.clone();
653 Some(Box::new(move |result| {
654 v.validate_graph_extraction(result)
655 .map_err(|e| e.to_string())
656 }))
657 } else {
658 None
659 };
660 let extraction_handle = memory.spawn_graph_extraction(
661 content.to_owned(),
662 context_messages,
663 extraction_cfg,
664 validator,
665 );
666 if let (Some(store), Some(tx)) =
669 (memory.graph_store.clone(), self.metrics.metrics_tx.clone())
670 {
671 let start = self.lifecycle.start_time;
672 tokio::spawn(async move {
673 let _ = extraction_handle.await;
674 let (entities, edges, communities) = tokio::join!(
675 store.entity_count(),
676 store.active_edge_count(),
677 store.community_count()
678 );
679 let elapsed = start.elapsed().as_secs();
680 tx.send_modify(|m| {
681 m.uptime_seconds = elapsed;
682 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
683 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
684 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
685 });
686 });
687 }
688 }
689 let _ = self.channel.send_status("").await;
690 self.sync_community_detection_failures();
691 self.sync_graph_extraction_metrics();
692 self.sync_graph_counts().await;
693 self.sync_guidelines_status().await;
694 }
695
696 pub(crate) async fn check_summarization(&mut self) {
697 let (Some(memory), Some(cid)) =
698 (&self.memory_state.memory, self.memory_state.conversation_id)
699 else {
700 return;
701 };
702
703 if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
704 let _ = self.channel.send_status("summarizing...").await;
705 let batch_size = self.memory_state.summarization_threshold / 2;
706 match memory.summarize(cid, batch_size).await {
707 Ok(Some(summary_id)) => {
708 tracing::info!("created summary {summary_id} for conversation {cid}");
709 self.memory_state.unsummarized_count = 0;
710 self.update_metrics(|m| {
711 m.summaries_count += 1;
712 });
713 }
714 Ok(None) => {
715 tracing::debug!("no summarization needed");
716 }
717 Err(e) => {
718 tracing::error!("summarization failed: {e:#}");
719 }
720 }
721 let _ = self.channel.send_status("").await;
722 }
723 }
724
725 async fn rpe_should_skip(&mut self, content: &str) -> bool {
730 let Some(ref rpe_mutex) = self.memory_state.rpe_router else {
731 return false;
732 };
733 let Some(memory) = &self.memory_state.memory else {
734 return false;
735 };
736 let candidates = zeph_memory::extract_candidate_entities(content);
737 let provider = memory.provider();
738 let Ok(Ok(emb_vec)) =
739 tokio::time::timeout(std::time::Duration::from_secs(5), provider.embed(content)).await
740 else {
741 return false; };
743 if let Ok(mut router) = rpe_mutex.lock() {
744 let signal = router.compute(&emb_vec, &candidates);
745 router.push_embedding(emb_vec);
746 router.push_entities(&candidates);
747 !signal.should_extract
748 } else {
749 tracing::warn!("rpe_router mutex poisoned; falling through to extract");
750 false
751 }
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::super::agent_tests::{
758 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
759 };
760 use super::*;
761 use zeph_llm::any::AnyProvider;
762 use zeph_memory::semantic::SemanticMemory;
763
764 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
765 SemanticMemory::new(
766 ":memory:",
767 "http://127.0.0.1:1",
768 provider.clone(),
769 "test-model",
770 )
771 .await
772 .unwrap()
773 }
774
775 #[tokio::test]
776 async fn load_history_without_memory_returns_ok() {
777 let provider = mock_provider(vec![]);
778 let channel = MockChannel::new(vec![]);
779 let registry = create_test_registry();
780 let executor = MockToolExecutor::no_tools();
781 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
782
783 let result = agent.load_history().await;
784 assert!(result.is_ok());
785 assert_eq!(agent.msg.messages.len(), 1); }
788
789 #[tokio::test]
790 async fn load_history_with_messages_injects_into_agent() {
791 let provider = mock_provider(vec![]);
792 let channel = MockChannel::new(vec![]);
793 let registry = create_test_registry();
794 let executor = MockToolExecutor::no_tools();
795
796 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
797 let cid = memory.sqlite().create_conversation().await.unwrap();
798
799 memory
800 .sqlite()
801 .save_message(cid, "user", "hello from history")
802 .await
803 .unwrap();
804 memory
805 .sqlite()
806 .save_message(cid, "assistant", "hi back")
807 .await
808 .unwrap();
809
810 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
811 std::sync::Arc::new(memory),
812 cid,
813 50,
814 5,
815 100,
816 );
817
818 let messages_before = agent.msg.messages.len();
819 agent.load_history().await.unwrap();
820 assert_eq!(agent.msg.messages.len(), messages_before + 2);
822 }
823
824 #[tokio::test]
825 async fn load_history_skips_empty_messages() {
826 let provider = mock_provider(vec![]);
827 let channel = MockChannel::new(vec![]);
828 let registry = create_test_registry();
829 let executor = MockToolExecutor::no_tools();
830
831 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
832 let cid = memory.sqlite().create_conversation().await.unwrap();
833
834 memory
836 .sqlite()
837 .save_message(cid, "user", " ")
838 .await
839 .unwrap();
840 memory
841 .sqlite()
842 .save_message(cid, "user", "real message")
843 .await
844 .unwrap();
845
846 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
847 std::sync::Arc::new(memory),
848 cid,
849 50,
850 5,
851 100,
852 );
853
854 let messages_before = agent.msg.messages.len();
855 agent.load_history().await.unwrap();
856 assert_eq!(agent.msg.messages.len(), messages_before + 1);
858 }
859
860 #[tokio::test]
861 async fn load_history_with_empty_store_returns_ok() {
862 let provider = mock_provider(vec![]);
863 let channel = MockChannel::new(vec![]);
864 let registry = create_test_registry();
865 let executor = MockToolExecutor::no_tools();
866
867 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
868 let cid = memory.sqlite().create_conversation().await.unwrap();
869
870 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
871 std::sync::Arc::new(memory),
872 cid,
873 50,
874 5,
875 100,
876 );
877
878 let messages_before = agent.msg.messages.len();
879 agent.load_history().await.unwrap();
880 assert_eq!(agent.msg.messages.len(), messages_before);
882 }
883
884 #[tokio::test]
885 async fn load_history_increments_session_count_for_existing_messages() {
886 let provider = mock_provider(vec![]);
887 let channel = MockChannel::new(vec![]);
888 let registry = create_test_registry();
889 let executor = MockToolExecutor::no_tools();
890
891 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
892 let cid = memory.sqlite().create_conversation().await.unwrap();
893
894 let id1 = memory
896 .sqlite()
897 .save_message(cid, "user", "hello")
898 .await
899 .unwrap();
900 let id2 = memory
901 .sqlite()
902 .save_message(cid, "assistant", "hi")
903 .await
904 .unwrap();
905
906 let memory_arc = std::sync::Arc::new(memory);
907 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
908 memory_arc.clone(),
909 cid,
910 50,
911 5,
912 100,
913 );
914
915 agent.load_history().await.unwrap();
916
917 let counts: Vec<i64> = zeph_db::query_scalar(
919 "SELECT session_count FROM messages WHERE id IN (?, ?) ORDER BY id",
920 )
921 .bind(id1)
922 .bind(id2)
923 .fetch_all(memory_arc.sqlite().pool())
924 .await
925 .unwrap();
926 assert_eq!(
927 counts,
928 vec![1, 1],
929 "session_count must be 1 after first restore"
930 );
931 }
932
933 #[tokio::test]
934 async fn load_history_does_not_increment_session_count_for_new_conversation() {
935 let provider = mock_provider(vec![]);
936 let channel = MockChannel::new(vec![]);
937 let registry = create_test_registry();
938 let executor = MockToolExecutor::no_tools();
939
940 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
941 let cid = memory.sqlite().create_conversation().await.unwrap();
942
943 let memory_arc = std::sync::Arc::new(memory);
945 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
946 memory_arc.clone(),
947 cid,
948 50,
949 5,
950 100,
951 );
952
953 agent.load_history().await.unwrap();
954
955 let counts: Vec<i64> =
957 zeph_db::query_scalar("SELECT session_count FROM messages WHERE conversation_id = ?")
958 .bind(cid)
959 .fetch_all(memory_arc.sqlite().pool())
960 .await
961 .unwrap();
962 assert!(counts.is_empty(), "new conversation must have no messages");
963 }
964
965 #[tokio::test]
966 async fn persist_message_without_memory_silently_returns() {
967 let provider = mock_provider(vec![]);
969 let channel = MockChannel::new(vec![]);
970 let registry = create_test_registry();
971 let executor = MockToolExecutor::no_tools();
972 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
973
974 agent.persist_message(Role::User, "hello", &[], false).await;
976 }
977
978 #[tokio::test]
979 async fn persist_message_assistant_autosave_false_uses_save_only() {
980 let provider = mock_provider(vec![]);
981 let channel = MockChannel::new(vec![]);
982 let registry = create_test_registry();
983 let executor = MockToolExecutor::no_tools();
984
985 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
986 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
987 let cid = memory.sqlite().create_conversation().await.unwrap();
988
989 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
990 .with_metrics(tx)
991 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
992 .with_autosave_config(false, 20);
993
994 agent
995 .persist_message(Role::Assistant, "short assistant reply", &[], false)
996 .await;
997
998 let history = agent
999 .memory_state
1000 .memory
1001 .as_ref()
1002 .unwrap()
1003 .sqlite()
1004 .load_history(cid, 50)
1005 .await
1006 .unwrap();
1007 assert_eq!(history.len(), 1, "message must be saved");
1008 assert_eq!(history[0].content, "short assistant reply");
1009 assert_eq!(rx.borrow().embeddings_generated, 0);
1011 }
1012
1013 #[tokio::test]
1014 async fn persist_message_assistant_below_min_length_uses_save_only() {
1015 let provider = mock_provider(vec![]);
1016 let channel = MockChannel::new(vec![]);
1017 let registry = create_test_registry();
1018 let executor = MockToolExecutor::no_tools();
1019
1020 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1021 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1022 let cid = memory.sqlite().create_conversation().await.unwrap();
1023
1024 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1026 .with_metrics(tx)
1027 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1028 .with_autosave_config(true, 1000);
1029
1030 agent
1031 .persist_message(Role::Assistant, "too short", &[], false)
1032 .await;
1033
1034 let history = agent
1035 .memory_state
1036 .memory
1037 .as_ref()
1038 .unwrap()
1039 .sqlite()
1040 .load_history(cid, 50)
1041 .await
1042 .unwrap();
1043 assert_eq!(history.len(), 1, "message must be saved");
1044 assert_eq!(history[0].content, "too short");
1045 assert_eq!(rx.borrow().embeddings_generated, 0);
1046 }
1047
1048 #[tokio::test]
1049 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
1050 let provider = mock_provider(vec![]);
1052 let channel = MockChannel::new(vec![]);
1053 let registry = create_test_registry();
1054 let executor = MockToolExecutor::no_tools();
1055
1056 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1057 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1058 let cid = memory.sqlite().create_conversation().await.unwrap();
1059
1060 let min_length = 10usize;
1061 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1062 .with_metrics(tx)
1063 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1064 .with_autosave_config(true, min_length);
1065
1066 let content_at_boundary = "A".repeat(min_length);
1068 assert_eq!(content_at_boundary.len(), min_length);
1069 agent
1070 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
1071 .await;
1072
1073 assert_eq!(rx.borrow().sqlite_message_count, 1);
1075 }
1076
1077 #[tokio::test]
1078 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
1079 let provider = mock_provider(vec![]);
1081 let channel = MockChannel::new(vec![]);
1082 let registry = create_test_registry();
1083 let executor = MockToolExecutor::no_tools();
1084
1085 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1086 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1087 let cid = memory.sqlite().create_conversation().await.unwrap();
1088
1089 let min_length = 10usize;
1090 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1091 .with_metrics(tx)
1092 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1093 .with_autosave_config(true, min_length);
1094
1095 let content_below_boundary = "A".repeat(min_length - 1);
1097 assert_eq!(content_below_boundary.len(), min_length - 1);
1098 agent
1099 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
1100 .await;
1101
1102 let history = agent
1103 .memory_state
1104 .memory
1105 .as_ref()
1106 .unwrap()
1107 .sqlite()
1108 .load_history(cid, 50)
1109 .await
1110 .unwrap();
1111 assert_eq!(history.len(), 1, "message must still be saved");
1112 assert_eq!(rx.borrow().embeddings_generated, 0);
1114 }
1115
1116 #[tokio::test]
1117 async fn persist_message_increments_unsummarized_count() {
1118 let provider = mock_provider(vec![]);
1119 let channel = MockChannel::new(vec![]);
1120 let registry = create_test_registry();
1121 let executor = MockToolExecutor::no_tools();
1122
1123 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1124 let cid = memory.sqlite().create_conversation().await.unwrap();
1125
1126 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1128 std::sync::Arc::new(memory),
1129 cid,
1130 50,
1131 5,
1132 100,
1133 );
1134
1135 assert_eq!(agent.memory_state.unsummarized_count, 0);
1136
1137 agent.persist_message(Role::User, "first", &[], false).await;
1138 assert_eq!(agent.memory_state.unsummarized_count, 1);
1139
1140 agent
1141 .persist_message(Role::User, "second", &[], false)
1142 .await;
1143 assert_eq!(agent.memory_state.unsummarized_count, 2);
1144 }
1145
1146 #[tokio::test]
1147 async fn check_summarization_resets_counter_on_success() {
1148 let provider = mock_provider(vec![]);
1149 let channel = MockChannel::new(vec![]);
1150 let registry = create_test_registry();
1151 let executor = MockToolExecutor::no_tools();
1152
1153 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1154 let cid = memory.sqlite().create_conversation().await.unwrap();
1155
1156 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1158 std::sync::Arc::new(memory),
1159 cid,
1160 50,
1161 5,
1162 1,
1163 );
1164
1165 agent.persist_message(Role::User, "msg1", &[], false).await;
1166 agent.persist_message(Role::User, "msg2", &[], false).await;
1167
1168 assert!(agent.memory_state.unsummarized_count <= 2);
1173 }
1174
1175 #[tokio::test]
1176 async fn unsummarized_count_not_incremented_without_memory() {
1177 let provider = mock_provider(vec![]);
1178 let channel = MockChannel::new(vec![]);
1179 let registry = create_test_registry();
1180 let executor = MockToolExecutor::no_tools();
1181 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1182
1183 agent.persist_message(Role::User, "hello", &[], false).await;
1184 assert_eq!(agent.memory_state.unsummarized_count, 0);
1186 }
1187
1188 mod graph_extraction_guards {
1190 use super::*;
1191 use crate::config::GraphConfig;
1192 use zeph_llm::provider::MessageMetadata;
1193 use zeph_memory::graph::GraphStore;
1194
1195 fn enabled_graph_config() -> GraphConfig {
1196 GraphConfig {
1197 enabled: true,
1198 ..GraphConfig::default()
1199 }
1200 }
1201
1202 async fn agent_with_graph(
1203 provider: &AnyProvider,
1204 config: GraphConfig,
1205 ) -> Agent<MockChannel> {
1206 let memory =
1207 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1208 let cid = memory.sqlite().create_conversation().await.unwrap();
1209 Agent::new(
1210 provider.clone(),
1211 MockChannel::new(vec![]),
1212 create_test_registry(),
1213 None,
1214 5,
1215 MockToolExecutor::no_tools(),
1216 )
1217 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1218 .with_graph_config(config)
1219 }
1220
1221 #[tokio::test]
1222 async fn injection_flag_guard_skips_extraction() {
1223 let provider = mock_provider(vec![]);
1225 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1226 let pool = agent
1227 .memory_state
1228 .memory
1229 .as_ref()
1230 .unwrap()
1231 .sqlite()
1232 .pool()
1233 .clone();
1234
1235 agent
1236 .maybe_spawn_graph_extraction("I use Rust", true, false)
1237 .await;
1238
1239 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1241
1242 let store = GraphStore::new(pool);
1243 let count = store.get_metadata("extraction_count").await.unwrap();
1244 assert!(
1245 count.is_none(),
1246 "injection flag must prevent extraction_count from being written"
1247 );
1248 }
1249
1250 #[tokio::test]
1251 async fn disabled_config_guard_skips_extraction() {
1252 let provider = mock_provider(vec![]);
1254 let disabled_cfg = GraphConfig {
1255 enabled: false,
1256 ..GraphConfig::default()
1257 };
1258 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
1259 let pool = agent
1260 .memory_state
1261 .memory
1262 .as_ref()
1263 .unwrap()
1264 .sqlite()
1265 .pool()
1266 .clone();
1267
1268 agent
1269 .maybe_spawn_graph_extraction("I use Rust", false, false)
1270 .await;
1271
1272 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1273
1274 let store = GraphStore::new(pool);
1275 let count = store.get_metadata("extraction_count").await.unwrap();
1276 assert!(
1277 count.is_none(),
1278 "disabled graph config must prevent extraction"
1279 );
1280 }
1281
1282 #[tokio::test]
1283 async fn happy_path_fires_extraction() {
1284 let provider = mock_provider(vec![]);
1287 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1288 let pool = agent
1289 .memory_state
1290 .memory
1291 .as_ref()
1292 .unwrap()
1293 .sqlite()
1294 .pool()
1295 .clone();
1296
1297 agent
1298 .maybe_spawn_graph_extraction("I use Rust for systems programming", false, false)
1299 .await;
1300
1301 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1303
1304 let store = GraphStore::new(pool);
1305 let count = store.get_metadata("extraction_count").await.unwrap();
1306 assert!(
1307 count.is_some(),
1308 "happy-path extraction must increment extraction_count"
1309 );
1310 }
1311
1312 #[tokio::test]
1313 async fn tool_result_parts_guard_skips_extraction() {
1314 let provider = mock_provider(vec![]);
1318 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1319 let pool = agent
1320 .memory_state
1321 .memory
1322 .as_ref()
1323 .unwrap()
1324 .sqlite()
1325 .pool()
1326 .clone();
1327
1328 agent
1329 .maybe_spawn_graph_extraction(
1330 "[tool_result: abc123]\nprovider_type = \"claude\"\nallowed_commands = []",
1331 false,
1332 true, )
1334 .await;
1335
1336 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1337
1338 let store = GraphStore::new(pool);
1339 let count = store.get_metadata("extraction_count").await.unwrap();
1340 assert!(
1341 count.is_none(),
1342 "tool result message must not trigger graph extraction"
1343 );
1344 }
1345
1346 #[tokio::test]
1347 async fn context_filter_excludes_tool_result_messages() {
1348 let provider = mock_provider(vec![]);
1359 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1360
1361 agent.msg.messages.push(Message {
1364 role: Role::User,
1365 content: "[tool_result: abc]\nprovider_type = \"openai\"".to_owned(),
1366 parts: vec![MessagePart::ToolResult {
1367 tool_use_id: "abc".to_owned(),
1368 content: "provider_type = \"openai\"".to_owned(),
1369 is_error: false,
1370 }],
1371 metadata: MessageMetadata::default(),
1372 });
1373
1374 let pool = agent
1375 .memory_state
1376 .memory
1377 .as_ref()
1378 .unwrap()
1379 .sqlite()
1380 .pool()
1381 .clone();
1382
1383 agent
1385 .maybe_spawn_graph_extraction("I prefer Rust for systems programming", false, false)
1386 .await;
1387
1388 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1389
1390 let store = GraphStore::new(pool);
1392 let count = store.get_metadata("extraction_count").await.unwrap();
1393 assert!(
1394 count.is_some(),
1395 "conversational message must trigger extraction even with prior tool result in history"
1396 );
1397 }
1398 }
1399
1400 #[tokio::test]
1401 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
1402 let provider = mock_provider(vec![]);
1403 let channel = MockChannel::new(vec![]);
1404 let registry = create_test_registry();
1405 let executor = MockToolExecutor::no_tools();
1406
1407 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1408 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1409 let cid = memory.sqlite().create_conversation().await.unwrap();
1410
1411 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1413 .with_metrics(tx)
1414 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1415 .with_autosave_config(false, 20);
1416
1417 let long_user_msg = "A".repeat(100);
1418 agent
1419 .persist_message(Role::User, &long_user_msg, &[], false)
1420 .await;
1421
1422 let history = agent
1423 .memory_state
1424 .memory
1425 .as_ref()
1426 .unwrap()
1427 .sqlite()
1428 .load_history(cid, 50)
1429 .await
1430 .unwrap();
1431 assert_eq!(history.len(), 1, "user message must be saved");
1432 assert_eq!(rx.borrow().sqlite_message_count, 1);
1435 }
1436
1437 #[tokio::test]
1441 async fn persist_message_saves_correct_tool_use_parts() {
1442 use zeph_llm::provider::MessagePart;
1443
1444 let provider = mock_provider(vec![]);
1445 let channel = MockChannel::new(vec![]);
1446 let registry = create_test_registry();
1447 let executor = MockToolExecutor::no_tools();
1448
1449 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1450 let cid = memory.sqlite().create_conversation().await.unwrap();
1451
1452 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1453 std::sync::Arc::new(memory),
1454 cid,
1455 50,
1456 5,
1457 100,
1458 );
1459
1460 let parts = vec![MessagePart::ToolUse {
1461 id: "call_abc123".to_string(),
1462 name: "read_file".to_string(),
1463 input: serde_json::json!({"path": "/tmp/test.txt"}),
1464 }];
1465 let content = "[tool_use: read_file(call_abc123)]";
1466
1467 agent
1468 .persist_message(Role::Assistant, content, &parts, false)
1469 .await;
1470
1471 let history = agent
1472 .memory_state
1473 .memory
1474 .as_ref()
1475 .unwrap()
1476 .sqlite()
1477 .load_history(cid, 50)
1478 .await
1479 .unwrap();
1480
1481 assert_eq!(history.len(), 1);
1482 assert_eq!(history[0].role, Role::Assistant);
1483 assert_eq!(history[0].content, content);
1484 assert_eq!(history[0].parts.len(), 1);
1485 match &history[0].parts[0] {
1486 MessagePart::ToolUse { id, name, .. } => {
1487 assert_eq!(id, "call_abc123");
1488 assert_eq!(name, "read_file");
1489 }
1490 other => panic!("expected ToolUse part, got {other:?}"),
1491 }
1492 assert!(
1494 !history[0]
1495 .parts
1496 .iter()
1497 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1498 "assistant message must not contain ToolResult parts"
1499 );
1500 }
1501
1502 #[tokio::test]
1503 async fn persist_message_saves_correct_tool_result_parts() {
1504 use zeph_llm::provider::MessagePart;
1505
1506 let provider = mock_provider(vec![]);
1507 let channel = MockChannel::new(vec![]);
1508 let registry = create_test_registry();
1509 let executor = MockToolExecutor::no_tools();
1510
1511 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1512 let cid = memory.sqlite().create_conversation().await.unwrap();
1513
1514 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1515 std::sync::Arc::new(memory),
1516 cid,
1517 50,
1518 5,
1519 100,
1520 );
1521
1522 let parts = vec![MessagePart::ToolResult {
1523 tool_use_id: "call_abc123".to_string(),
1524 content: "file contents here".to_string(),
1525 is_error: false,
1526 }];
1527 let content = "[tool_result: call_abc123]\nfile contents here";
1528
1529 agent
1530 .persist_message(Role::User, content, &parts, false)
1531 .await;
1532
1533 let history = agent
1534 .memory_state
1535 .memory
1536 .as_ref()
1537 .unwrap()
1538 .sqlite()
1539 .load_history(cid, 50)
1540 .await
1541 .unwrap();
1542
1543 assert_eq!(history.len(), 1);
1544 assert_eq!(history[0].role, Role::User);
1545 assert_eq!(history[0].content, content);
1546 assert_eq!(history[0].parts.len(), 1);
1547 match &history[0].parts[0] {
1548 MessagePart::ToolResult {
1549 tool_use_id,
1550 content: result_content,
1551 is_error,
1552 } => {
1553 assert_eq!(tool_use_id, "call_abc123");
1554 assert_eq!(result_content, "file contents here");
1555 assert!(!is_error);
1556 }
1557 other => panic!("expected ToolResult part, got {other:?}"),
1558 }
1559 assert!(
1561 !history[0]
1562 .parts
1563 .iter()
1564 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1565 "user ToolResult message must not contain ToolUse parts"
1566 );
1567 }
1568
1569 #[tokio::test]
1570 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1571 use zeph_llm::provider::MessagePart;
1572
1573 let provider = mock_provider(vec![]);
1574 let channel = MockChannel::new(vec![]);
1575 let registry = create_test_registry();
1576 let executor = MockToolExecutor::no_tools();
1577
1578 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1579 let cid = memory.sqlite().create_conversation().await.unwrap();
1580
1581 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1582 std::sync::Arc::new(memory),
1583 cid,
1584 50,
1585 5,
1586 100,
1587 );
1588
1589 let assistant_parts = vec![MessagePart::ToolUse {
1591 id: "id_1".to_string(),
1592 name: "list_dir".to_string(),
1593 input: serde_json::json!({"path": "/tmp"}),
1594 }];
1595 agent
1596 .persist_message(
1597 Role::Assistant,
1598 "[tool_use: list_dir(id_1)]",
1599 &assistant_parts,
1600 false,
1601 )
1602 .await;
1603
1604 let user_parts = vec![MessagePart::ToolResult {
1606 tool_use_id: "id_1".to_string(),
1607 content: "file1.txt\nfile2.txt".to_string(),
1608 is_error: false,
1609 }];
1610 agent
1611 .persist_message(
1612 Role::User,
1613 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1614 &user_parts,
1615 false,
1616 )
1617 .await;
1618
1619 let history = agent
1620 .memory_state
1621 .memory
1622 .as_ref()
1623 .unwrap()
1624 .sqlite()
1625 .load_history(cid, 50)
1626 .await
1627 .unwrap();
1628
1629 assert_eq!(history.len(), 2);
1630
1631 assert_eq!(history[0].role, Role::Assistant);
1633 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1634 assert!(
1635 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1636 "first message must be assistant ToolUse"
1637 );
1638
1639 assert_eq!(history[1].role, Role::User);
1641 assert_eq!(
1642 history[1].content,
1643 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1644 );
1645 assert!(
1646 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1647 "second message must be user ToolResult"
1648 );
1649
1650 assert!(
1652 !history[0]
1653 .parts
1654 .iter()
1655 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1656 "assistant message must not have ToolResult parts"
1657 );
1658 assert!(
1659 !history[1]
1660 .parts
1661 .iter()
1662 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1663 "user message must not have ToolUse parts"
1664 );
1665 }
1666
1667 #[tokio::test]
1668 async fn persist_message_saves_correct_tool_output_parts() {
1669 use zeph_llm::provider::MessagePart;
1670
1671 let provider = mock_provider(vec![]);
1672 let channel = MockChannel::new(vec![]);
1673 let registry = create_test_registry();
1674 let executor = MockToolExecutor::no_tools();
1675
1676 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1677 let cid = memory.sqlite().create_conversation().await.unwrap();
1678
1679 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1680 std::sync::Arc::new(memory),
1681 cid,
1682 50,
1683 5,
1684 100,
1685 );
1686
1687 let parts = vec![MessagePart::ToolOutput {
1688 tool_name: "shell".to_string(),
1689 body: "hello from shell".to_string(),
1690 compacted_at: None,
1691 }];
1692 let content = "[tool: shell]\nhello from shell";
1693
1694 agent
1695 .persist_message(Role::User, content, &parts, false)
1696 .await;
1697
1698 let history = agent
1699 .memory_state
1700 .memory
1701 .as_ref()
1702 .unwrap()
1703 .sqlite()
1704 .load_history(cid, 50)
1705 .await
1706 .unwrap();
1707
1708 assert_eq!(history.len(), 1);
1709 assert_eq!(history[0].role, Role::User);
1710 assert_eq!(history[0].content, content);
1711 assert_eq!(history[0].parts.len(), 1);
1712 match &history[0].parts[0] {
1713 MessagePart::ToolOutput {
1714 tool_name,
1715 body,
1716 compacted_at,
1717 } => {
1718 assert_eq!(tool_name, "shell");
1719 assert_eq!(body, "hello from shell");
1720 assert!(compacted_at.is_none());
1721 }
1722 other => panic!("expected ToolOutput part, got {other:?}"),
1723 }
1724 }
1725
1726 #[tokio::test]
1729 async fn load_history_removes_trailing_orphan_tool_use() {
1730 use zeph_llm::provider::MessagePart;
1731
1732 let provider = mock_provider(vec![]);
1733 let channel = MockChannel::new(vec![]);
1734 let registry = create_test_registry();
1735 let executor = MockToolExecutor::no_tools();
1736
1737 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1738 let cid = memory.sqlite().create_conversation().await.unwrap();
1739 let sqlite = memory.sqlite();
1740
1741 sqlite
1743 .save_message(cid, "user", "do something with a tool")
1744 .await
1745 .unwrap();
1746
1747 let parts = serde_json::to_string(&[MessagePart::ToolUse {
1749 id: "call_orphan".to_string(),
1750 name: "shell".to_string(),
1751 input: serde_json::json!({"command": "ls"}),
1752 }])
1753 .unwrap();
1754 sqlite
1755 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
1756 .await
1757 .unwrap();
1758
1759 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1760 std::sync::Arc::new(memory),
1761 cid,
1762 50,
1763 5,
1764 100,
1765 );
1766
1767 let messages_before = agent.msg.messages.len();
1768 agent.load_history().await.unwrap();
1769
1770 assert_eq!(
1772 agent.msg.messages.len(),
1773 messages_before + 1,
1774 "orphaned trailing tool_use must be removed"
1775 );
1776 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
1777 }
1778
1779 #[tokio::test]
1780 async fn load_history_removes_leading_orphan_tool_result() {
1781 use zeph_llm::provider::MessagePart;
1782
1783 let provider = mock_provider(vec![]);
1784 let channel = MockChannel::new(vec![]);
1785 let registry = create_test_registry();
1786 let executor = MockToolExecutor::no_tools();
1787
1788 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1789 let cid = memory.sqlite().create_conversation().await.unwrap();
1790 let sqlite = memory.sqlite();
1791
1792 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1794 tool_use_id: "call_missing".to_string(),
1795 content: "result data".to_string(),
1796 is_error: false,
1797 }])
1798 .unwrap();
1799 sqlite
1800 .save_message_with_parts(
1801 cid,
1802 "user",
1803 "[tool_result: call_missing]\nresult data",
1804 &result_parts,
1805 )
1806 .await
1807 .unwrap();
1808
1809 sqlite
1811 .save_message(cid, "assistant", "here is my response")
1812 .await
1813 .unwrap();
1814
1815 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1816 std::sync::Arc::new(memory),
1817 cid,
1818 50,
1819 5,
1820 100,
1821 );
1822
1823 let messages_before = agent.msg.messages.len();
1824 agent.load_history().await.unwrap();
1825
1826 assert_eq!(
1828 agent.msg.messages.len(),
1829 messages_before + 1,
1830 "orphaned leading tool_result must be removed"
1831 );
1832 assert_eq!(agent.msg.messages.last().unwrap().role, Role::Assistant);
1833 }
1834
1835 #[tokio::test]
1836 async fn load_history_preserves_complete_tool_pairs() {
1837 use zeph_llm::provider::MessagePart;
1838
1839 let provider = mock_provider(vec![]);
1840 let channel = MockChannel::new(vec![]);
1841 let registry = create_test_registry();
1842 let executor = MockToolExecutor::no_tools();
1843
1844 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1845 let cid = memory.sqlite().create_conversation().await.unwrap();
1846 let sqlite = memory.sqlite();
1847
1848 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
1850 id: "call_ok".to_string(),
1851 name: "shell".to_string(),
1852 input: serde_json::json!({"command": "pwd"}),
1853 }])
1854 .unwrap();
1855 sqlite
1856 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
1857 .await
1858 .unwrap();
1859
1860 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
1861 tool_use_id: "call_ok".to_string(),
1862 content: "/home/user".to_string(),
1863 is_error: false,
1864 }])
1865 .unwrap();
1866 sqlite
1867 .save_message_with_parts(
1868 cid,
1869 "user",
1870 "[tool_result: call_ok]\n/home/user",
1871 &result_parts,
1872 )
1873 .await
1874 .unwrap();
1875
1876 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1877 std::sync::Arc::new(memory),
1878 cid,
1879 50,
1880 5,
1881 100,
1882 );
1883
1884 let messages_before = agent.msg.messages.len();
1885 agent.load_history().await.unwrap();
1886
1887 assert_eq!(
1889 agent.msg.messages.len(),
1890 messages_before + 2,
1891 "complete tool_use/tool_result pair must be preserved"
1892 );
1893 assert_eq!(agent.msg.messages[messages_before].role, Role::Assistant);
1894 assert_eq!(agent.msg.messages[messages_before + 1].role, Role::User);
1895 }
1896
1897 #[tokio::test]
1898 async fn load_history_handles_multiple_trailing_orphans() {
1899 use zeph_llm::provider::MessagePart;
1900
1901 let provider = mock_provider(vec![]);
1902 let channel = MockChannel::new(vec![]);
1903 let registry = create_test_registry();
1904 let executor = MockToolExecutor::no_tools();
1905
1906 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1907 let cid = memory.sqlite().create_conversation().await.unwrap();
1908 let sqlite = memory.sqlite();
1909
1910 sqlite.save_message(cid, "user", "start").await.unwrap();
1912
1913 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
1915 id: "call_1".to_string(),
1916 name: "shell".to_string(),
1917 input: serde_json::json!({}),
1918 }])
1919 .unwrap();
1920 sqlite
1921 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
1922 .await
1923 .unwrap();
1924
1925 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
1927 id: "call_2".to_string(),
1928 name: "read_file".to_string(),
1929 input: serde_json::json!({}),
1930 }])
1931 .unwrap();
1932 sqlite
1933 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
1934 .await
1935 .unwrap();
1936
1937 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1938 std::sync::Arc::new(memory),
1939 cid,
1940 50,
1941 5,
1942 100,
1943 );
1944
1945 let messages_before = agent.msg.messages.len();
1946 agent.load_history().await.unwrap();
1947
1948 assert_eq!(
1950 agent.msg.messages.len(),
1951 messages_before + 1,
1952 "all trailing orphaned tool_use messages must be removed"
1953 );
1954 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
1955 }
1956
1957 #[tokio::test]
1958 async fn load_history_no_tool_messages_unchanged() {
1959 let provider = mock_provider(vec![]);
1960 let channel = MockChannel::new(vec![]);
1961 let registry = create_test_registry();
1962 let executor = MockToolExecutor::no_tools();
1963
1964 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1965 let cid = memory.sqlite().create_conversation().await.unwrap();
1966 let sqlite = memory.sqlite();
1967
1968 sqlite.save_message(cid, "user", "hello").await.unwrap();
1969 sqlite
1970 .save_message(cid, "assistant", "hi there")
1971 .await
1972 .unwrap();
1973 sqlite
1974 .save_message(cid, "user", "how are you?")
1975 .await
1976 .unwrap();
1977
1978 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1979 std::sync::Arc::new(memory),
1980 cid,
1981 50,
1982 5,
1983 100,
1984 );
1985
1986 let messages_before = agent.msg.messages.len();
1987 agent.load_history().await.unwrap();
1988
1989 assert_eq!(
1991 agent.msg.messages.len(),
1992 messages_before + 3,
1993 "plain messages without tool parts must pass through unchanged"
1994 );
1995 }
1996
1997 #[tokio::test]
1998 async fn load_history_removes_both_leading_and_trailing_orphans() {
1999 use zeph_llm::provider::MessagePart;
2000
2001 let provider = mock_provider(vec![]);
2002 let channel = MockChannel::new(vec![]);
2003 let registry = create_test_registry();
2004 let executor = MockToolExecutor::no_tools();
2005
2006 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2007 let cid = memory.sqlite().create_conversation().await.unwrap();
2008 let sqlite = memory.sqlite();
2009
2010 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2012 tool_use_id: "call_leading".to_string(),
2013 content: "orphaned result".to_string(),
2014 is_error: false,
2015 }])
2016 .unwrap();
2017 sqlite
2018 .save_message_with_parts(
2019 cid,
2020 "user",
2021 "[tool_result: call_leading]\norphaned result",
2022 &result_parts,
2023 )
2024 .await
2025 .unwrap();
2026
2027 sqlite
2029 .save_message(cid, "user", "what is 2+2?")
2030 .await
2031 .unwrap();
2032 sqlite.save_message(cid, "assistant", "4").await.unwrap();
2033
2034 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2036 id: "call_trailing".to_string(),
2037 name: "shell".to_string(),
2038 input: serde_json::json!({"command": "date"}),
2039 }])
2040 .unwrap();
2041 sqlite
2042 .save_message_with_parts(
2043 cid,
2044 "assistant",
2045 "[tool_use: shell(call_trailing)]",
2046 &use_parts,
2047 )
2048 .await
2049 .unwrap();
2050
2051 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2052 std::sync::Arc::new(memory),
2053 cid,
2054 50,
2055 5,
2056 100,
2057 );
2058
2059 let messages_before = agent.msg.messages.len();
2060 agent.load_history().await.unwrap();
2061
2062 assert_eq!(
2064 agent.msg.messages.len(),
2065 messages_before + 2,
2066 "both leading and trailing orphans must be removed"
2067 );
2068 assert_eq!(agent.msg.messages[messages_before].role, Role::User);
2069 assert_eq!(agent.msg.messages[messages_before].content, "what is 2+2?");
2070 assert_eq!(
2071 agent.msg.messages[messages_before + 1].role,
2072 Role::Assistant
2073 );
2074 assert_eq!(agent.msg.messages[messages_before + 1].content, "4");
2075 }
2076
2077 #[tokio::test]
2082 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
2083 use zeph_llm::provider::MessagePart;
2084
2085 let provider = mock_provider(vec![]);
2086 let channel = MockChannel::new(vec![]);
2087 let registry = create_test_registry();
2088 let executor = MockToolExecutor::no_tools();
2089
2090 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2091 let cid = memory.sqlite().create_conversation().await.unwrap();
2092 let sqlite = memory.sqlite();
2093
2094 sqlite
2096 .save_message(cid, "user", "first question")
2097 .await
2098 .unwrap();
2099 sqlite
2100 .save_message(cid, "assistant", "first answer")
2101 .await
2102 .unwrap();
2103
2104 let use_parts = serde_json::to_string(&[
2108 MessagePart::ToolUse {
2109 id: "call_mid_1".to_string(),
2110 name: "shell".to_string(),
2111 input: serde_json::json!({"command": "ls"}),
2112 },
2113 MessagePart::Text {
2114 text: "Let me check the files.".to_string(),
2115 },
2116 ])
2117 .unwrap();
2118 sqlite
2119 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
2120 .await
2121 .unwrap();
2122
2123 sqlite
2125 .save_message(cid, "user", "second question")
2126 .await
2127 .unwrap();
2128 sqlite
2129 .save_message(cid, "assistant", "second answer")
2130 .await
2131 .unwrap();
2132
2133 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2134 std::sync::Arc::new(memory),
2135 cid,
2136 50,
2137 5,
2138 100,
2139 );
2140
2141 let messages_before = agent.msg.messages.len();
2142 agent.load_history().await.unwrap();
2143
2144 assert_eq!(
2147 agent.msg.messages.len(),
2148 messages_before + 5,
2149 "message count must be 5 (orphan message kept — has text content)"
2150 );
2151
2152 let orphan = &agent.msg.messages[messages_before + 2];
2154 assert_eq!(orphan.role, Role::Assistant);
2155 assert!(
2156 !orphan
2157 .parts
2158 .iter()
2159 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
2160 "orphaned ToolUse parts must be stripped from mid-history message"
2161 );
2162 assert!(
2164 orphan.parts.iter().any(
2165 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
2166 ),
2167 "text content of orphaned assistant message must be preserved"
2168 );
2169 }
2170
2171 #[tokio::test]
2176 async fn load_history_keeps_tool_only_user_message() {
2177 use zeph_llm::provider::MessagePart;
2178
2179 let provider = mock_provider(vec![]);
2180 let channel = MockChannel::new(vec![]);
2181 let registry = create_test_registry();
2182 let executor = MockToolExecutor::no_tools();
2183
2184 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2185 let cid = memory.sqlite().create_conversation().await.unwrap();
2186 let sqlite = memory.sqlite();
2187
2188 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2190 id: "call_rc3".to_string(),
2191 name: "memory_save".to_string(),
2192 input: serde_json::json!({"content": "something"}),
2193 }])
2194 .unwrap();
2195 sqlite
2196 .save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
2197 .await
2198 .unwrap();
2199
2200 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2202 tool_use_id: "call_rc3".to_string(),
2203 content: "saved".to_string(),
2204 is_error: false,
2205 }])
2206 .unwrap();
2207 sqlite
2208 .save_message_with_parts(cid, "user", "", &result_parts)
2209 .await
2210 .unwrap();
2211
2212 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2213
2214 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2215 std::sync::Arc::new(memory),
2216 cid,
2217 50,
2218 5,
2219 100,
2220 );
2221
2222 let messages_before = agent.msg.messages.len();
2223 agent.load_history().await.unwrap();
2224
2225 assert_eq!(
2228 agent.msg.messages.len(),
2229 messages_before + 3,
2230 "user message with empty content but ToolResult parts must not be dropped"
2231 );
2232
2233 let user_msg = &agent.msg.messages[messages_before + 1];
2235 assert_eq!(user_msg.role, Role::User);
2236 assert!(
2237 user_msg.parts.iter().any(
2238 |p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
2239 ),
2240 "ToolResult part must be preserved on user message with empty content"
2241 );
2242 }
2243
2244 #[tokio::test]
2248 async fn strip_orphans_removes_orphaned_tool_result() {
2249 use zeph_llm::provider::MessagePart;
2250
2251 let provider = mock_provider(vec![]);
2252 let channel = MockChannel::new(vec![]);
2253 let registry = create_test_registry();
2254 let executor = MockToolExecutor::no_tools();
2255
2256 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2257 let cid = memory.sqlite().create_conversation().await.unwrap();
2258 let sqlite = memory.sqlite();
2259
2260 sqlite.save_message(cid, "user", "hello").await.unwrap();
2262 sqlite.save_message(cid, "assistant", "hi").await.unwrap();
2263
2264 sqlite
2266 .save_message(cid, "assistant", "plain answer")
2267 .await
2268 .unwrap();
2269
2270 let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2272 tool_use_id: "call_nonexistent".to_string(),
2273 content: "stale result".to_string(),
2274 is_error: false,
2275 }])
2276 .unwrap();
2277 sqlite
2278 .save_message_with_parts(
2279 cid,
2280 "user",
2281 "[tool_result: call_nonexistent]\nstale result",
2282 &orphan_result_parts,
2283 )
2284 .await
2285 .unwrap();
2286
2287 sqlite
2288 .save_message(cid, "assistant", "final")
2289 .await
2290 .unwrap();
2291
2292 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2293 std::sync::Arc::new(memory),
2294 cid,
2295 50,
2296 5,
2297 100,
2298 );
2299
2300 let messages_before = agent.msg.messages.len();
2301 agent.load_history().await.unwrap();
2302
2303 let loaded = &agent.msg.messages[messages_before..];
2307 for msg in loaded {
2308 assert!(
2309 !msg.parts.iter().any(|p| matches!(
2310 p,
2311 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
2312 )),
2313 "orphaned ToolResult part must be stripped from history"
2314 );
2315 }
2316 }
2317
2318 #[tokio::test]
2321 async fn strip_orphans_keeps_complete_pair() {
2322 use zeph_llm::provider::MessagePart;
2323
2324 let provider = mock_provider(vec![]);
2325 let channel = MockChannel::new(vec![]);
2326 let registry = create_test_registry();
2327 let executor = MockToolExecutor::no_tools();
2328
2329 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2330 let cid = memory.sqlite().create_conversation().await.unwrap();
2331 let sqlite = memory.sqlite();
2332
2333 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2334 id: "call_valid".to_string(),
2335 name: "shell".to_string(),
2336 input: serde_json::json!({"command": "ls"}),
2337 }])
2338 .unwrap();
2339 sqlite
2340 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2341 .await
2342 .unwrap();
2343
2344 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2345 tool_use_id: "call_valid".to_string(),
2346 content: "file.rs".to_string(),
2347 is_error: false,
2348 }])
2349 .unwrap();
2350 sqlite
2351 .save_message_with_parts(cid, "user", "", &result_parts)
2352 .await
2353 .unwrap();
2354
2355 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2356 std::sync::Arc::new(memory),
2357 cid,
2358 50,
2359 5,
2360 100,
2361 );
2362
2363 let messages_before = agent.msg.messages.len();
2364 agent.load_history().await.unwrap();
2365
2366 assert_eq!(
2367 agent.msg.messages.len(),
2368 messages_before + 2,
2369 "complete tool_use/tool_result pair must be preserved"
2370 );
2371
2372 let user_msg = &agent.msg.messages[messages_before + 1];
2373 assert!(
2374 user_msg.parts.iter().any(|p| matches!(
2375 p,
2376 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
2377 )),
2378 "ToolResult part for a matched tool_use must not be stripped"
2379 );
2380 }
2381
2382 #[tokio::test]
2385 async fn strip_orphans_mixed_history() {
2386 use zeph_llm::provider::MessagePart;
2387
2388 let provider = mock_provider(vec![]);
2389 let channel = MockChannel::new(vec![]);
2390 let registry = create_test_registry();
2391 let executor = MockToolExecutor::no_tools();
2392
2393 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2394 let cid = memory.sqlite().create_conversation().await.unwrap();
2395 let sqlite = memory.sqlite();
2396
2397 let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
2399 id: "call_good".to_string(),
2400 name: "shell".to_string(),
2401 input: serde_json::json!({"command": "pwd"}),
2402 }])
2403 .unwrap();
2404 sqlite
2405 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
2406 .await
2407 .unwrap();
2408
2409 let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
2410 tool_use_id: "call_good".to_string(),
2411 content: "/home".to_string(),
2412 is_error: false,
2413 }])
2414 .unwrap();
2415 sqlite
2416 .save_message_with_parts(cid, "user", "", &result_parts_ok)
2417 .await
2418 .unwrap();
2419
2420 sqlite
2422 .save_message(cid, "assistant", "text only")
2423 .await
2424 .unwrap();
2425
2426 let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
2427 tool_use_id: "call_ghost".to_string(),
2428 content: "ghost result".to_string(),
2429 is_error: false,
2430 }])
2431 .unwrap();
2432 sqlite
2433 .save_message_with_parts(
2434 cid,
2435 "user",
2436 "[tool_result: call_ghost]\nghost result",
2437 &orphan_parts,
2438 )
2439 .await
2440 .unwrap();
2441
2442 sqlite
2443 .save_message(cid, "assistant", "final reply")
2444 .await
2445 .unwrap();
2446
2447 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2448 std::sync::Arc::new(memory),
2449 cid,
2450 50,
2451 5,
2452 100,
2453 );
2454
2455 let messages_before = agent.msg.messages.len();
2456 agent.load_history().await.unwrap();
2457
2458 let loaded = &agent.msg.messages[messages_before..];
2459
2460 for msg in loaded {
2462 assert!(
2463 !msg.parts.iter().any(|p| matches!(
2464 p,
2465 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
2466 )),
2467 "orphaned ToolResult (call_ghost) must be stripped from history"
2468 );
2469 }
2470
2471 let has_good_result = loaded.iter().any(|msg| {
2474 msg.role == Role::User
2475 && msg.parts.iter().any(|p| {
2476 matches!(
2477 p,
2478 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
2479 )
2480 })
2481 });
2482 assert!(
2483 has_good_result,
2484 "matched ToolResult (call_good) must be preserved in history"
2485 );
2486 }
2487
2488 #[tokio::test]
2491 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
2492 use zeph_llm::provider::MessagePart;
2493
2494 let provider = mock_provider(vec![]);
2495 let channel = MockChannel::new(vec![]);
2496 let registry = create_test_registry();
2497 let executor = MockToolExecutor::no_tools();
2498
2499 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2500 let cid = memory.sqlite().create_conversation().await.unwrap();
2501 let sqlite = memory.sqlite();
2502
2503 sqlite
2504 .save_message(cid, "user", "run a command")
2505 .await
2506 .unwrap();
2507
2508 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2510 id: "call_ok".to_string(),
2511 name: "shell".to_string(),
2512 input: serde_json::json!({"command": "echo hi"}),
2513 }])
2514 .unwrap();
2515 sqlite
2516 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2517 .await
2518 .unwrap();
2519
2520 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2522 tool_use_id: "call_ok".to_string(),
2523 content: "hi".to_string(),
2524 is_error: false,
2525 }])
2526 .unwrap();
2527 sqlite
2528 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
2529 .await
2530 .unwrap();
2531
2532 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2533
2534 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2535 std::sync::Arc::new(memory),
2536 cid,
2537 50,
2538 5,
2539 100,
2540 );
2541
2542 let messages_before = agent.msg.messages.len();
2543 agent.load_history().await.unwrap();
2544
2545 assert_eq!(
2547 agent.msg.messages.len(),
2548 messages_before + 4,
2549 "matched tool pair must not be removed"
2550 );
2551 let tool_msg = &agent.msg.messages[messages_before + 1];
2552 assert!(
2553 tool_msg
2554 .parts
2555 .iter()
2556 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
2557 "matched ToolUse parts must be preserved"
2558 );
2559 }
2560
2561 #[tokio::test]
2565 async fn persist_cancelled_tool_results_pairs_tool_use() {
2566 use zeph_llm::provider::MessagePart;
2567
2568 let provider = mock_provider(vec![]);
2569 let channel = MockChannel::new(vec![]);
2570 let registry = create_test_registry();
2571 let executor = MockToolExecutor::no_tools();
2572
2573 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2574 let cid = memory.sqlite().create_conversation().await.unwrap();
2575
2576 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2577 std::sync::Arc::new(memory),
2578 cid,
2579 50,
2580 5,
2581 100,
2582 );
2583
2584 let tool_calls = vec![
2586 zeph_llm::provider::ToolUseRequest {
2587 id: "cancel_id_1".to_string(),
2588 name: "shell".to_string(),
2589 input: serde_json::json!({}),
2590 },
2591 zeph_llm::provider::ToolUseRequest {
2592 id: "cancel_id_2".to_string(),
2593 name: "read_file".to_string(),
2594 input: serde_json::json!({}),
2595 },
2596 ];
2597
2598 agent.persist_cancelled_tool_results(&tool_calls).await;
2599
2600 let history = agent
2601 .memory_state
2602 .memory
2603 .as_ref()
2604 .unwrap()
2605 .sqlite()
2606 .load_history(cid, 50)
2607 .await
2608 .unwrap();
2609
2610 assert_eq!(history.len(), 1);
2612 assert_eq!(history[0].role, Role::User);
2613
2614 for tc in &tool_calls {
2616 assert!(
2617 history[0].parts.iter().any(|p| matches!(
2618 p,
2619 MessagePart::ToolResult { tool_use_id, is_error, .. }
2620 if tool_use_id == &tc.id && *is_error
2621 )),
2622 "tombstone ToolResult for {} must be present and is_error=true",
2623 tc.id
2624 );
2625 }
2626 }
2627
2628 #[test]
2631 fn meaningful_content_empty_string() {
2632 assert!(!has_meaningful_content(""));
2633 }
2634
2635 #[test]
2636 fn meaningful_content_whitespace_only() {
2637 assert!(!has_meaningful_content(" \n\t "));
2638 }
2639
2640 #[test]
2641 fn meaningful_content_tool_use_only() {
2642 assert!(!has_meaningful_content("[tool_use: shell(call_1)]"));
2643 }
2644
2645 #[test]
2646 fn meaningful_content_tool_use_no_parens() {
2647 assert!(!has_meaningful_content("[tool_use: memory_save]"));
2649 }
2650
2651 #[test]
2652 fn meaningful_content_tool_result_with_body() {
2653 assert!(!has_meaningful_content(
2654 "[tool_result: call_1]\nsome output here"
2655 ));
2656 }
2657
2658 #[test]
2659 fn meaningful_content_tool_result_empty_body() {
2660 assert!(!has_meaningful_content("[tool_result: call_1]\n"));
2661 }
2662
2663 #[test]
2664 fn meaningful_content_tool_output_inline() {
2665 assert!(!has_meaningful_content("[tool output: bash] some result"));
2666 }
2667
2668 #[test]
2669 fn meaningful_content_tool_output_pruned() {
2670 assert!(!has_meaningful_content("[tool output: bash] (pruned)"));
2671 }
2672
2673 #[test]
2674 fn meaningful_content_tool_output_fenced() {
2675 assert!(!has_meaningful_content(
2676 "[tool output: bash]\n```\nls output\n```"
2677 ));
2678 }
2679
2680 #[test]
2681 fn meaningful_content_multiple_tool_use_tags() {
2682 assert!(!has_meaningful_content(
2683 "[tool_use: bash(id1)][tool_use: read(id2)]"
2684 ));
2685 }
2686
2687 #[test]
2688 fn meaningful_content_multiple_tool_use_tags_space_separator() {
2689 assert!(!has_meaningful_content(
2691 "[tool_use: bash(id1)] [tool_use: read(id2)]"
2692 ));
2693 }
2694
2695 #[test]
2696 fn meaningful_content_multiple_tool_use_tags_newline_separator() {
2697 assert!(!has_meaningful_content(
2699 "[tool_use: bash(id1)]\n[tool_use: read(id2)]"
2700 ));
2701 }
2702
2703 #[test]
2704 fn meaningful_content_tool_result_followed_by_tool_use() {
2705 assert!(!has_meaningful_content(
2707 "[tool_result: call_1]\nresult\n[tool_use: bash(call_2)]"
2708 ));
2709 }
2710
2711 #[test]
2712 fn meaningful_content_real_text_only() {
2713 assert!(has_meaningful_content("Hello, how can I help you?"));
2714 }
2715
2716 #[test]
2717 fn meaningful_content_text_before_tool_tag() {
2718 assert!(has_meaningful_content("Let me check. [tool_use: bash(id)]"));
2719 }
2720
2721 #[test]
2722 fn meaningful_content_text_after_tool_use_tag() {
2723 assert!(has_meaningful_content("[tool_use: bash] I ran the command"));
2727 }
2728
2729 #[test]
2730 fn meaningful_content_text_between_tags() {
2731 assert!(has_meaningful_content(
2732 "[tool_use: bash(id1)]\nand then\n[tool_use: read(id2)]"
2733 ));
2734 }
2735
2736 #[test]
2737 fn meaningful_content_malformed_tag_no_closing_bracket() {
2738 assert!(has_meaningful_content("[tool_use: "));
2740 }
2741
2742 #[test]
2743 fn meaningful_content_tool_use_and_tool_result_only() {
2744 assert!(!has_meaningful_content(
2746 "[tool_use: memory_save(call_abc)]\n[tool_result: call_abc]\nsaved"
2747 ));
2748 }
2749
2750 #[test]
2751 fn meaningful_content_tool_result_body_with_json_array() {
2752 assert!(!has_meaningful_content(
2753 "[tool_result: id1]\n[\"array\", \"value\"]"
2754 ));
2755 }
2756
2757 #[tokio::test]
2768 async fn issue_2529_orphaned_legacy_content_pair_is_soft_deleted() {
2769 use zeph_llm::provider::MessagePart;
2770
2771 let provider = mock_provider(vec![]);
2772 let channel = MockChannel::new(vec![]);
2773 let registry = create_test_registry();
2774 let executor = MockToolExecutor::no_tools();
2775
2776 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2777 let cid = memory.sqlite().create_conversation().await.unwrap();
2778 let sqlite = memory.sqlite();
2779
2780 sqlite
2782 .save_message(cid, "user", "save this for me")
2783 .await
2784 .unwrap();
2785
2786 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2789 id: "call_2529".to_string(),
2790 name: "memory_save".to_string(),
2791 input: serde_json::json!({"content": "save this"}),
2792 }])
2793 .unwrap();
2794 let orphan_assistant_id = sqlite
2795 .save_message_with_parts(
2796 cid,
2797 "assistant",
2798 "[tool_use: memory_save(call_2529)]",
2799 &use_parts,
2800 )
2801 .await
2802 .unwrap();
2803
2804 sqlite
2809 .save_message(cid, "assistant", "here is a plain reply")
2810 .await
2811 .unwrap();
2812
2813 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2814 tool_use_id: "call_2529".to_string(),
2815 content: "saved".to_string(),
2816 is_error: false,
2817 }])
2818 .unwrap();
2819 let orphan_user_id = sqlite
2820 .save_message_with_parts(
2821 cid,
2822 "user",
2823 "[tool_result: call_2529]\nsaved",
2824 &result_parts,
2825 )
2826 .await
2827 .unwrap();
2828
2829 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2831
2832 let memory_arc = std::sync::Arc::new(memory);
2833 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2834 memory_arc.clone(),
2835 cid,
2836 50,
2837 5,
2838 100,
2839 );
2840
2841 agent.load_history().await.unwrap();
2842
2843 let assistant_deleted_count: Vec<i64> = zeph_db::query_scalar(
2846 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
2847 )
2848 .bind(orphan_assistant_id)
2849 .fetch_all(memory_arc.sqlite().pool())
2850 .await
2851 .unwrap();
2852
2853 let user_deleted_count: Vec<i64> = zeph_db::query_scalar(
2854 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
2855 )
2856 .bind(orphan_user_id)
2857 .fetch_all(memory_arc.sqlite().pool())
2858 .await
2859 .unwrap();
2860
2861 assert_eq!(
2862 assistant_deleted_count.first().copied().unwrap_or(0),
2863 1,
2864 "orphaned assistant[ToolUse] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
2865 );
2866 assert_eq!(
2867 user_deleted_count.first().copied().unwrap_or(0),
2868 1,
2869 "orphaned user[ToolResult] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
2870 );
2871 }
2872
2873 #[tokio::test]
2877 async fn issue_2529_soft_delete_is_idempotent_across_sessions() {
2878 use zeph_llm::provider::MessagePart;
2879
2880 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2881 let cid = memory.sqlite().create_conversation().await.unwrap();
2882 let sqlite = memory.sqlite();
2883
2884 sqlite
2886 .save_message(cid, "user", "do something")
2887 .await
2888 .unwrap();
2889
2890 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2892 id: "call_idem".to_string(),
2893 name: "shell".to_string(),
2894 input: serde_json::json!({"command": "ls"}),
2895 }])
2896 .unwrap();
2897 sqlite
2898 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_idem)]", &use_parts)
2899 .await
2900 .unwrap();
2901
2902 sqlite
2904 .save_message(cid, "assistant", "continuing")
2905 .await
2906 .unwrap();
2907
2908 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2910 tool_use_id: "call_idem".to_string(),
2911 content: "output".to_string(),
2912 is_error: false,
2913 }])
2914 .unwrap();
2915 sqlite
2916 .save_message_with_parts(
2917 cid,
2918 "user",
2919 "[tool_result: call_idem]\noutput",
2920 &result_parts,
2921 )
2922 .await
2923 .unwrap();
2924
2925 sqlite
2926 .save_message(cid, "assistant", "final")
2927 .await
2928 .unwrap();
2929
2930 let memory_arc = std::sync::Arc::new(memory);
2931
2932 let mut agent1 = Agent::new(
2934 mock_provider(vec![]),
2935 MockChannel::new(vec![]),
2936 create_test_registry(),
2937 None,
2938 5,
2939 MockToolExecutor::no_tools(),
2940 )
2941 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
2942 agent1.load_history().await.unwrap();
2943 let count_after_first = agent1.msg.messages.len();
2944
2945 let mut agent2 = Agent::new(
2948 mock_provider(vec![]),
2949 MockChannel::new(vec![]),
2950 create_test_registry(),
2951 None,
2952 5,
2953 MockToolExecutor::no_tools(),
2954 )
2955 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
2956 agent2.load_history().await.unwrap();
2957 let count_after_second = agent2.msg.messages.len();
2958
2959 assert_eq!(
2961 count_after_first, count_after_second,
2962 "second load_history must load the same message count as the first (soft-deleted orphans excluded)"
2963 );
2964 }
2965
2966 #[tokio::test]
2970 async fn issue_2529_message_with_text_and_tool_tag_is_kept_after_part_strip() {
2971 use zeph_llm::provider::MessagePart;
2972
2973 let provider = mock_provider(vec![]);
2974 let channel = MockChannel::new(vec![]);
2975 let registry = create_test_registry();
2976 let executor = MockToolExecutor::no_tools();
2977
2978 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2979 let cid = memory.sqlite().create_conversation().await.unwrap();
2980 let sqlite = memory.sqlite();
2981
2982 sqlite
2984 .save_message(cid, "user", "check the files")
2985 .await
2986 .unwrap();
2987
2988 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2991 id: "call_mixed".to_string(),
2992 name: "shell".to_string(),
2993 input: serde_json::json!({"command": "ls"}),
2994 }])
2995 .unwrap();
2996 sqlite
2997 .save_message_with_parts(
2998 cid,
2999 "assistant",
3000 "Let me list the directory. [tool_use: shell(call_mixed)]",
3001 &use_parts,
3002 )
3003 .await
3004 .unwrap();
3005
3006 sqlite.save_message(cid, "user", "thanks").await.unwrap();
3008 sqlite
3009 .save_message(cid, "assistant", "you are welcome")
3010 .await
3011 .unwrap();
3012
3013 let memory_arc = std::sync::Arc::new(memory);
3014 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3015 memory_arc.clone(),
3016 cid,
3017 50,
3018 5,
3019 100,
3020 );
3021
3022 let messages_before = agent.msg.messages.len();
3023 agent.load_history().await.unwrap();
3024
3025 assert_eq!(
3027 agent.msg.messages.len(),
3028 messages_before + 4,
3029 "assistant message with text + tool tag must not be removed after ToolUse strip"
3030 );
3031
3032 let mixed_msg = agent
3034 .msg
3035 .messages
3036 .iter()
3037 .find(|m| m.content.contains("Let me list the directory"))
3038 .expect("mixed-content assistant message must still be in history");
3039 assert!(
3040 !mixed_msg
3041 .parts
3042 .iter()
3043 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
3044 "orphaned ToolUse parts must be stripped even when message has meaningful text"
3045 );
3046 assert_eq!(
3047 mixed_msg.content, "Let me list the directory. [tool_use: shell(call_mixed)]",
3048 "content field must be unchanged — only parts are stripped"
3049 );
3050 }
3051
3052 #[tokio::test]
3055 async fn persist_message_skipped_tool_result_does_not_embed() {
3056 use zeph_llm::provider::MessagePart;
3057
3058 let provider = mock_provider(vec![]);
3059 let channel = MockChannel::new(vec![]);
3060 let registry = create_test_registry();
3061 let executor = MockToolExecutor::no_tools();
3062
3063 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
3064 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3065 let cid = memory.sqlite().create_conversation().await.unwrap();
3066
3067 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3068 .with_metrics(tx)
3069 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
3070 .with_autosave_config(true, 0);
3071
3072 let parts = vec![MessagePart::ToolResult {
3073 tool_use_id: "tu1".into(),
3074 content: "[skipped] bash tool was blocked by utility gate".into(),
3075 is_error: false,
3076 }];
3077
3078 agent
3079 .persist_message(
3080 Role::User,
3081 "[skipped] bash tool was blocked by utility gate",
3082 &parts,
3083 false,
3084 )
3085 .await;
3086
3087 assert_eq!(
3088 rx.borrow().embeddings_generated,
3089 0,
3090 "[skipped] ToolResult must not be embedded into Qdrant"
3091 );
3092 }
3093
3094 #[tokio::test]
3095 async fn persist_message_stopped_tool_result_does_not_embed() {
3096 use zeph_llm::provider::MessagePart;
3097
3098 let provider = mock_provider(vec![]);
3099 let channel = MockChannel::new(vec![]);
3100 let registry = create_test_registry();
3101 let executor = MockToolExecutor::no_tools();
3102
3103 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
3104 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3105 let cid = memory.sqlite().create_conversation().await.unwrap();
3106
3107 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3108 .with_metrics(tx)
3109 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
3110 .with_autosave_config(true, 0);
3111
3112 let parts = vec![MessagePart::ToolResult {
3113 tool_use_id: "tu2".into(),
3114 content: "[stopped] execution limit reached".into(),
3115 is_error: false,
3116 }];
3117
3118 agent
3119 .persist_message(
3120 Role::User,
3121 "[stopped] execution limit reached",
3122 &parts,
3123 false,
3124 )
3125 .await;
3126
3127 assert_eq!(
3128 rx.borrow().embeddings_generated,
3129 0,
3130 "[stopped] ToolResult must not be embedded into Qdrant"
3131 );
3132 }
3133
3134 #[tokio::test]
3135 async fn persist_message_normal_tool_result_is_saved_not_blocked_by_guard() {
3136 use zeph_llm::provider::MessagePart;
3139
3140 let provider = mock_provider(vec![]);
3141 let channel = MockChannel::new(vec![]);
3142 let registry = create_test_registry();
3143 let executor = MockToolExecutor::no_tools();
3144
3145 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3146 let cid = memory.sqlite().create_conversation().await.unwrap();
3147 let memory_arc = std::sync::Arc::new(memory);
3148
3149 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3150 .with_memory(memory_arc.clone(), cid, 50, 5, 100)
3151 .with_autosave_config(true, 0);
3152
3153 let content = "total 42\ndrwxr-xr-x 5 user group";
3154 let parts = vec![MessagePart::ToolResult {
3155 tool_use_id: "tu3".into(),
3156 content: content.into(),
3157 is_error: false,
3158 }];
3159
3160 agent
3161 .persist_message(Role::User, content, &parts, false)
3162 .await;
3163
3164 let history = memory_arc.sqlite().load_history(cid, 50).await.unwrap();
3166 assert_eq!(
3167 history.len(),
3168 1,
3169 "normal ToolResult must be saved to SQLite"
3170 );
3171 assert_eq!(history[0].content, content);
3172 }
3173}