1use std::collections::HashSet;
5
6use crate::channel::Channel;
7use zeph_llm::provider::{LlmProvider as _, Message, MessagePart, Role};
8use zeph_memory::store::role_str;
9
10use super::Agent;
11
12fn sanitize_tool_pairs(messages: &mut Vec<Message>) -> (usize, Vec<i64>) {
28 let mut removed = 0;
29 let mut db_ids: Vec<i64> = Vec::new();
30
31 loop {
32 if let Some(last) = messages.last()
34 && last.role == Role::Assistant
35 && last
36 .parts
37 .iter()
38 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
39 {
40 let ids: Vec<String> = last
41 .parts
42 .iter()
43 .filter_map(|p| {
44 if let MessagePart::ToolUse { id, .. } = p {
45 Some(id.clone())
46 } else {
47 None
48 }
49 })
50 .collect();
51 tracing::warn!(
52 tool_ids = ?ids,
53 "removing orphaned trailing tool_use message from restored history"
54 );
55 if let Some(db_id) = messages.last().and_then(|m| m.metadata.db_id) {
56 db_ids.push(db_id);
57 }
58 messages.pop();
59 removed += 1;
60 continue;
61 }
62
63 if let Some(first) = messages.first()
65 && first.role == Role::User
66 && first
67 .parts
68 .iter()
69 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
70 {
71 let ids: Vec<String> = first
72 .parts
73 .iter()
74 .filter_map(|p| {
75 if let MessagePart::ToolResult { tool_use_id, .. } = p {
76 Some(tool_use_id.clone())
77 } else {
78 None
79 }
80 })
81 .collect();
82 tracing::warn!(
83 tool_use_ids = ?ids,
84 "removing orphaned leading tool_result message from restored history"
85 );
86 if let Some(db_id) = messages.first().and_then(|m| m.metadata.db_id) {
87 db_ids.push(db_id);
88 }
89 messages.remove(0);
90 removed += 1;
91 continue;
92 }
93
94 break;
95 }
96
97 let (mid_removed, mid_db_ids) = strip_mid_history_orphans(messages);
100 removed += mid_removed;
101 db_ids.extend(mid_db_ids);
102
103 (removed, db_ids)
104}
105
106fn orphaned_tool_use_ids(msg: &Message, next_msg: Option<&Message>) -> HashSet<String> {
108 let matched: HashSet<String> = next_msg
109 .filter(|n| n.role == Role::User)
110 .map(|n| {
111 msg.parts
112 .iter()
113 .filter_map(|p| if let MessagePart::ToolUse { id, .. } = p { Some(id.clone()) } else { None })
114 .filter(|uid| n.parts.iter().any(|np| matches!(np, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == uid)))
115 .collect()
116 })
117 .unwrap_or_default();
118 msg.parts
119 .iter()
120 .filter_map(|p| {
121 if let MessagePart::ToolUse { id, .. } = p
122 && !matched.contains(id)
123 {
124 Some(id.clone())
125 } else {
126 None
127 }
128 })
129 .collect()
130}
131
132fn orphaned_tool_result_ids(msg: &Message, prev_msg: Option<&Message>) -> HashSet<String> {
134 let avail: HashSet<&str> = prev_msg
135 .filter(|p| p.role == Role::Assistant)
136 .map(|p| {
137 p.parts
138 .iter()
139 .filter_map(|part| {
140 if let MessagePart::ToolUse { id, .. } = part {
141 Some(id.as_str())
142 } else {
143 None
144 }
145 })
146 .collect()
147 })
148 .unwrap_or_default();
149 msg.parts
150 .iter()
151 .filter_map(|p| {
152 if let MessagePart::ToolResult { tool_use_id, .. } = p
153 && !avail.contains(tool_use_id.as_str())
154 {
155 Some(tool_use_id.clone())
156 } else {
157 None
158 }
159 })
160 .collect()
161}
162
163fn has_meaningful_content(content: &str) -> bool {
184 const PREFIXES: [&str; 3] = ["[tool_use: ", "[tool_result: ", "[tool output: "];
185
186 let mut remaining = content.trim();
187
188 loop {
189 let next = PREFIXES
191 .iter()
192 .filter_map(|prefix| remaining.find(prefix).map(|pos| (pos, *prefix)))
193 .min_by_key(|(pos, _)| *pos);
194
195 let Some((start, prefix)) = next else {
196 break;
198 };
199
200 if !remaining[..start].trim().is_empty() {
202 return true;
203 }
204
205 let after_prefix = &remaining[start + prefix.len()..];
207 let Some(close) = after_prefix.find(']') else {
208 return true;
210 };
211
212 let tag_end = start + prefix.len() + close + 1;
214
215 if prefix == "[tool_result: " || prefix == "[tool output: " {
216 let body = remaining[tag_end..].trim_start_matches('\n');
219 let next_tag = PREFIXES
220 .iter()
221 .filter_map(|p| body.find(p))
222 .min()
223 .unwrap_or(body.len());
224 remaining = &body[next_tag..];
225 } else {
226 remaining = &remaining[tag_end..];
227 }
228 }
229
230 !remaining.trim().is_empty()
231}
232
233fn strip_mid_history_orphans(messages: &mut Vec<Message>) -> (usize, Vec<i64>) {
246 let mut removed = 0;
247 let mut db_ids: Vec<i64> = Vec::new();
248 let mut i = 0;
249 while i < messages.len() {
250 if messages[i].role == Role::Assistant
254 && messages[i]
255 .parts
256 .iter()
257 .any(|p| matches!(p, MessagePart::ToolUse { .. }))
258 {
259 let orphaned_ids = orphaned_tool_use_ids(&messages[i], messages.get(i + 1));
260 if !orphaned_ids.is_empty() {
261 tracing::warn!(
262 tool_ids = ?orphaned_ids,
263 index = i,
264 "stripping orphaned mid-history tool_use parts from assistant message"
265 );
266 messages[i].parts.retain(
267 |p| !matches!(p, MessagePart::ToolUse { id, .. } if orphaned_ids.contains(id)),
268 );
269 let is_empty =
270 !has_meaningful_content(&messages[i].content) && messages[i].parts.is_empty();
271 if is_empty {
272 if let Some(db_id) = messages[i].metadata.db_id {
273 db_ids.push(db_id);
274 }
275 messages.remove(i);
276 removed += 1;
277 continue; }
279 }
280 }
281
282 if messages[i].role == Role::User
284 && messages[i]
285 .parts
286 .iter()
287 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
288 {
289 let orphaned_ids = orphaned_tool_result_ids(
290 &messages[i],
291 if i > 0 { messages.get(i - 1) } else { None },
292 );
293 if !orphaned_ids.is_empty() {
294 tracing::warn!(
295 tool_use_ids = ?orphaned_ids,
296 index = i,
297 "stripping orphaned mid-history tool_result parts from user message"
298 );
299 messages[i].parts.retain(|p| {
300 !matches!(p, MessagePart::ToolResult { tool_use_id, .. } if orphaned_ids.contains(tool_use_id.as_str()))
301 });
302
303 let is_empty =
304 !has_meaningful_content(&messages[i].content) && messages[i].parts.is_empty();
305 if is_empty {
306 if let Some(db_id) = messages[i].metadata.db_id {
307 db_ids.push(db_id);
308 }
309 messages.remove(i);
310 removed += 1;
311 continue;
313 }
314 }
315 }
316
317 i += 1;
318 }
319 (removed, db_ids)
320}
321
322impl<C: Channel> Agent<C> {
323 pub async fn load_history(&mut self) -> Result<(), super::error::AgentError> {
329 let (Some(memory), Some(cid)) =
330 (&self.memory_state.memory, self.memory_state.conversation_id)
331 else {
332 return Ok(());
333 };
334
335 let history = memory
336 .sqlite()
337 .load_history_filtered(cid, self.memory_state.history_limit, Some(true), None)
338 .await?;
339 if !history.is_empty() {
340 let mut loaded = 0;
341 let mut skipped = 0;
342
343 for msg in history {
344 if !has_meaningful_content(&msg.content) && msg.parts.is_empty() {
349 tracing::warn!("skipping empty message from history (role: {:?})", msg.role);
350 skipped += 1;
351 continue;
352 }
353 self.msg.messages.push(msg);
354 loaded += 1;
355 }
356
357 let history_start = self.msg.messages.len() - loaded;
359 let mut restored_slice = self.msg.messages.split_off(history_start);
360 let (orphans, orphan_db_ids) = sanitize_tool_pairs(&mut restored_slice);
361 skipped += orphans;
362 loaded = loaded.saturating_sub(orphans);
363 self.msg.messages.append(&mut restored_slice);
364
365 if !orphan_db_ids.is_empty() {
366 let ids: Vec<zeph_memory::types::MessageId> = orphan_db_ids
367 .iter()
368 .map(|&id| zeph_memory::types::MessageId(id))
369 .collect();
370 if let Err(e) = memory.sqlite().soft_delete_messages(&ids).await {
371 tracing::warn!(
372 count = ids.len(),
373 error = %e,
374 "failed to soft-delete orphaned tool-pair messages from DB"
375 );
376 } else {
377 tracing::debug!(
378 count = ids.len(),
379 "soft-deleted orphaned tool-pair messages from DB"
380 );
381 }
382 }
383
384 tracing::info!("restored {loaded} message(s) from conversation {cid}");
385 if skipped > 0 {
386 tracing::warn!("skipped {skipped} empty/orphaned message(s) from history");
387 }
388
389 if loaded > 0 {
390 let _ = memory
393 .sqlite()
394 .increment_session_counts_for_conversation(cid)
395 .await
396 .inspect_err(|e| {
397 tracing::warn!(error = %e, "failed to increment tier session counts");
398 });
399 }
400 }
401
402 if let Ok(count) = memory.message_count(cid).await {
403 let count_u64 = u64::try_from(count).unwrap_or(0);
404 self.update_metrics(|m| {
405 m.sqlite_message_count = count_u64;
406 });
407 }
408
409 if let Ok(count) = memory.sqlite().count_semantic_facts().await {
410 let count_u64 = u64::try_from(count).unwrap_or(0);
411 self.update_metrics(|m| {
412 m.semantic_fact_count = count_u64;
413 });
414 }
415
416 if let Ok(count) = memory.unsummarized_message_count(cid).await {
417 self.memory_state.unsummarized_count = usize::try_from(count).unwrap_or(0);
418 }
419
420 self.recompute_prompt_tokens();
421 Ok(())
422 }
423
424 #[allow(clippy::too_many_lines)]
430 pub(crate) async fn persist_message(
431 &mut self,
432 role: Role,
433 content: &str,
434 parts: &[MessagePart],
435 has_injection_flags: bool,
436 ) {
437 let (Some(memory), Some(cid)) =
438 (&self.memory_state.memory, self.memory_state.conversation_id)
439 else {
440 return;
441 };
442
443 let parts_json = if parts.is_empty() {
444 "[]".to_string()
445 } else {
446 serde_json::to_string(parts).unwrap_or_else(|e| {
447 tracing::warn!("failed to serialize message parts, storing empty: {e}");
448 "[]".to_string()
449 })
450 };
451
452 let guard_event = self
456 .security
457 .exfiltration_guard
458 .should_guard_memory_write(has_injection_flags);
459 if let Some(ref event) = guard_event {
460 tracing::warn!(
461 ?event,
462 "exfiltration guard: skipping Qdrant embedding for flagged content"
463 );
464 self.update_metrics(|m| m.exfiltration_memory_guards += 1);
465 self.push_security_event(
466 crate::metrics::SecurityEventCategory::ExfiltrationBlock,
467 "memory_write",
468 "Qdrant embedding skipped: flagged content",
469 );
470 }
471
472 let skip_embedding = guard_event.is_some();
473
474 let has_skipped_tool_result = parts.iter().any(|p| {
478 if let MessagePart::ToolResult { content, .. } = p {
479 content.starts_with("[skipped]") || content.starts_with("[stopped]")
480 } else {
481 false
482 }
483 });
484
485 let should_embed = if skip_embedding || has_skipped_tool_result {
486 false
487 } else {
488 match role {
489 Role::Assistant => {
490 self.memory_state.autosave_assistant
491 && content.len() >= self.memory_state.autosave_min_length
492 }
493 _ => true,
494 }
495 };
496
497 let goal_text = self.memory_state.goal_text.clone();
498
499 let (embedding_stored, was_persisted) = if should_embed {
500 match memory
501 .remember_with_parts(
502 cid,
503 role_str(role),
504 content,
505 &parts_json,
506 goal_text.as_deref(),
507 )
508 .await
509 {
510 Ok((Some(message_id), stored)) => {
511 self.last_persisted_message_id = Some(message_id.0);
512 (stored, true)
513 }
514 Ok((None, _)) => {
515 return;
517 }
518 Err(e) => {
519 tracing::error!("failed to persist message: {e:#}");
520 return;
521 }
522 }
523 } else {
524 match memory
525 .save_only(cid, role_str(role), content, &parts_json)
526 .await
527 {
528 Ok(message_id) => {
529 self.last_persisted_message_id = Some(message_id.0);
530 (false, true)
531 }
532 Err(e) => {
533 tracing::error!("failed to persist message: {e:#}");
534 return;
535 }
536 }
537 };
538
539 if !was_persisted {
540 return;
541 }
542
543 self.memory_state.unsummarized_count += 1;
544
545 self.update_metrics(|m| {
546 m.sqlite_message_count += 1;
547 if embedding_stored {
548 m.embeddings_generated += 1;
549 }
550 });
551
552 self.check_summarization().await;
553
554 let has_tool_result_parts = parts
557 .iter()
558 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
559
560 self.maybe_spawn_graph_extraction(content, has_injection_flags, has_tool_result_parts)
561 .await;
562
563 if role == Role::User && !has_tool_result_parts && !has_injection_flags {
565 self.maybe_spawn_persona_extraction().await;
566 }
567 }
568
569 #[allow(clippy::too_many_lines)]
570 async fn maybe_spawn_graph_extraction(
571 &mut self,
572 content: &str,
573 has_injection_flags: bool,
574 has_tool_result_parts: bool,
575 ) {
576 use zeph_memory::semantic::GraphExtractionConfig;
577
578 if self.memory_state.memory.is_none() || self.memory_state.conversation_id.is_none() {
579 return;
580 }
581
582 if has_tool_result_parts {
585 tracing::debug!("graph extraction skipped: message contains ToolResult parts");
586 return;
587 }
588
589 if has_injection_flags {
591 tracing::warn!("graph extraction skipped: injection patterns detected in content");
592 return;
593 }
594
595 let extraction_cfg = {
597 let cfg = &self.memory_state.graph_config;
598 if !cfg.enabled {
599 return;
600 }
601 GraphExtractionConfig {
602 max_entities: cfg.max_entities_per_message,
603 max_edges: cfg.max_edges_per_message,
604 extraction_timeout_secs: cfg.extraction_timeout_secs,
605 community_refresh_interval: cfg.community_refresh_interval,
606 expired_edge_retention_days: cfg.expired_edge_retention_days,
607 max_entities_cap: cfg.max_entities,
608 community_summary_max_prompt_bytes: cfg.community_summary_max_prompt_bytes,
609 community_summary_concurrency: cfg.community_summary_concurrency,
610 lpa_edge_chunk_size: cfg.lpa_edge_chunk_size,
611 note_linking: zeph_memory::NoteLinkingConfig {
612 enabled: cfg.note_linking.enabled,
613 similarity_threshold: cfg.note_linking.similarity_threshold,
614 top_k: cfg.note_linking.top_k,
615 timeout_secs: cfg.note_linking.timeout_secs,
616 },
617 link_weight_decay_lambda: cfg.link_weight_decay_lambda,
618 link_weight_decay_interval_secs: cfg.link_weight_decay_interval_secs,
619 belief_revision_enabled: cfg.belief_revision.enabled,
620 belief_revision_similarity_threshold: cfg.belief_revision.similarity_threshold,
621 conversation_id: self.memory_state.conversation_id.map(|c| c.0),
622 }
623 };
624
625 if self.rpe_should_skip(content).await {
627 tracing::debug!("D-MEM RPE: low-surprise turn, skipping graph extraction");
628 return;
629 }
630
631 let context_messages: Vec<String> = self
635 .msg
636 .messages
637 .iter()
638 .rev()
639 .filter(|m| {
640 m.role == Role::User
641 && !m
642 .parts
643 .iter()
644 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
645 })
646 .take(4)
647 .map(|m| m.content.clone())
648 .collect();
649
650 let _ = self.channel.send_status("saving to graph...").await;
651
652 if let Some(memory) = &self.memory_state.memory {
653 let validator: zeph_memory::semantic::PostExtractValidator =
656 if self.security.memory_validator.is_enabled() {
657 let v = self.security.memory_validator.clone();
658 Some(Box::new(move |result| {
659 v.validate_graph_extraction(result)
660 .map_err(|e| e.to_string())
661 }))
662 } else {
663 None
664 };
665 let extraction_handle = memory.spawn_graph_extraction(
666 content.to_owned(),
667 context_messages,
668 extraction_cfg,
669 validator,
670 );
671 if let (Some(store), Some(tx)) =
674 (memory.graph_store.clone(), self.metrics.metrics_tx.clone())
675 {
676 let start = self.lifecycle.start_time;
677 tokio::spawn(async move {
678 let _ = extraction_handle.await;
679 let (entities, edges, communities) = tokio::join!(
680 store.entity_count(),
681 store.active_edge_count(),
682 store.community_count()
683 );
684 let elapsed = start.elapsed().as_secs();
685 tx.send_modify(|m| {
686 m.uptime_seconds = elapsed;
687 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
688 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
689 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
690 });
691 });
692 }
693 }
694 let _ = self.channel.send_status("").await;
695 self.sync_community_detection_failures();
696 self.sync_graph_extraction_metrics();
697 self.sync_graph_counts().await;
698 self.sync_guidelines_status().await;
699 }
700
701 async fn maybe_spawn_persona_extraction(&mut self) {
702 use std::time::Duration;
703
704 use zeph_memory::semantic::{PersonaExtractionConfig, extract_persona_facts};
705
706 let cfg = &self.memory_state.persona_config;
707 if !cfg.enabled {
708 return;
709 }
710
711 let Some(memory) = &self.memory_state.memory else {
712 return;
713 };
714
715 let user_messages: Vec<String> = self
717 .msg
718 .messages
719 .iter()
720 .filter(|m| {
721 m.role == Role::User
722 && !m
723 .parts
724 .iter()
725 .any(|p| matches!(p, MessagePart::ToolResult { .. }))
726 })
727 .map(|m| m.content.clone())
728 .collect();
729
730 if user_messages.len() < cfg.min_messages {
731 return;
732 }
733
734 let timeout_secs = cfg.extraction_timeout_secs;
735 let extraction_cfg = PersonaExtractionConfig {
736 enabled: cfg.enabled,
737 persona_provider: cfg.persona_provider.as_str().to_owned(),
738 min_messages: cfg.min_messages,
739 max_messages: cfg.max_messages,
740 extraction_timeout_secs: timeout_secs,
741 };
742
743 let provider = self.resolve_background_provider(cfg.persona_provider.as_str());
744 let store = memory.sqlite().clone();
745 let conversation_id = self.memory_state.conversation_id.map(|c| c.0);
746
747 let user_message_refs: Vec<&str> = user_messages.iter().map(String::as_str).collect();
748 let fut = extract_persona_facts(
749 &store,
750 &provider,
751 &user_message_refs,
752 &extraction_cfg,
753 conversation_id,
754 );
755 match tokio::time::timeout(Duration::from_secs(timeout_secs), fut).await {
756 Ok(Ok(n)) => tracing::debug!(upserted = n, "persona extraction complete"),
757 Ok(Err(e)) => tracing::warn!(error = %e, "persona extraction failed"),
758 Err(_) => tracing::warn!(
759 timeout_secs,
760 "persona extraction timed out — no facts written this turn"
761 ),
762 }
763 }
764
765 pub(crate) async fn check_summarization(&mut self) {
766 let (Some(memory), Some(cid)) =
767 (&self.memory_state.memory, self.memory_state.conversation_id)
768 else {
769 return;
770 };
771
772 if self.memory_state.unsummarized_count > self.memory_state.summarization_threshold {
773 let _ = self.channel.send_status("summarizing...").await;
774 let batch_size = self.memory_state.summarization_threshold / 2;
775 match memory.summarize(cid, batch_size).await {
776 Ok(Some(summary_id)) => {
777 tracing::info!("created summary {summary_id} for conversation {cid}");
778 self.memory_state.unsummarized_count = 0;
779 self.update_metrics(|m| {
780 m.summaries_count += 1;
781 });
782 }
783 Ok(None) => {
784 tracing::debug!("no summarization needed");
785 }
786 Err(e) => {
787 tracing::error!("summarization failed: {e:#}");
788 }
789 }
790 let _ = self.channel.send_status("").await;
791 }
792 }
793
794 async fn rpe_should_skip(&mut self, content: &str) -> bool {
799 let Some(ref rpe_mutex) = self.memory_state.rpe_router else {
800 return false;
801 };
802 let Some(memory) = &self.memory_state.memory else {
803 return false;
804 };
805 let candidates = zeph_memory::extract_candidate_entities(content);
806 let provider = memory.provider();
807 let Ok(Ok(emb_vec)) =
808 tokio::time::timeout(std::time::Duration::from_secs(5), provider.embed(content)).await
809 else {
810 return false; };
812 if let Ok(mut router) = rpe_mutex.lock() {
813 let signal = router.compute(&emb_vec, &candidates);
814 router.push_embedding(emb_vec);
815 router.push_entities(&candidates);
816 !signal.should_extract
817 } else {
818 tracing::warn!("rpe_router mutex poisoned; falling through to extract");
819 false
820 }
821 }
822}
823
824#[cfg(test)]
825mod tests {
826 use super::super::agent_tests::{
827 MetricsSnapshot, MockChannel, MockToolExecutor, create_test_registry, mock_provider,
828 };
829 use super::*;
830 use zeph_llm::any::AnyProvider;
831 use zeph_memory::semantic::SemanticMemory;
832
833 async fn test_memory(provider: &AnyProvider) -> SemanticMemory {
834 SemanticMemory::new(
835 ":memory:",
836 "http://127.0.0.1:1",
837 provider.clone(),
838 "test-model",
839 )
840 .await
841 .unwrap()
842 }
843
844 #[tokio::test]
845 async fn load_history_without_memory_returns_ok() {
846 let provider = mock_provider(vec![]);
847 let channel = MockChannel::new(vec![]);
848 let registry = create_test_registry();
849 let executor = MockToolExecutor::no_tools();
850 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
851
852 let result = agent.load_history().await;
853 assert!(result.is_ok());
854 assert_eq!(agent.msg.messages.len(), 1); }
857
858 #[tokio::test]
859 async fn load_history_with_messages_injects_into_agent() {
860 let provider = mock_provider(vec![]);
861 let channel = MockChannel::new(vec![]);
862 let registry = create_test_registry();
863 let executor = MockToolExecutor::no_tools();
864
865 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
866 let cid = memory.sqlite().create_conversation().await.unwrap();
867
868 memory
869 .sqlite()
870 .save_message(cid, "user", "hello from history")
871 .await
872 .unwrap();
873 memory
874 .sqlite()
875 .save_message(cid, "assistant", "hi back")
876 .await
877 .unwrap();
878
879 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
880 std::sync::Arc::new(memory),
881 cid,
882 50,
883 5,
884 100,
885 );
886
887 let messages_before = agent.msg.messages.len();
888 agent.load_history().await.unwrap();
889 assert_eq!(agent.msg.messages.len(), messages_before + 2);
891 }
892
893 #[tokio::test]
894 async fn load_history_skips_empty_messages() {
895 let provider = mock_provider(vec![]);
896 let channel = MockChannel::new(vec![]);
897 let registry = create_test_registry();
898 let executor = MockToolExecutor::no_tools();
899
900 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
901 let cid = memory.sqlite().create_conversation().await.unwrap();
902
903 memory
905 .sqlite()
906 .save_message(cid, "user", " ")
907 .await
908 .unwrap();
909 memory
910 .sqlite()
911 .save_message(cid, "user", "real message")
912 .await
913 .unwrap();
914
915 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
916 std::sync::Arc::new(memory),
917 cid,
918 50,
919 5,
920 100,
921 );
922
923 let messages_before = agent.msg.messages.len();
924 agent.load_history().await.unwrap();
925 assert_eq!(agent.msg.messages.len(), messages_before + 1);
927 }
928
929 #[tokio::test]
930 async fn load_history_with_empty_store_returns_ok() {
931 let provider = mock_provider(vec![]);
932 let channel = MockChannel::new(vec![]);
933 let registry = create_test_registry();
934 let executor = MockToolExecutor::no_tools();
935
936 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
937 let cid = memory.sqlite().create_conversation().await.unwrap();
938
939 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
940 std::sync::Arc::new(memory),
941 cid,
942 50,
943 5,
944 100,
945 );
946
947 let messages_before = agent.msg.messages.len();
948 agent.load_history().await.unwrap();
949 assert_eq!(agent.msg.messages.len(), messages_before);
951 }
952
953 #[tokio::test]
954 async fn load_history_increments_session_count_for_existing_messages() {
955 let provider = mock_provider(vec![]);
956 let channel = MockChannel::new(vec![]);
957 let registry = create_test_registry();
958 let executor = MockToolExecutor::no_tools();
959
960 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
961 let cid = memory.sqlite().create_conversation().await.unwrap();
962
963 let id1 = memory
965 .sqlite()
966 .save_message(cid, "user", "hello")
967 .await
968 .unwrap();
969 let id2 = memory
970 .sqlite()
971 .save_message(cid, "assistant", "hi")
972 .await
973 .unwrap();
974
975 let memory_arc = std::sync::Arc::new(memory);
976 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
977 memory_arc.clone(),
978 cid,
979 50,
980 5,
981 100,
982 );
983
984 agent.load_history().await.unwrap();
985
986 let counts: Vec<i64> = zeph_db::query_scalar(
988 "SELECT session_count FROM messages WHERE id IN (?, ?) ORDER BY id",
989 )
990 .bind(id1)
991 .bind(id2)
992 .fetch_all(memory_arc.sqlite().pool())
993 .await
994 .unwrap();
995 assert_eq!(
996 counts,
997 vec![1, 1],
998 "session_count must be 1 after first restore"
999 );
1000 }
1001
1002 #[tokio::test]
1003 async fn load_history_does_not_increment_session_count_for_new_conversation() {
1004 let provider = mock_provider(vec![]);
1005 let channel = MockChannel::new(vec![]);
1006 let registry = create_test_registry();
1007 let executor = MockToolExecutor::no_tools();
1008
1009 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1010 let cid = memory.sqlite().create_conversation().await.unwrap();
1011
1012 let memory_arc = std::sync::Arc::new(memory);
1014 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1015 memory_arc.clone(),
1016 cid,
1017 50,
1018 5,
1019 100,
1020 );
1021
1022 agent.load_history().await.unwrap();
1023
1024 let counts: Vec<i64> =
1026 zeph_db::query_scalar("SELECT session_count FROM messages WHERE conversation_id = ?")
1027 .bind(cid)
1028 .fetch_all(memory_arc.sqlite().pool())
1029 .await
1030 .unwrap();
1031 assert!(counts.is_empty(), "new conversation must have no messages");
1032 }
1033
1034 #[tokio::test]
1035 async fn persist_message_without_memory_silently_returns() {
1036 let provider = mock_provider(vec![]);
1038 let channel = MockChannel::new(vec![]);
1039 let registry = create_test_registry();
1040 let executor = MockToolExecutor::no_tools();
1041 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1042
1043 agent.persist_message(Role::User, "hello", &[], false).await;
1045 }
1046
1047 #[tokio::test]
1048 async fn persist_message_assistant_autosave_false_uses_save_only() {
1049 let provider = mock_provider(vec![]);
1050 let channel = MockChannel::new(vec![]);
1051 let registry = create_test_registry();
1052 let executor = MockToolExecutor::no_tools();
1053
1054 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1055 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1056 let cid = memory.sqlite().create_conversation().await.unwrap();
1057
1058 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1059 .with_metrics(tx)
1060 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1061 .with_autosave_config(false, 20);
1062
1063 agent
1064 .persist_message(Role::Assistant, "short assistant reply", &[], false)
1065 .await;
1066
1067 let history = agent
1068 .memory_state
1069 .memory
1070 .as_ref()
1071 .unwrap()
1072 .sqlite()
1073 .load_history(cid, 50)
1074 .await
1075 .unwrap();
1076 assert_eq!(history.len(), 1, "message must be saved");
1077 assert_eq!(history[0].content, "short assistant reply");
1078 assert_eq!(rx.borrow().embeddings_generated, 0);
1080 }
1081
1082 #[tokio::test]
1083 async fn persist_message_assistant_below_min_length_uses_save_only() {
1084 let provider = mock_provider(vec![]);
1085 let channel = MockChannel::new(vec![]);
1086 let registry = create_test_registry();
1087 let executor = MockToolExecutor::no_tools();
1088
1089 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1090 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1091 let cid = memory.sqlite().create_conversation().await.unwrap();
1092
1093 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1095 .with_metrics(tx)
1096 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1097 .with_autosave_config(true, 1000);
1098
1099 agent
1100 .persist_message(Role::Assistant, "too short", &[], false)
1101 .await;
1102
1103 let history = agent
1104 .memory_state
1105 .memory
1106 .as_ref()
1107 .unwrap()
1108 .sqlite()
1109 .load_history(cid, 50)
1110 .await
1111 .unwrap();
1112 assert_eq!(history.len(), 1, "message must be saved");
1113 assert_eq!(history[0].content, "too short");
1114 assert_eq!(rx.borrow().embeddings_generated, 0);
1115 }
1116
1117 #[tokio::test]
1118 async fn persist_message_assistant_at_min_length_boundary_uses_embed() {
1119 let provider = mock_provider(vec![]);
1121 let channel = MockChannel::new(vec![]);
1122 let registry = create_test_registry();
1123 let executor = MockToolExecutor::no_tools();
1124
1125 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1126 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1127 let cid = memory.sqlite().create_conversation().await.unwrap();
1128
1129 let min_length = 10usize;
1130 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1131 .with_metrics(tx)
1132 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1133 .with_autosave_config(true, min_length);
1134
1135 let content_at_boundary = "A".repeat(min_length);
1137 assert_eq!(content_at_boundary.len(), min_length);
1138 agent
1139 .persist_message(Role::Assistant, &content_at_boundary, &[], false)
1140 .await;
1141
1142 assert_eq!(rx.borrow().sqlite_message_count, 1);
1144 }
1145
1146 #[tokio::test]
1147 async fn persist_message_assistant_one_below_min_length_uses_save_only() {
1148 let provider = mock_provider(vec![]);
1150 let channel = MockChannel::new(vec![]);
1151 let registry = create_test_registry();
1152 let executor = MockToolExecutor::no_tools();
1153
1154 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1155 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1156 let cid = memory.sqlite().create_conversation().await.unwrap();
1157
1158 let min_length = 10usize;
1159 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1160 .with_metrics(tx)
1161 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1162 .with_autosave_config(true, min_length);
1163
1164 let content_below_boundary = "A".repeat(min_length - 1);
1166 assert_eq!(content_below_boundary.len(), min_length - 1);
1167 agent
1168 .persist_message(Role::Assistant, &content_below_boundary, &[], false)
1169 .await;
1170
1171 let history = agent
1172 .memory_state
1173 .memory
1174 .as_ref()
1175 .unwrap()
1176 .sqlite()
1177 .load_history(cid, 50)
1178 .await
1179 .unwrap();
1180 assert_eq!(history.len(), 1, "message must still be saved");
1181 assert_eq!(rx.borrow().embeddings_generated, 0);
1183 }
1184
1185 #[tokio::test]
1186 async fn persist_message_increments_unsummarized_count() {
1187 let provider = mock_provider(vec![]);
1188 let channel = MockChannel::new(vec![]);
1189 let registry = create_test_registry();
1190 let executor = MockToolExecutor::no_tools();
1191
1192 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1193 let cid = memory.sqlite().create_conversation().await.unwrap();
1194
1195 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1197 std::sync::Arc::new(memory),
1198 cid,
1199 50,
1200 5,
1201 100,
1202 );
1203
1204 assert_eq!(agent.memory_state.unsummarized_count, 0);
1205
1206 agent.persist_message(Role::User, "first", &[], false).await;
1207 assert_eq!(agent.memory_state.unsummarized_count, 1);
1208
1209 agent
1210 .persist_message(Role::User, "second", &[], false)
1211 .await;
1212 assert_eq!(agent.memory_state.unsummarized_count, 2);
1213 }
1214
1215 #[tokio::test]
1216 async fn check_summarization_resets_counter_on_success() {
1217 let provider = mock_provider(vec![]);
1218 let channel = MockChannel::new(vec![]);
1219 let registry = create_test_registry();
1220 let executor = MockToolExecutor::no_tools();
1221
1222 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1223 let cid = memory.sqlite().create_conversation().await.unwrap();
1224
1225 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1227 std::sync::Arc::new(memory),
1228 cid,
1229 50,
1230 5,
1231 1,
1232 );
1233
1234 agent.persist_message(Role::User, "msg1", &[], false).await;
1235 agent.persist_message(Role::User, "msg2", &[], false).await;
1236
1237 assert!(agent.memory_state.unsummarized_count <= 2);
1242 }
1243
1244 #[tokio::test]
1245 async fn unsummarized_count_not_incremented_without_memory() {
1246 let provider = mock_provider(vec![]);
1247 let channel = MockChannel::new(vec![]);
1248 let registry = create_test_registry();
1249 let executor = MockToolExecutor::no_tools();
1250 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1251
1252 agent.persist_message(Role::User, "hello", &[], false).await;
1253 assert_eq!(agent.memory_state.unsummarized_count, 0);
1255 }
1256
1257 mod graph_extraction_guards {
1259 use super::*;
1260 use crate::config::GraphConfig;
1261 use zeph_llm::provider::MessageMetadata;
1262 use zeph_memory::graph::GraphStore;
1263
1264 fn enabled_graph_config() -> GraphConfig {
1265 GraphConfig {
1266 enabled: true,
1267 ..GraphConfig::default()
1268 }
1269 }
1270
1271 async fn agent_with_graph(
1272 provider: &AnyProvider,
1273 config: GraphConfig,
1274 ) -> Agent<MockChannel> {
1275 let memory =
1276 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1277 let cid = memory.sqlite().create_conversation().await.unwrap();
1278 Agent::new(
1279 provider.clone(),
1280 MockChannel::new(vec![]),
1281 create_test_registry(),
1282 None,
1283 5,
1284 MockToolExecutor::no_tools(),
1285 )
1286 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1287 .with_graph_config(config)
1288 }
1289
1290 #[tokio::test]
1291 async fn injection_flag_guard_skips_extraction() {
1292 let provider = mock_provider(vec![]);
1294 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1295 let pool = agent
1296 .memory_state
1297 .memory
1298 .as_ref()
1299 .unwrap()
1300 .sqlite()
1301 .pool()
1302 .clone();
1303
1304 agent
1305 .maybe_spawn_graph_extraction("I use Rust", true, false)
1306 .await;
1307
1308 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1310
1311 let store = GraphStore::new(pool);
1312 let count = store.get_metadata("extraction_count").await.unwrap();
1313 assert!(
1314 count.is_none(),
1315 "injection flag must prevent extraction_count from being written"
1316 );
1317 }
1318
1319 #[tokio::test]
1320 async fn disabled_config_guard_skips_extraction() {
1321 let provider = mock_provider(vec![]);
1323 let disabled_cfg = GraphConfig {
1324 enabled: false,
1325 ..GraphConfig::default()
1326 };
1327 let mut agent = agent_with_graph(&provider, disabled_cfg).await;
1328 let pool = agent
1329 .memory_state
1330 .memory
1331 .as_ref()
1332 .unwrap()
1333 .sqlite()
1334 .pool()
1335 .clone();
1336
1337 agent
1338 .maybe_spawn_graph_extraction("I use Rust", false, false)
1339 .await;
1340
1341 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1342
1343 let store = GraphStore::new(pool);
1344 let count = store.get_metadata("extraction_count").await.unwrap();
1345 assert!(
1346 count.is_none(),
1347 "disabled graph config must prevent extraction"
1348 );
1349 }
1350
1351 #[tokio::test]
1352 async fn happy_path_fires_extraction() {
1353 let provider = mock_provider(vec![]);
1356 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1357 let pool = agent
1358 .memory_state
1359 .memory
1360 .as_ref()
1361 .unwrap()
1362 .sqlite()
1363 .pool()
1364 .clone();
1365
1366 agent
1367 .maybe_spawn_graph_extraction("I use Rust for systems programming", false, false)
1368 .await;
1369
1370 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1372
1373 let store = GraphStore::new(pool);
1374 let count = store.get_metadata("extraction_count").await.unwrap();
1375 assert!(
1376 count.is_some(),
1377 "happy-path extraction must increment extraction_count"
1378 );
1379 }
1380
1381 #[tokio::test]
1382 async fn tool_result_parts_guard_skips_extraction() {
1383 let provider = mock_provider(vec![]);
1387 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1388 let pool = agent
1389 .memory_state
1390 .memory
1391 .as_ref()
1392 .unwrap()
1393 .sqlite()
1394 .pool()
1395 .clone();
1396
1397 agent
1398 .maybe_spawn_graph_extraction(
1399 "[tool_result: abc123]\nprovider_type = \"claude\"\nallowed_commands = []",
1400 false,
1401 true, )
1403 .await;
1404
1405 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1406
1407 let store = GraphStore::new(pool);
1408 let count = store.get_metadata("extraction_count").await.unwrap();
1409 assert!(
1410 count.is_none(),
1411 "tool result message must not trigger graph extraction"
1412 );
1413 }
1414
1415 #[tokio::test]
1416 async fn context_filter_excludes_tool_result_messages() {
1417 let provider = mock_provider(vec![]);
1428 let mut agent = agent_with_graph(&provider, enabled_graph_config()).await;
1429
1430 agent.msg.messages.push(Message {
1433 role: Role::User,
1434 content: "[tool_result: abc]\nprovider_type = \"openai\"".to_owned(),
1435 parts: vec![MessagePart::ToolResult {
1436 tool_use_id: "abc".to_owned(),
1437 content: "provider_type = \"openai\"".to_owned(),
1438 is_error: false,
1439 }],
1440 metadata: MessageMetadata::default(),
1441 });
1442
1443 let pool = agent
1444 .memory_state
1445 .memory
1446 .as_ref()
1447 .unwrap()
1448 .sqlite()
1449 .pool()
1450 .clone();
1451
1452 agent
1454 .maybe_spawn_graph_extraction("I prefer Rust for systems programming", false, false)
1455 .await;
1456
1457 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1458
1459 let store = GraphStore::new(pool);
1461 let count = store.get_metadata("extraction_count").await.unwrap();
1462 assert!(
1463 count.is_some(),
1464 "conversational message must trigger extraction even with prior tool result in history"
1465 );
1466 }
1467 }
1468
1469 mod persona_extraction_guards {
1471 use super::*;
1472 use zeph_config::PersonaConfig;
1473 use zeph_llm::provider::MessageMetadata;
1474
1475 fn enabled_persona_config() -> PersonaConfig {
1476 PersonaConfig {
1477 enabled: true,
1478 min_messages: 1,
1479 ..PersonaConfig::default()
1480 }
1481 }
1482
1483 async fn agent_with_persona(
1484 provider: &AnyProvider,
1485 config: PersonaConfig,
1486 ) -> Agent<MockChannel> {
1487 let memory =
1488 test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1489 let cid = memory.sqlite().create_conversation().await.unwrap();
1490 let mut agent = Agent::new(
1491 provider.clone(),
1492 MockChannel::new(vec![]),
1493 create_test_registry(),
1494 None,
1495 5,
1496 MockToolExecutor::no_tools(),
1497 )
1498 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100);
1499 agent.memory_state.persona_config = config;
1500 agent
1501 }
1502
1503 #[tokio::test]
1504 async fn disabled_config_skips_spawn() {
1505 let provider = mock_provider(vec![]);
1507 let mut agent = agent_with_persona(
1508 &provider,
1509 PersonaConfig {
1510 enabled: false,
1511 ..PersonaConfig::default()
1512 },
1513 )
1514 .await;
1515
1516 agent.msg.messages.push(zeph_llm::provider::Message {
1518 role: Role::User,
1519 content: "I prefer Rust for systems programming".to_owned(),
1520 parts: vec![],
1521 metadata: MessageMetadata::default(),
1522 });
1523
1524 agent.maybe_spawn_persona_extraction().await;
1525
1526 let store = agent.memory_state.memory.as_ref().unwrap().sqlite().clone();
1527 let count = store.count_persona_facts().await.unwrap();
1528 assert_eq!(count, 0, "disabled persona config must not write any facts");
1529 }
1530
1531 #[tokio::test]
1532 async fn below_min_messages_skips_spawn() {
1533 let provider = mock_provider(vec![]);
1535 let mut agent = agent_with_persona(
1536 &provider,
1537 PersonaConfig {
1538 enabled: true,
1539 min_messages: 3,
1540 ..PersonaConfig::default()
1541 },
1542 )
1543 .await;
1544
1545 for text in ["I use Rust", "I prefer async code"] {
1546 agent.msg.messages.push(zeph_llm::provider::Message {
1547 role: Role::User,
1548 content: text.to_owned(),
1549 parts: vec![],
1550 metadata: MessageMetadata::default(),
1551 });
1552 }
1553
1554 agent.maybe_spawn_persona_extraction().await;
1555
1556 let store = agent.memory_state.memory.as_ref().unwrap().sqlite().clone();
1557 let count = store.count_persona_facts().await.unwrap();
1558 assert_eq!(
1559 count, 0,
1560 "below min_messages threshold must not trigger extraction"
1561 );
1562 }
1563
1564 #[tokio::test]
1565 async fn no_memory_skips_spawn() {
1566 let provider = mock_provider(vec![]);
1568 let channel = MockChannel::new(vec![]);
1569 let registry = create_test_registry();
1570 let executor = MockToolExecutor::no_tools();
1571 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
1572 agent.memory_state.persona_config = enabled_persona_config();
1573 agent.msg.messages.push(zeph_llm::provider::Message {
1574 role: Role::User,
1575 content: "I like Rust".to_owned(),
1576 parts: vec![],
1577 metadata: MessageMetadata::default(),
1578 });
1579
1580 agent.maybe_spawn_persona_extraction().await;
1582 }
1583
1584 #[tokio::test]
1585 async fn enabled_enough_messages_spawns_extraction() {
1586 use zeph_llm::mock::MockProvider;
1589 let (mock, recorded) = MockProvider::default().with_recording();
1590 let provider = AnyProvider::Mock(mock);
1591 let mut agent = agent_with_persona(&provider, enabled_persona_config()).await;
1592
1593 agent.msg.messages.push(zeph_llm::provider::Message {
1594 role: Role::User,
1595 content: "I prefer Rust for systems programming".to_owned(),
1596 parts: vec![],
1597 metadata: MessageMetadata::default(),
1598 });
1599
1600 agent.maybe_spawn_persona_extraction().await;
1601
1602 let calls = recorded.lock().unwrap();
1603 assert!(
1604 !calls.is_empty(),
1605 "happy-path: provider.chat() must be called when extraction completes"
1606 );
1607 }
1608 }
1609
1610 #[tokio::test]
1611 async fn persist_message_user_always_embeds_regardless_of_autosave_flag() {
1612 let provider = mock_provider(vec![]);
1613 let channel = MockChannel::new(vec![]);
1614 let registry = create_test_registry();
1615 let executor = MockToolExecutor::no_tools();
1616
1617 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
1618 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1619 let cid = memory.sqlite().create_conversation().await.unwrap();
1620
1621 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
1623 .with_metrics(tx)
1624 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
1625 .with_autosave_config(false, 20);
1626
1627 let long_user_msg = "A".repeat(100);
1628 agent
1629 .persist_message(Role::User, &long_user_msg, &[], false)
1630 .await;
1631
1632 let history = agent
1633 .memory_state
1634 .memory
1635 .as_ref()
1636 .unwrap()
1637 .sqlite()
1638 .load_history(cid, 50)
1639 .await
1640 .unwrap();
1641 assert_eq!(history.len(), 1, "user message must be saved");
1642 assert_eq!(rx.borrow().sqlite_message_count, 1);
1645 }
1646
1647 #[tokio::test]
1651 async fn persist_message_saves_correct_tool_use_parts() {
1652 use zeph_llm::provider::MessagePart;
1653
1654 let provider = mock_provider(vec![]);
1655 let channel = MockChannel::new(vec![]);
1656 let registry = create_test_registry();
1657 let executor = MockToolExecutor::no_tools();
1658
1659 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1660 let cid = memory.sqlite().create_conversation().await.unwrap();
1661
1662 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1663 std::sync::Arc::new(memory),
1664 cid,
1665 50,
1666 5,
1667 100,
1668 );
1669
1670 let parts = vec![MessagePart::ToolUse {
1671 id: "call_abc123".to_string(),
1672 name: "read_file".to_string(),
1673 input: serde_json::json!({"path": "/tmp/test.txt"}),
1674 }];
1675 let content = "[tool_use: read_file(call_abc123)]";
1676
1677 agent
1678 .persist_message(Role::Assistant, content, &parts, false)
1679 .await;
1680
1681 let history = agent
1682 .memory_state
1683 .memory
1684 .as_ref()
1685 .unwrap()
1686 .sqlite()
1687 .load_history(cid, 50)
1688 .await
1689 .unwrap();
1690
1691 assert_eq!(history.len(), 1);
1692 assert_eq!(history[0].role, Role::Assistant);
1693 assert_eq!(history[0].content, content);
1694 assert_eq!(history[0].parts.len(), 1);
1695 match &history[0].parts[0] {
1696 MessagePart::ToolUse { id, name, .. } => {
1697 assert_eq!(id, "call_abc123");
1698 assert_eq!(name, "read_file");
1699 }
1700 other => panic!("expected ToolUse part, got {other:?}"),
1701 }
1702 assert!(
1704 !history[0]
1705 .parts
1706 .iter()
1707 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1708 "assistant message must not contain ToolResult parts"
1709 );
1710 }
1711
1712 #[tokio::test]
1713 async fn persist_message_saves_correct_tool_result_parts() {
1714 use zeph_llm::provider::MessagePart;
1715
1716 let provider = mock_provider(vec![]);
1717 let channel = MockChannel::new(vec![]);
1718 let registry = create_test_registry();
1719 let executor = MockToolExecutor::no_tools();
1720
1721 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1722 let cid = memory.sqlite().create_conversation().await.unwrap();
1723
1724 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1725 std::sync::Arc::new(memory),
1726 cid,
1727 50,
1728 5,
1729 100,
1730 );
1731
1732 let parts = vec![MessagePart::ToolResult {
1733 tool_use_id: "call_abc123".to_string(),
1734 content: "file contents here".to_string(),
1735 is_error: false,
1736 }];
1737 let content = "[tool_result: call_abc123]\nfile contents here";
1738
1739 agent
1740 .persist_message(Role::User, content, &parts, false)
1741 .await;
1742
1743 let history = agent
1744 .memory_state
1745 .memory
1746 .as_ref()
1747 .unwrap()
1748 .sqlite()
1749 .load_history(cid, 50)
1750 .await
1751 .unwrap();
1752
1753 assert_eq!(history.len(), 1);
1754 assert_eq!(history[0].role, Role::User);
1755 assert_eq!(history[0].content, content);
1756 assert_eq!(history[0].parts.len(), 1);
1757 match &history[0].parts[0] {
1758 MessagePart::ToolResult {
1759 tool_use_id,
1760 content: result_content,
1761 is_error,
1762 } => {
1763 assert_eq!(tool_use_id, "call_abc123");
1764 assert_eq!(result_content, "file contents here");
1765 assert!(!is_error);
1766 }
1767 other => panic!("expected ToolResult part, got {other:?}"),
1768 }
1769 assert!(
1771 !history[0]
1772 .parts
1773 .iter()
1774 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1775 "user ToolResult message must not contain ToolUse parts"
1776 );
1777 }
1778
1779 #[tokio::test]
1780 async fn persist_message_roundtrip_preserves_role_part_alignment() {
1781 use zeph_llm::provider::MessagePart;
1782
1783 let provider = mock_provider(vec![]);
1784 let channel = MockChannel::new(vec![]);
1785 let registry = create_test_registry();
1786 let executor = MockToolExecutor::no_tools();
1787
1788 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1789 let cid = memory.sqlite().create_conversation().await.unwrap();
1790
1791 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1792 std::sync::Arc::new(memory),
1793 cid,
1794 50,
1795 5,
1796 100,
1797 );
1798
1799 let assistant_parts = vec![MessagePart::ToolUse {
1801 id: "id_1".to_string(),
1802 name: "list_dir".to_string(),
1803 input: serde_json::json!({"path": "/tmp"}),
1804 }];
1805 agent
1806 .persist_message(
1807 Role::Assistant,
1808 "[tool_use: list_dir(id_1)]",
1809 &assistant_parts,
1810 false,
1811 )
1812 .await;
1813
1814 let user_parts = vec![MessagePart::ToolResult {
1816 tool_use_id: "id_1".to_string(),
1817 content: "file1.txt\nfile2.txt".to_string(),
1818 is_error: false,
1819 }];
1820 agent
1821 .persist_message(
1822 Role::User,
1823 "[tool_result: id_1]\nfile1.txt\nfile2.txt",
1824 &user_parts,
1825 false,
1826 )
1827 .await;
1828
1829 let history = agent
1830 .memory_state
1831 .memory
1832 .as_ref()
1833 .unwrap()
1834 .sqlite()
1835 .load_history(cid, 50)
1836 .await
1837 .unwrap();
1838
1839 assert_eq!(history.len(), 2);
1840
1841 assert_eq!(history[0].role, Role::Assistant);
1843 assert_eq!(history[0].content, "[tool_use: list_dir(id_1)]");
1844 assert!(
1845 matches!(&history[0].parts[0], MessagePart::ToolUse { id, .. } if id == "id_1"),
1846 "first message must be assistant ToolUse"
1847 );
1848
1849 assert_eq!(history[1].role, Role::User);
1851 assert_eq!(
1852 history[1].content,
1853 "[tool_result: id_1]\nfile1.txt\nfile2.txt"
1854 );
1855 assert!(
1856 matches!(&history[1].parts[0], MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "id_1"),
1857 "second message must be user ToolResult"
1858 );
1859
1860 assert!(
1862 !history[0]
1863 .parts
1864 .iter()
1865 .any(|p| matches!(p, MessagePart::ToolResult { .. })),
1866 "assistant message must not have ToolResult parts"
1867 );
1868 assert!(
1869 !history[1]
1870 .parts
1871 .iter()
1872 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
1873 "user message must not have ToolUse parts"
1874 );
1875 }
1876
1877 #[tokio::test]
1878 async fn persist_message_saves_correct_tool_output_parts() {
1879 use zeph_llm::provider::MessagePart;
1880
1881 let provider = mock_provider(vec![]);
1882 let channel = MockChannel::new(vec![]);
1883 let registry = create_test_registry();
1884 let executor = MockToolExecutor::no_tools();
1885
1886 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1887 let cid = memory.sqlite().create_conversation().await.unwrap();
1888
1889 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1890 std::sync::Arc::new(memory),
1891 cid,
1892 50,
1893 5,
1894 100,
1895 );
1896
1897 let parts = vec![MessagePart::ToolOutput {
1898 tool_name: "shell".to_string(),
1899 body: "hello from shell".to_string(),
1900 compacted_at: None,
1901 }];
1902 let content = "[tool: shell]\nhello from shell";
1903
1904 agent
1905 .persist_message(Role::User, content, &parts, false)
1906 .await;
1907
1908 let history = agent
1909 .memory_state
1910 .memory
1911 .as_ref()
1912 .unwrap()
1913 .sqlite()
1914 .load_history(cid, 50)
1915 .await
1916 .unwrap();
1917
1918 assert_eq!(history.len(), 1);
1919 assert_eq!(history[0].role, Role::User);
1920 assert_eq!(history[0].content, content);
1921 assert_eq!(history[0].parts.len(), 1);
1922 match &history[0].parts[0] {
1923 MessagePart::ToolOutput {
1924 tool_name,
1925 body,
1926 compacted_at,
1927 } => {
1928 assert_eq!(tool_name, "shell");
1929 assert_eq!(body, "hello from shell");
1930 assert!(compacted_at.is_none());
1931 }
1932 other => panic!("expected ToolOutput part, got {other:?}"),
1933 }
1934 }
1935
1936 #[tokio::test]
1939 async fn load_history_removes_trailing_orphan_tool_use() {
1940 use zeph_llm::provider::MessagePart;
1941
1942 let provider = mock_provider(vec![]);
1943 let channel = MockChannel::new(vec![]);
1944 let registry = create_test_registry();
1945 let executor = MockToolExecutor::no_tools();
1946
1947 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1948 let cid = memory.sqlite().create_conversation().await.unwrap();
1949 let sqlite = memory.sqlite();
1950
1951 sqlite
1953 .save_message(cid, "user", "do something with a tool")
1954 .await
1955 .unwrap();
1956
1957 let parts = serde_json::to_string(&[MessagePart::ToolUse {
1959 id: "call_orphan".to_string(),
1960 name: "shell".to_string(),
1961 input: serde_json::json!({"command": "ls"}),
1962 }])
1963 .unwrap();
1964 sqlite
1965 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_orphan)]", &parts)
1966 .await
1967 .unwrap();
1968
1969 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
1970 std::sync::Arc::new(memory),
1971 cid,
1972 50,
1973 5,
1974 100,
1975 );
1976
1977 let messages_before = agent.msg.messages.len();
1978 agent.load_history().await.unwrap();
1979
1980 assert_eq!(
1982 agent.msg.messages.len(),
1983 messages_before + 1,
1984 "orphaned trailing tool_use must be removed"
1985 );
1986 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
1987 }
1988
1989 #[tokio::test]
1990 async fn load_history_removes_leading_orphan_tool_result() {
1991 use zeph_llm::provider::MessagePart;
1992
1993 let provider = mock_provider(vec![]);
1994 let channel = MockChannel::new(vec![]);
1995 let registry = create_test_registry();
1996 let executor = MockToolExecutor::no_tools();
1997
1998 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
1999 let cid = memory.sqlite().create_conversation().await.unwrap();
2000 let sqlite = memory.sqlite();
2001
2002 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2004 tool_use_id: "call_missing".to_string(),
2005 content: "result data".to_string(),
2006 is_error: false,
2007 }])
2008 .unwrap();
2009 sqlite
2010 .save_message_with_parts(
2011 cid,
2012 "user",
2013 "[tool_result: call_missing]\nresult data",
2014 &result_parts,
2015 )
2016 .await
2017 .unwrap();
2018
2019 sqlite
2021 .save_message(cid, "assistant", "here is my response")
2022 .await
2023 .unwrap();
2024
2025 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2026 std::sync::Arc::new(memory),
2027 cid,
2028 50,
2029 5,
2030 100,
2031 );
2032
2033 let messages_before = agent.msg.messages.len();
2034 agent.load_history().await.unwrap();
2035
2036 assert_eq!(
2038 agent.msg.messages.len(),
2039 messages_before + 1,
2040 "orphaned leading tool_result must be removed"
2041 );
2042 assert_eq!(agent.msg.messages.last().unwrap().role, Role::Assistant);
2043 }
2044
2045 #[tokio::test]
2046 async fn load_history_preserves_complete_tool_pairs() {
2047 use zeph_llm::provider::MessagePart;
2048
2049 let provider = mock_provider(vec![]);
2050 let channel = MockChannel::new(vec![]);
2051 let registry = create_test_registry();
2052 let executor = MockToolExecutor::no_tools();
2053
2054 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2055 let cid = memory.sqlite().create_conversation().await.unwrap();
2056 let sqlite = memory.sqlite();
2057
2058 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2060 id: "call_ok".to_string(),
2061 name: "shell".to_string(),
2062 input: serde_json::json!({"command": "pwd"}),
2063 }])
2064 .unwrap();
2065 sqlite
2066 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_ok)]", &use_parts)
2067 .await
2068 .unwrap();
2069
2070 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2071 tool_use_id: "call_ok".to_string(),
2072 content: "/home/user".to_string(),
2073 is_error: false,
2074 }])
2075 .unwrap();
2076 sqlite
2077 .save_message_with_parts(
2078 cid,
2079 "user",
2080 "[tool_result: call_ok]\n/home/user",
2081 &result_parts,
2082 )
2083 .await
2084 .unwrap();
2085
2086 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2087 std::sync::Arc::new(memory),
2088 cid,
2089 50,
2090 5,
2091 100,
2092 );
2093
2094 let messages_before = agent.msg.messages.len();
2095 agent.load_history().await.unwrap();
2096
2097 assert_eq!(
2099 agent.msg.messages.len(),
2100 messages_before + 2,
2101 "complete tool_use/tool_result pair must be preserved"
2102 );
2103 assert_eq!(agent.msg.messages[messages_before].role, Role::Assistant);
2104 assert_eq!(agent.msg.messages[messages_before + 1].role, Role::User);
2105 }
2106
2107 #[tokio::test]
2108 async fn load_history_handles_multiple_trailing_orphans() {
2109 use zeph_llm::provider::MessagePart;
2110
2111 let provider = mock_provider(vec![]);
2112 let channel = MockChannel::new(vec![]);
2113 let registry = create_test_registry();
2114 let executor = MockToolExecutor::no_tools();
2115
2116 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2117 let cid = memory.sqlite().create_conversation().await.unwrap();
2118 let sqlite = memory.sqlite();
2119
2120 sqlite.save_message(cid, "user", "start").await.unwrap();
2122
2123 let parts1 = serde_json::to_string(&[MessagePart::ToolUse {
2125 id: "call_1".to_string(),
2126 name: "shell".to_string(),
2127 input: serde_json::json!({}),
2128 }])
2129 .unwrap();
2130 sqlite
2131 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_1)]", &parts1)
2132 .await
2133 .unwrap();
2134
2135 let parts2 = serde_json::to_string(&[MessagePart::ToolUse {
2137 id: "call_2".to_string(),
2138 name: "read_file".to_string(),
2139 input: serde_json::json!({}),
2140 }])
2141 .unwrap();
2142 sqlite
2143 .save_message_with_parts(cid, "assistant", "[tool_use: read_file(call_2)]", &parts2)
2144 .await
2145 .unwrap();
2146
2147 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2148 std::sync::Arc::new(memory),
2149 cid,
2150 50,
2151 5,
2152 100,
2153 );
2154
2155 let messages_before = agent.msg.messages.len();
2156 agent.load_history().await.unwrap();
2157
2158 assert_eq!(
2160 agent.msg.messages.len(),
2161 messages_before + 1,
2162 "all trailing orphaned tool_use messages must be removed"
2163 );
2164 assert_eq!(agent.msg.messages.last().unwrap().role, Role::User);
2165 }
2166
2167 #[tokio::test]
2168 async fn load_history_no_tool_messages_unchanged() {
2169 let provider = mock_provider(vec![]);
2170 let channel = MockChannel::new(vec![]);
2171 let registry = create_test_registry();
2172 let executor = MockToolExecutor::no_tools();
2173
2174 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2175 let cid = memory.sqlite().create_conversation().await.unwrap();
2176 let sqlite = memory.sqlite();
2177
2178 sqlite.save_message(cid, "user", "hello").await.unwrap();
2179 sqlite
2180 .save_message(cid, "assistant", "hi there")
2181 .await
2182 .unwrap();
2183 sqlite
2184 .save_message(cid, "user", "how are you?")
2185 .await
2186 .unwrap();
2187
2188 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2189 std::sync::Arc::new(memory),
2190 cid,
2191 50,
2192 5,
2193 100,
2194 );
2195
2196 let messages_before = agent.msg.messages.len();
2197 agent.load_history().await.unwrap();
2198
2199 assert_eq!(
2201 agent.msg.messages.len(),
2202 messages_before + 3,
2203 "plain messages without tool parts must pass through unchanged"
2204 );
2205 }
2206
2207 #[tokio::test]
2208 async fn load_history_removes_both_leading_and_trailing_orphans() {
2209 use zeph_llm::provider::MessagePart;
2210
2211 let provider = mock_provider(vec![]);
2212 let channel = MockChannel::new(vec![]);
2213 let registry = create_test_registry();
2214 let executor = MockToolExecutor::no_tools();
2215
2216 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2217 let cid = memory.sqlite().create_conversation().await.unwrap();
2218 let sqlite = memory.sqlite();
2219
2220 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2222 tool_use_id: "call_leading".to_string(),
2223 content: "orphaned result".to_string(),
2224 is_error: false,
2225 }])
2226 .unwrap();
2227 sqlite
2228 .save_message_with_parts(
2229 cid,
2230 "user",
2231 "[tool_result: call_leading]\norphaned result",
2232 &result_parts,
2233 )
2234 .await
2235 .unwrap();
2236
2237 sqlite
2239 .save_message(cid, "user", "what is 2+2?")
2240 .await
2241 .unwrap();
2242 sqlite.save_message(cid, "assistant", "4").await.unwrap();
2243
2244 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2246 id: "call_trailing".to_string(),
2247 name: "shell".to_string(),
2248 input: serde_json::json!({"command": "date"}),
2249 }])
2250 .unwrap();
2251 sqlite
2252 .save_message_with_parts(
2253 cid,
2254 "assistant",
2255 "[tool_use: shell(call_trailing)]",
2256 &use_parts,
2257 )
2258 .await
2259 .unwrap();
2260
2261 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2262 std::sync::Arc::new(memory),
2263 cid,
2264 50,
2265 5,
2266 100,
2267 );
2268
2269 let messages_before = agent.msg.messages.len();
2270 agent.load_history().await.unwrap();
2271
2272 assert_eq!(
2274 agent.msg.messages.len(),
2275 messages_before + 2,
2276 "both leading and trailing orphans must be removed"
2277 );
2278 assert_eq!(agent.msg.messages[messages_before].role, Role::User);
2279 assert_eq!(agent.msg.messages[messages_before].content, "what is 2+2?");
2280 assert_eq!(
2281 agent.msg.messages[messages_before + 1].role,
2282 Role::Assistant
2283 );
2284 assert_eq!(agent.msg.messages[messages_before + 1].content, "4");
2285 }
2286
2287 #[tokio::test]
2292 async fn sanitize_tool_pairs_strips_mid_history_orphan_tool_use() {
2293 use zeph_llm::provider::MessagePart;
2294
2295 let provider = mock_provider(vec![]);
2296 let channel = MockChannel::new(vec![]);
2297 let registry = create_test_registry();
2298 let executor = MockToolExecutor::no_tools();
2299
2300 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2301 let cid = memory.sqlite().create_conversation().await.unwrap();
2302 let sqlite = memory.sqlite();
2303
2304 sqlite
2306 .save_message(cid, "user", "first question")
2307 .await
2308 .unwrap();
2309 sqlite
2310 .save_message(cid, "assistant", "first answer")
2311 .await
2312 .unwrap();
2313
2314 let use_parts = serde_json::to_string(&[
2318 MessagePart::ToolUse {
2319 id: "call_mid_1".to_string(),
2320 name: "shell".to_string(),
2321 input: serde_json::json!({"command": "ls"}),
2322 },
2323 MessagePart::Text {
2324 text: "Let me check the files.".to_string(),
2325 },
2326 ])
2327 .unwrap();
2328 sqlite
2329 .save_message_with_parts(cid, "assistant", "Let me check the files.", &use_parts)
2330 .await
2331 .unwrap();
2332
2333 sqlite
2335 .save_message(cid, "user", "second question")
2336 .await
2337 .unwrap();
2338 sqlite
2339 .save_message(cid, "assistant", "second answer")
2340 .await
2341 .unwrap();
2342
2343 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2344 std::sync::Arc::new(memory),
2345 cid,
2346 50,
2347 5,
2348 100,
2349 );
2350
2351 let messages_before = agent.msg.messages.len();
2352 agent.load_history().await.unwrap();
2353
2354 assert_eq!(
2357 agent.msg.messages.len(),
2358 messages_before + 5,
2359 "message count must be 5 (orphan message kept — has text content)"
2360 );
2361
2362 let orphan = &agent.msg.messages[messages_before + 2];
2364 assert_eq!(orphan.role, Role::Assistant);
2365 assert!(
2366 !orphan
2367 .parts
2368 .iter()
2369 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
2370 "orphaned ToolUse parts must be stripped from mid-history message"
2371 );
2372 assert!(
2374 orphan.parts.iter().any(
2375 |p| matches!(p, MessagePart::Text { text } if text == "Let me check the files.")
2376 ),
2377 "text content of orphaned assistant message must be preserved"
2378 );
2379 }
2380
2381 #[tokio::test]
2386 async fn load_history_keeps_tool_only_user_message() {
2387 use zeph_llm::provider::MessagePart;
2388
2389 let provider = mock_provider(vec![]);
2390 let channel = MockChannel::new(vec![]);
2391 let registry = create_test_registry();
2392 let executor = MockToolExecutor::no_tools();
2393
2394 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2395 let cid = memory.sqlite().create_conversation().await.unwrap();
2396 let sqlite = memory.sqlite();
2397
2398 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2400 id: "call_rc3".to_string(),
2401 name: "memory_save".to_string(),
2402 input: serde_json::json!({"content": "something"}),
2403 }])
2404 .unwrap();
2405 sqlite
2406 .save_message_with_parts(cid, "assistant", "[tool_use: memory_save]", &use_parts)
2407 .await
2408 .unwrap();
2409
2410 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2412 tool_use_id: "call_rc3".to_string(),
2413 content: "saved".to_string(),
2414 is_error: false,
2415 }])
2416 .unwrap();
2417 sqlite
2418 .save_message_with_parts(cid, "user", "", &result_parts)
2419 .await
2420 .unwrap();
2421
2422 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2423
2424 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2425 std::sync::Arc::new(memory),
2426 cid,
2427 50,
2428 5,
2429 100,
2430 );
2431
2432 let messages_before = agent.msg.messages.len();
2433 agent.load_history().await.unwrap();
2434
2435 assert_eq!(
2438 agent.msg.messages.len(),
2439 messages_before + 3,
2440 "user message with empty content but ToolResult parts must not be dropped"
2441 );
2442
2443 let user_msg = &agent.msg.messages[messages_before + 1];
2445 assert_eq!(user_msg.role, Role::User);
2446 assert!(
2447 user_msg.parts.iter().any(
2448 |p| matches!(p, MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_rc3")
2449 ),
2450 "ToolResult part must be preserved on user message with empty content"
2451 );
2452 }
2453
2454 #[tokio::test]
2458 async fn strip_orphans_removes_orphaned_tool_result() {
2459 use zeph_llm::provider::MessagePart;
2460
2461 let provider = mock_provider(vec![]);
2462 let channel = MockChannel::new(vec![]);
2463 let registry = create_test_registry();
2464 let executor = MockToolExecutor::no_tools();
2465
2466 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2467 let cid = memory.sqlite().create_conversation().await.unwrap();
2468 let sqlite = memory.sqlite();
2469
2470 sqlite.save_message(cid, "user", "hello").await.unwrap();
2472 sqlite.save_message(cid, "assistant", "hi").await.unwrap();
2473
2474 sqlite
2476 .save_message(cid, "assistant", "plain answer")
2477 .await
2478 .unwrap();
2479
2480 let orphan_result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2482 tool_use_id: "call_nonexistent".to_string(),
2483 content: "stale result".to_string(),
2484 is_error: false,
2485 }])
2486 .unwrap();
2487 sqlite
2488 .save_message_with_parts(
2489 cid,
2490 "user",
2491 "[tool_result: call_nonexistent]\nstale result",
2492 &orphan_result_parts,
2493 )
2494 .await
2495 .unwrap();
2496
2497 sqlite
2498 .save_message(cid, "assistant", "final")
2499 .await
2500 .unwrap();
2501
2502 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2503 std::sync::Arc::new(memory),
2504 cid,
2505 50,
2506 5,
2507 100,
2508 );
2509
2510 let messages_before = agent.msg.messages.len();
2511 agent.load_history().await.unwrap();
2512
2513 let loaded = &agent.msg.messages[messages_before..];
2517 for msg in loaded {
2518 assert!(
2519 !msg.parts.iter().any(|p| matches!(
2520 p,
2521 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_nonexistent"
2522 )),
2523 "orphaned ToolResult part must be stripped from history"
2524 );
2525 }
2526 }
2527
2528 #[tokio::test]
2531 async fn strip_orphans_keeps_complete_pair() {
2532 use zeph_llm::provider::MessagePart;
2533
2534 let provider = mock_provider(vec![]);
2535 let channel = MockChannel::new(vec![]);
2536 let registry = create_test_registry();
2537 let executor = MockToolExecutor::no_tools();
2538
2539 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2540 let cid = memory.sqlite().create_conversation().await.unwrap();
2541 let sqlite = memory.sqlite();
2542
2543 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2544 id: "call_valid".to_string(),
2545 name: "shell".to_string(),
2546 input: serde_json::json!({"command": "ls"}),
2547 }])
2548 .unwrap();
2549 sqlite
2550 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2551 .await
2552 .unwrap();
2553
2554 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2555 tool_use_id: "call_valid".to_string(),
2556 content: "file.rs".to_string(),
2557 is_error: false,
2558 }])
2559 .unwrap();
2560 sqlite
2561 .save_message_with_parts(cid, "user", "", &result_parts)
2562 .await
2563 .unwrap();
2564
2565 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2566 std::sync::Arc::new(memory),
2567 cid,
2568 50,
2569 5,
2570 100,
2571 );
2572
2573 let messages_before = agent.msg.messages.len();
2574 agent.load_history().await.unwrap();
2575
2576 assert_eq!(
2577 agent.msg.messages.len(),
2578 messages_before + 2,
2579 "complete tool_use/tool_result pair must be preserved"
2580 );
2581
2582 let user_msg = &agent.msg.messages[messages_before + 1];
2583 assert!(
2584 user_msg.parts.iter().any(|p| matches!(
2585 p,
2586 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_valid"
2587 )),
2588 "ToolResult part for a matched tool_use must not be stripped"
2589 );
2590 }
2591
2592 #[tokio::test]
2595 async fn strip_orphans_mixed_history() {
2596 use zeph_llm::provider::MessagePart;
2597
2598 let provider = mock_provider(vec![]);
2599 let channel = MockChannel::new(vec![]);
2600 let registry = create_test_registry();
2601 let executor = MockToolExecutor::no_tools();
2602
2603 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2604 let cid = memory.sqlite().create_conversation().await.unwrap();
2605 let sqlite = memory.sqlite();
2606
2607 let use_parts_ok = serde_json::to_string(&[MessagePart::ToolUse {
2609 id: "call_good".to_string(),
2610 name: "shell".to_string(),
2611 input: serde_json::json!({"command": "pwd"}),
2612 }])
2613 .unwrap();
2614 sqlite
2615 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts_ok)
2616 .await
2617 .unwrap();
2618
2619 let result_parts_ok = serde_json::to_string(&[MessagePart::ToolResult {
2620 tool_use_id: "call_good".to_string(),
2621 content: "/home".to_string(),
2622 is_error: false,
2623 }])
2624 .unwrap();
2625 sqlite
2626 .save_message_with_parts(cid, "user", "", &result_parts_ok)
2627 .await
2628 .unwrap();
2629
2630 sqlite
2632 .save_message(cid, "assistant", "text only")
2633 .await
2634 .unwrap();
2635
2636 let orphan_parts = serde_json::to_string(&[MessagePart::ToolResult {
2637 tool_use_id: "call_ghost".to_string(),
2638 content: "ghost result".to_string(),
2639 is_error: false,
2640 }])
2641 .unwrap();
2642 sqlite
2643 .save_message_with_parts(
2644 cid,
2645 "user",
2646 "[tool_result: call_ghost]\nghost result",
2647 &orphan_parts,
2648 )
2649 .await
2650 .unwrap();
2651
2652 sqlite
2653 .save_message(cid, "assistant", "final reply")
2654 .await
2655 .unwrap();
2656
2657 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2658 std::sync::Arc::new(memory),
2659 cid,
2660 50,
2661 5,
2662 100,
2663 );
2664
2665 let messages_before = agent.msg.messages.len();
2666 agent.load_history().await.unwrap();
2667
2668 let loaded = &agent.msg.messages[messages_before..];
2669
2670 for msg in loaded {
2672 assert!(
2673 !msg.parts.iter().any(|p| matches!(
2674 p,
2675 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_ghost"
2676 )),
2677 "orphaned ToolResult (call_ghost) must be stripped from history"
2678 );
2679 }
2680
2681 let has_good_result = loaded.iter().any(|msg| {
2684 msg.role == Role::User
2685 && msg.parts.iter().any(|p| {
2686 matches!(
2687 p,
2688 MessagePart::ToolResult { tool_use_id, .. } if tool_use_id == "call_good"
2689 )
2690 })
2691 });
2692 assert!(
2693 has_good_result,
2694 "matched ToolResult (call_good) must be preserved in history"
2695 );
2696 }
2697
2698 #[tokio::test]
2701 async fn sanitize_tool_pairs_preserves_matched_tool_pair() {
2702 use zeph_llm::provider::MessagePart;
2703
2704 let provider = mock_provider(vec![]);
2705 let channel = MockChannel::new(vec![]);
2706 let registry = create_test_registry();
2707 let executor = MockToolExecutor::no_tools();
2708
2709 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2710 let cid = memory.sqlite().create_conversation().await.unwrap();
2711 let sqlite = memory.sqlite();
2712
2713 sqlite
2714 .save_message(cid, "user", "run a command")
2715 .await
2716 .unwrap();
2717
2718 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2720 id: "call_ok".to_string(),
2721 name: "shell".to_string(),
2722 input: serde_json::json!({"command": "echo hi"}),
2723 }])
2724 .unwrap();
2725 sqlite
2726 .save_message_with_parts(cid, "assistant", "[tool_use: shell]", &use_parts)
2727 .await
2728 .unwrap();
2729
2730 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
2732 tool_use_id: "call_ok".to_string(),
2733 content: "hi".to_string(),
2734 is_error: false,
2735 }])
2736 .unwrap();
2737 sqlite
2738 .save_message_with_parts(cid, "user", "[tool_result: call_ok]\nhi", &result_parts)
2739 .await
2740 .unwrap();
2741
2742 sqlite.save_message(cid, "assistant", "done").await.unwrap();
2743
2744 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2745 std::sync::Arc::new(memory),
2746 cid,
2747 50,
2748 5,
2749 100,
2750 );
2751
2752 let messages_before = agent.msg.messages.len();
2753 agent.load_history().await.unwrap();
2754
2755 assert_eq!(
2757 agent.msg.messages.len(),
2758 messages_before + 4,
2759 "matched tool pair must not be removed"
2760 );
2761 let tool_msg = &agent.msg.messages[messages_before + 1];
2762 assert!(
2763 tool_msg
2764 .parts
2765 .iter()
2766 .any(|p| matches!(p, MessagePart::ToolUse { id, .. } if id == "call_ok")),
2767 "matched ToolUse parts must be preserved"
2768 );
2769 }
2770
2771 #[tokio::test]
2775 async fn persist_cancelled_tool_results_pairs_tool_use() {
2776 use zeph_llm::provider::MessagePart;
2777
2778 let provider = mock_provider(vec![]);
2779 let channel = MockChannel::new(vec![]);
2780 let registry = create_test_registry();
2781 let executor = MockToolExecutor::no_tools();
2782
2783 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2784 let cid = memory.sqlite().create_conversation().await.unwrap();
2785
2786 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
2787 std::sync::Arc::new(memory),
2788 cid,
2789 50,
2790 5,
2791 100,
2792 );
2793
2794 let tool_calls = vec![
2796 zeph_llm::provider::ToolUseRequest {
2797 id: "cancel_id_1".to_string(),
2798 name: "shell".to_string(),
2799 input: serde_json::json!({}),
2800 },
2801 zeph_llm::provider::ToolUseRequest {
2802 id: "cancel_id_2".to_string(),
2803 name: "read_file".to_string(),
2804 input: serde_json::json!({}),
2805 },
2806 ];
2807
2808 agent.persist_cancelled_tool_results(&tool_calls).await;
2809
2810 let history = agent
2811 .memory_state
2812 .memory
2813 .as_ref()
2814 .unwrap()
2815 .sqlite()
2816 .load_history(cid, 50)
2817 .await
2818 .unwrap();
2819
2820 assert_eq!(history.len(), 1);
2822 assert_eq!(history[0].role, Role::User);
2823
2824 for tc in &tool_calls {
2826 assert!(
2827 history[0].parts.iter().any(|p| matches!(
2828 p,
2829 MessagePart::ToolResult { tool_use_id, is_error, .. }
2830 if tool_use_id == &tc.id && *is_error
2831 )),
2832 "tombstone ToolResult for {} must be present and is_error=true",
2833 tc.id
2834 );
2835 }
2836 }
2837
2838 #[test]
2841 fn meaningful_content_empty_string() {
2842 assert!(!has_meaningful_content(""));
2843 }
2844
2845 #[test]
2846 fn meaningful_content_whitespace_only() {
2847 assert!(!has_meaningful_content(" \n\t "));
2848 }
2849
2850 #[test]
2851 fn meaningful_content_tool_use_only() {
2852 assert!(!has_meaningful_content("[tool_use: shell(call_1)]"));
2853 }
2854
2855 #[test]
2856 fn meaningful_content_tool_use_no_parens() {
2857 assert!(!has_meaningful_content("[tool_use: memory_save]"));
2859 }
2860
2861 #[test]
2862 fn meaningful_content_tool_result_with_body() {
2863 assert!(!has_meaningful_content(
2864 "[tool_result: call_1]\nsome output here"
2865 ));
2866 }
2867
2868 #[test]
2869 fn meaningful_content_tool_result_empty_body() {
2870 assert!(!has_meaningful_content("[tool_result: call_1]\n"));
2871 }
2872
2873 #[test]
2874 fn meaningful_content_tool_output_inline() {
2875 assert!(!has_meaningful_content("[tool output: bash] some result"));
2876 }
2877
2878 #[test]
2879 fn meaningful_content_tool_output_pruned() {
2880 assert!(!has_meaningful_content("[tool output: bash] (pruned)"));
2881 }
2882
2883 #[test]
2884 fn meaningful_content_tool_output_fenced() {
2885 assert!(!has_meaningful_content(
2886 "[tool output: bash]\n```\nls output\n```"
2887 ));
2888 }
2889
2890 #[test]
2891 fn meaningful_content_multiple_tool_use_tags() {
2892 assert!(!has_meaningful_content(
2893 "[tool_use: bash(id1)][tool_use: read(id2)]"
2894 ));
2895 }
2896
2897 #[test]
2898 fn meaningful_content_multiple_tool_use_tags_space_separator() {
2899 assert!(!has_meaningful_content(
2901 "[tool_use: bash(id1)] [tool_use: read(id2)]"
2902 ));
2903 }
2904
2905 #[test]
2906 fn meaningful_content_multiple_tool_use_tags_newline_separator() {
2907 assert!(!has_meaningful_content(
2909 "[tool_use: bash(id1)]\n[tool_use: read(id2)]"
2910 ));
2911 }
2912
2913 #[test]
2914 fn meaningful_content_tool_result_followed_by_tool_use() {
2915 assert!(!has_meaningful_content(
2917 "[tool_result: call_1]\nresult\n[tool_use: bash(call_2)]"
2918 ));
2919 }
2920
2921 #[test]
2922 fn meaningful_content_real_text_only() {
2923 assert!(has_meaningful_content("Hello, how can I help you?"));
2924 }
2925
2926 #[test]
2927 fn meaningful_content_text_before_tool_tag() {
2928 assert!(has_meaningful_content("Let me check. [tool_use: bash(id)]"));
2929 }
2930
2931 #[test]
2932 fn meaningful_content_text_after_tool_use_tag() {
2933 assert!(has_meaningful_content("[tool_use: bash] I ran the command"));
2937 }
2938
2939 #[test]
2940 fn meaningful_content_text_between_tags() {
2941 assert!(has_meaningful_content(
2942 "[tool_use: bash(id1)]\nand then\n[tool_use: read(id2)]"
2943 ));
2944 }
2945
2946 #[test]
2947 fn meaningful_content_malformed_tag_no_closing_bracket() {
2948 assert!(has_meaningful_content("[tool_use: "));
2950 }
2951
2952 #[test]
2953 fn meaningful_content_tool_use_and_tool_result_only() {
2954 assert!(!has_meaningful_content(
2956 "[tool_use: memory_save(call_abc)]\n[tool_result: call_abc]\nsaved"
2957 ));
2958 }
2959
2960 #[test]
2961 fn meaningful_content_tool_result_body_with_json_array() {
2962 assert!(!has_meaningful_content(
2963 "[tool_result: id1]\n[\"array\", \"value\"]"
2964 ));
2965 }
2966
2967 #[tokio::test]
2978 async fn issue_2529_orphaned_legacy_content_pair_is_soft_deleted() {
2979 use zeph_llm::provider::MessagePart;
2980
2981 let provider = mock_provider(vec![]);
2982 let channel = MockChannel::new(vec![]);
2983 let registry = create_test_registry();
2984 let executor = MockToolExecutor::no_tools();
2985
2986 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
2987 let cid = memory.sqlite().create_conversation().await.unwrap();
2988 let sqlite = memory.sqlite();
2989
2990 sqlite
2992 .save_message(cid, "user", "save this for me")
2993 .await
2994 .unwrap();
2995
2996 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
2999 id: "call_2529".to_string(),
3000 name: "memory_save".to_string(),
3001 input: serde_json::json!({"content": "save this"}),
3002 }])
3003 .unwrap();
3004 let orphan_assistant_id = sqlite
3005 .save_message_with_parts(
3006 cid,
3007 "assistant",
3008 "[tool_use: memory_save(call_2529)]",
3009 &use_parts,
3010 )
3011 .await
3012 .unwrap();
3013
3014 sqlite
3019 .save_message(cid, "assistant", "here is a plain reply")
3020 .await
3021 .unwrap();
3022
3023 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
3024 tool_use_id: "call_2529".to_string(),
3025 content: "saved".to_string(),
3026 is_error: false,
3027 }])
3028 .unwrap();
3029 let orphan_user_id = sqlite
3030 .save_message_with_parts(
3031 cid,
3032 "user",
3033 "[tool_result: call_2529]\nsaved",
3034 &result_parts,
3035 )
3036 .await
3037 .unwrap();
3038
3039 sqlite.save_message(cid, "assistant", "done").await.unwrap();
3041
3042 let memory_arc = std::sync::Arc::new(memory);
3043 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3044 memory_arc.clone(),
3045 cid,
3046 50,
3047 5,
3048 100,
3049 );
3050
3051 agent.load_history().await.unwrap();
3052
3053 let assistant_deleted_count: Vec<i64> = zeph_db::query_scalar(
3056 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
3057 )
3058 .bind(orphan_assistant_id)
3059 .fetch_all(memory_arc.sqlite().pool())
3060 .await
3061 .unwrap();
3062
3063 let user_deleted_count: Vec<i64> = zeph_db::query_scalar(
3064 "SELECT COUNT(*) FROM messages WHERE id = ? AND deleted_at IS NOT NULL",
3065 )
3066 .bind(orphan_user_id)
3067 .fetch_all(memory_arc.sqlite().pool())
3068 .await
3069 .unwrap();
3070
3071 assert_eq!(
3072 assistant_deleted_count.first().copied().unwrap_or(0),
3073 1,
3074 "orphaned assistant[ToolUse] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
3075 );
3076 assert_eq!(
3077 user_deleted_count.first().copied().unwrap_or(0),
3078 1,
3079 "orphaned user[ToolResult] with legacy-only content must be soft-deleted (deleted_at IS NOT NULL)"
3080 );
3081 }
3082
3083 #[tokio::test]
3087 async fn issue_2529_soft_delete_is_idempotent_across_sessions() {
3088 use zeph_llm::provider::MessagePart;
3089
3090 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3091 let cid = memory.sqlite().create_conversation().await.unwrap();
3092 let sqlite = memory.sqlite();
3093
3094 sqlite
3096 .save_message(cid, "user", "do something")
3097 .await
3098 .unwrap();
3099
3100 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
3102 id: "call_idem".to_string(),
3103 name: "shell".to_string(),
3104 input: serde_json::json!({"command": "ls"}),
3105 }])
3106 .unwrap();
3107 sqlite
3108 .save_message_with_parts(cid, "assistant", "[tool_use: shell(call_idem)]", &use_parts)
3109 .await
3110 .unwrap();
3111
3112 sqlite
3114 .save_message(cid, "assistant", "continuing")
3115 .await
3116 .unwrap();
3117
3118 let result_parts = serde_json::to_string(&[MessagePart::ToolResult {
3120 tool_use_id: "call_idem".to_string(),
3121 content: "output".to_string(),
3122 is_error: false,
3123 }])
3124 .unwrap();
3125 sqlite
3126 .save_message_with_parts(
3127 cid,
3128 "user",
3129 "[tool_result: call_idem]\noutput",
3130 &result_parts,
3131 )
3132 .await
3133 .unwrap();
3134
3135 sqlite
3136 .save_message(cid, "assistant", "final")
3137 .await
3138 .unwrap();
3139
3140 let memory_arc = std::sync::Arc::new(memory);
3141
3142 let mut agent1 = Agent::new(
3144 mock_provider(vec![]),
3145 MockChannel::new(vec![]),
3146 create_test_registry(),
3147 None,
3148 5,
3149 MockToolExecutor::no_tools(),
3150 )
3151 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
3152 agent1.load_history().await.unwrap();
3153 let count_after_first = agent1.msg.messages.len();
3154
3155 let mut agent2 = Agent::new(
3158 mock_provider(vec![]),
3159 MockChannel::new(vec![]),
3160 create_test_registry(),
3161 None,
3162 5,
3163 MockToolExecutor::no_tools(),
3164 )
3165 .with_memory(memory_arc.clone(), cid, 50, 5, 100);
3166 agent2.load_history().await.unwrap();
3167 let count_after_second = agent2.msg.messages.len();
3168
3169 assert_eq!(
3171 count_after_first, count_after_second,
3172 "second load_history must load the same message count as the first (soft-deleted orphans excluded)"
3173 );
3174 }
3175
3176 #[tokio::test]
3180 async fn issue_2529_message_with_text_and_tool_tag_is_kept_after_part_strip() {
3181 use zeph_llm::provider::MessagePart;
3182
3183 let provider = mock_provider(vec![]);
3184 let channel = MockChannel::new(vec![]);
3185 let registry = create_test_registry();
3186 let executor = MockToolExecutor::no_tools();
3187
3188 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3189 let cid = memory.sqlite().create_conversation().await.unwrap();
3190 let sqlite = memory.sqlite();
3191
3192 sqlite
3194 .save_message(cid, "user", "check the files")
3195 .await
3196 .unwrap();
3197
3198 let use_parts = serde_json::to_string(&[MessagePart::ToolUse {
3201 id: "call_mixed".to_string(),
3202 name: "shell".to_string(),
3203 input: serde_json::json!({"command": "ls"}),
3204 }])
3205 .unwrap();
3206 sqlite
3207 .save_message_with_parts(
3208 cid,
3209 "assistant",
3210 "Let me list the directory. [tool_use: shell(call_mixed)]",
3211 &use_parts,
3212 )
3213 .await
3214 .unwrap();
3215
3216 sqlite.save_message(cid, "user", "thanks").await.unwrap();
3218 sqlite
3219 .save_message(cid, "assistant", "you are welcome")
3220 .await
3221 .unwrap();
3222
3223 let memory_arc = std::sync::Arc::new(memory);
3224 let mut agent = Agent::new(provider, channel, registry, None, 5, executor).with_memory(
3225 memory_arc.clone(),
3226 cid,
3227 50,
3228 5,
3229 100,
3230 );
3231
3232 let messages_before = agent.msg.messages.len();
3233 agent.load_history().await.unwrap();
3234
3235 assert_eq!(
3237 agent.msg.messages.len(),
3238 messages_before + 4,
3239 "assistant message with text + tool tag must not be removed after ToolUse strip"
3240 );
3241
3242 let mixed_msg = agent
3244 .msg
3245 .messages
3246 .iter()
3247 .find(|m| m.content.contains("Let me list the directory"))
3248 .expect("mixed-content assistant message must still be in history");
3249 assert!(
3250 !mixed_msg
3251 .parts
3252 .iter()
3253 .any(|p| matches!(p, MessagePart::ToolUse { .. })),
3254 "orphaned ToolUse parts must be stripped even when message has meaningful text"
3255 );
3256 assert_eq!(
3257 mixed_msg.content, "Let me list the directory. [tool_use: shell(call_mixed)]",
3258 "content field must be unchanged — only parts are stripped"
3259 );
3260 }
3261
3262 #[tokio::test]
3265 async fn persist_message_skipped_tool_result_does_not_embed() {
3266 use zeph_llm::provider::MessagePart;
3267
3268 let provider = mock_provider(vec![]);
3269 let channel = MockChannel::new(vec![]);
3270 let registry = create_test_registry();
3271 let executor = MockToolExecutor::no_tools();
3272
3273 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
3274 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3275 let cid = memory.sqlite().create_conversation().await.unwrap();
3276
3277 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3278 .with_metrics(tx)
3279 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
3280 .with_autosave_config(true, 0);
3281
3282 let parts = vec![MessagePart::ToolResult {
3283 tool_use_id: "tu1".into(),
3284 content: "[skipped] bash tool was blocked by utility gate".into(),
3285 is_error: false,
3286 }];
3287
3288 agent
3289 .persist_message(
3290 Role::User,
3291 "[skipped] bash tool was blocked by utility gate",
3292 &parts,
3293 false,
3294 )
3295 .await;
3296
3297 assert_eq!(
3298 rx.borrow().embeddings_generated,
3299 0,
3300 "[skipped] ToolResult must not be embedded into Qdrant"
3301 );
3302 }
3303
3304 #[tokio::test]
3305 async fn persist_message_stopped_tool_result_does_not_embed() {
3306 use zeph_llm::provider::MessagePart;
3307
3308 let provider = mock_provider(vec![]);
3309 let channel = MockChannel::new(vec![]);
3310 let registry = create_test_registry();
3311 let executor = MockToolExecutor::no_tools();
3312
3313 let (tx, rx) = tokio::sync::watch::channel(MetricsSnapshot::default());
3314 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3315 let cid = memory.sqlite().create_conversation().await.unwrap();
3316
3317 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3318 .with_metrics(tx)
3319 .with_memory(std::sync::Arc::new(memory), cid, 50, 5, 100)
3320 .with_autosave_config(true, 0);
3321
3322 let parts = vec![MessagePart::ToolResult {
3323 tool_use_id: "tu2".into(),
3324 content: "[stopped] execution limit reached".into(),
3325 is_error: false,
3326 }];
3327
3328 agent
3329 .persist_message(
3330 Role::User,
3331 "[stopped] execution limit reached",
3332 &parts,
3333 false,
3334 )
3335 .await;
3336
3337 assert_eq!(
3338 rx.borrow().embeddings_generated,
3339 0,
3340 "[stopped] ToolResult must not be embedded into Qdrant"
3341 );
3342 }
3343
3344 #[tokio::test]
3345 async fn persist_message_normal_tool_result_is_saved_not_blocked_by_guard() {
3346 use zeph_llm::provider::MessagePart;
3349
3350 let provider = mock_provider(vec![]);
3351 let channel = MockChannel::new(vec![]);
3352 let registry = create_test_registry();
3353 let executor = MockToolExecutor::no_tools();
3354
3355 let memory = test_memory(&AnyProvider::Mock(zeph_llm::mock::MockProvider::default())).await;
3356 let cid = memory.sqlite().create_conversation().await.unwrap();
3357 let memory_arc = std::sync::Arc::new(memory);
3358
3359 let mut agent = Agent::new(provider, channel, registry, None, 5, executor)
3360 .with_memory(memory_arc.clone(), cid, 50, 5, 100)
3361 .with_autosave_config(true, 0);
3362
3363 let content = "total 42\ndrwxr-xr-x 5 user group";
3364 let parts = vec![MessagePart::ToolResult {
3365 tool_use_id: "tu3".into(),
3366 content: content.into(),
3367 is_error: false,
3368 }];
3369
3370 agent
3371 .persist_message(Role::User, content, &parts, false)
3372 .await;
3373
3374 let history = memory_arc.sqlite().load_history(cid, 50).await.unwrap();
3376 assert_eq!(
3377 history.len(),
3378 1,
3379 "normal ToolResult must be saved to SQLite"
3380 );
3381 assert_eq!(history[0].content, content);
3382 }
3383}