1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use zeph_common::ToolName;
16
17pub use zeph_common::ToolDefinition;
18
19use crate::embed::owned_strs;
20use crate::error::LlmError;
21
22static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
31-> Result<(serde_json::Value, String), crate::LlmError> {
32 let type_id = TypeId::of::<T>();
33 if let Ok(cache) = SCHEMA_CACHE.lock()
34 && let Some(entry) = cache.get(&type_id)
35 {
36 return Ok(entry.clone());
37 }
38 let schema = schemars::schema_for!(T);
39 let value = serde_json::to_value(&schema)
40 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
41 let pretty = serde_json::to_string_pretty(&schema)
42 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
43 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
44 cache.insert(type_id, (value.clone(), pretty.clone()));
45 }
46 Ok((value, pretty))
47}
48
49pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
63 std::any::type_name::<T>()
64 .rsplit("::")
65 .next()
66 .unwrap_or("Output")
67}
68
69#[non_exhaustive]
78#[derive(Debug, Clone, Default)]
79pub struct ChatExtras {
80 pub entropy: Option<f64>,
85}
86
87impl ChatExtras {
88 #[must_use]
101 pub fn with_entropy(entropy: f64) -> Self {
102 Self {
103 entropy: Some(entropy),
104 }
105 }
106}
107
108#[non_exhaustive]
113#[derive(Debug, Clone)]
114pub enum StreamChunk {
115 Content(String),
117 Thinking(String),
119 Compaction(String),
122 ToolUse(Vec<ToolUseRequest>),
124}
125
126pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ToolUseRequest {
139 pub id: String,
141 pub name: ToolName,
143 pub input: serde_json::Value,
145}
146
147#[non_exhaustive]
153#[derive(Debug, Clone)]
154pub enum ThinkingBlock {
155 Thinking { thinking: String, signature: String },
157 Redacted { data: String },
159}
160
161pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
164
165#[non_exhaustive]
173#[derive(Debug, Clone)]
174pub enum ChatResponse {
175 Text(String),
177 ToolUse {
179 text: Option<String>,
181 tool_calls: Vec<ToolUseRequest>,
182 thinking_blocks: Vec<ThinkingBlock>,
185 },
186}
187
188pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
190
191pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
196
197pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
203
204#[must_use]
207pub fn default_debug_request_json(
208 messages: &[Message],
209 tools: &[ToolDefinition],
210) -> serde_json::Value {
211 serde_json::json!({
212 "model": serde_json::Value::Null,
213 "max_tokens": serde_json::Value::Null,
214 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
215 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
216 "temperature": serde_json::Value::Null,
217 "cache_control": serde_json::Value::Null,
218 })
219}
220
221#[derive(Debug, Clone, Default)]
230pub struct GenerationOverrides {
231 pub temperature: Option<f64>,
233 pub top_p: Option<f64>,
235 pub top_k: Option<usize>,
237 pub frequency_penalty: Option<f64>,
239 pub presence_penalty: Option<f64>,
241}
242
243#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
250#[serde(rename_all = "lowercase")]
251pub enum Role {
252 System,
253 User,
254 Assistant,
255}
256
257#[non_exhaustive]
272#[derive(Clone, Debug, Serialize, Deserialize)]
273#[serde(tag = "kind", rename_all = "snake_case")]
274pub enum MessagePart {
275 Text { text: String },
277 ToolOutput {
279 tool_name: zeph_common::ToolName,
280 body: String,
281 #[serde(default, skip_serializing_if = "Option::is_none")]
282 compacted_at: Option<i64>,
283 },
284 Recall { text: String },
286 CodeContext { text: String },
288 Summary { text: String },
290 CrossSession { text: String },
292 ToolUse {
294 id: String,
295 name: String,
296 input: serde_json::Value,
297 },
298 ToolResult {
300 tool_use_id: String,
301 content: String,
302 #[serde(default)]
303 is_error: bool,
304 },
305 Image(Box<ImageData>),
307 ThinkingBlock { thinking: String, signature: String },
309 RedactedThinkingBlock { data: String },
311 Compaction { summary: String },
314}
315
316impl MessagePart {
317 #[must_use]
320 pub fn as_plain_text(&self) -> Option<&str> {
321 match self {
322 Self::Text { text }
323 | Self::Recall { text }
324 | Self::CodeContext { text }
325 | Self::Summary { text }
326 | Self::CrossSession { text } => Some(text.as_str()),
327 _ => None,
328 }
329 }
330
331 #[must_use]
333 pub fn as_image(&self) -> Option<&ImageData> {
334 if let Self::Image(img) = self {
335 Some(img)
336 } else {
337 None
338 }
339 }
340}
341
342#[derive(Clone, Debug, Serialize, Deserialize)]
343pub struct ImageData {
348 #[serde(with = "serde_bytes_base64")]
349 pub data: Vec<u8>,
350 pub mime_type: String,
351}
352
353mod serde_bytes_base64 {
354 use base64::{Engine, engine::general_purpose::STANDARD};
355 use serde::{Deserialize, Deserializer, Serializer};
356
357 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
358 where
359 S: Serializer,
360 {
361 s.serialize_str(&STANDARD.encode(bytes))
362 }
363
364 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
365 where
366 D: Deserializer<'de>,
367 {
368 let s = String::deserialize(d)?;
369 STANDARD.decode(&s).map_err(serde::de::Error::custom)
370 }
371}
372
373#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
389#[serde(rename_all = "snake_case")]
390#[non_exhaustive]
391pub enum MessageVisibility {
392 Both,
394 AgentOnly,
396 UserOnly,
398}
399
400impl MessageVisibility {
401 #[must_use]
403 pub fn is_agent_visible(self) -> bool {
404 matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
405 }
406
407 #[must_use]
409 pub fn is_user_visible(self) -> bool {
410 matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
411 }
412}
413
414impl Default for MessageVisibility {
415 fn default() -> Self {
417 MessageVisibility::Both
418 }
419}
420
421impl MessageVisibility {
422 #[must_use]
424 pub fn as_db_str(self) -> &'static str {
425 match self {
426 MessageVisibility::Both => "both",
427 MessageVisibility::AgentOnly => "agent_only",
428 MessageVisibility::UserOnly => "user_only",
429 }
430 }
431
432 #[must_use]
436 pub fn from_db_str(s: &str) -> Self {
437 match s {
438 "agent_only" => MessageVisibility::AgentOnly,
439 "user_only" => MessageVisibility::UserOnly,
440 _ => MessageVisibility::Both,
441 }
442 }
443}
444
445#[derive(Clone, Debug, Serialize, Deserialize)]
450pub struct MessageMetadata {
451 pub visibility: MessageVisibility,
453 #[serde(default, skip_serializing_if = "Option::is_none")]
455 pub compacted_at: Option<i64>,
456 #[serde(default, skip_serializing_if = "Option::is_none")]
459 pub deferred_summary: Option<String>,
460 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
463 pub focus_pinned: bool,
464 #[serde(default, skip_serializing_if = "Option::is_none")]
467 pub focus_marker_id: Option<uuid::Uuid>,
468 #[serde(skip)]
471 pub db_id: Option<i64>,
472 #[serde(default, skip_serializing_if = "Option::is_none")]
477 pub fidelity_tag: Option<zeph_common::ContextFidelity>,
478 #[serde(skip)]
482 pub embedding: Option<Vec<f32>>,
483}
484
485impl Default for MessageMetadata {
486 fn default() -> Self {
487 Self {
488 visibility: MessageVisibility::Both,
489 compacted_at: None,
490 deferred_summary: None,
491 focus_pinned: false,
492 focus_marker_id: None,
493 db_id: None,
494 fidelity_tag: None,
495 embedding: None,
496 }
497 }
498}
499
500impl MessageMetadata {
501 #[must_use]
503 pub fn agent_only() -> Self {
504 Self {
505 visibility: MessageVisibility::AgentOnly,
506 compacted_at: None,
507 deferred_summary: None,
508 focus_pinned: false,
509 focus_marker_id: None,
510 db_id: None,
511 fidelity_tag: None,
512 embedding: None,
513 }
514 }
515
516 #[must_use]
518 pub fn user_only() -> Self {
519 Self {
520 visibility: MessageVisibility::UserOnly,
521 compacted_at: None,
522 deferred_summary: None,
523 focus_pinned: false,
524 focus_marker_id: None,
525 db_id: None,
526 fidelity_tag: None,
527 embedding: None,
528 }
529 }
530
531 #[must_use]
533 pub fn focus_pinned() -> Self {
534 Self {
535 visibility: MessageVisibility::AgentOnly,
536 compacted_at: None,
537 deferred_summary: None,
538 focus_pinned: true,
539 focus_marker_id: None,
540 db_id: None,
541 fidelity_tag: None,
542 embedding: None,
543 }
544 }
545}
546
547#[derive(Clone, Debug, Serialize, Deserialize)]
574pub struct Message {
575 pub role: Role,
576 pub content: String,
578 #[serde(default)]
579 pub parts: Vec<MessagePart>,
580 #[serde(default)]
581 pub metadata: MessageMetadata,
582}
583
584impl Default for Message {
585 fn default() -> Self {
586 Self {
587 role: Role::User,
588 content: String::new(),
589 parts: vec![],
590 metadata: MessageMetadata::default(),
591 }
592 }
593}
594
595impl Message {
596 #[must_use]
601 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
602 Self {
603 role,
604 content: content.into(),
605 parts: vec![],
606 metadata: MessageMetadata::default(),
607 }
608 }
609
610 #[must_use]
615 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
616 let content = Self::flatten_parts(&parts);
617 Self {
618 role,
619 content,
620 parts,
621 metadata: MessageMetadata::default(),
622 }
623 }
624
625 #[must_use]
628 pub fn to_llm_content(&self) -> &str {
629 &self.content
630 }
631
632 pub fn rebuild_content(&mut self) {
634 if !self.parts.is_empty() {
635 self.content = Self::flatten_parts(&self.parts);
636 }
637 }
638
639 fn flatten_parts(parts: &[MessagePart]) -> String {
640 use std::fmt::Write;
641 let mut out = String::new();
642 for part in parts {
643 match part {
644 MessagePart::Text { text }
645 | MessagePart::Recall { text }
646 | MessagePart::CodeContext { text }
647 | MessagePart::Summary { text }
648 | MessagePart::CrossSession { text } => out.push_str(text),
649 MessagePart::ToolOutput {
650 tool_name,
651 body,
652 compacted_at,
653 } => {
654 if compacted_at.is_some() {
655 if body.is_empty() {
656 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
657 } else {
658 let _ = write!(out, "[tool output: {tool_name}] {body}");
659 }
660 } else {
661 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
662 }
663 }
664 MessagePart::ToolUse { id, name, .. } => {
665 let _ = write!(out, "[tool_use: {name}({id})]");
666 }
667 MessagePart::ToolResult {
668 tool_use_id,
669 content,
670 ..
671 } => {
672 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
673 }
674 MessagePart::Image(img) => {
675 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
676 }
677 MessagePart::ThinkingBlock { .. }
679 | MessagePart::RedactedThinkingBlock { .. }
680 | MessagePart::Compaction { .. } => {}
681 }
682 }
683 out
684 }
685}
686
687pub trait LlmProvider: Send + Sync {
755 fn context_window(&self) -> Option<usize> {
759 None
760 }
761
762 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
768
769 fn chat_stream(
775 &self,
776 messages: &[Message],
777 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
778
779 fn supports_streaming(&self) -> bool;
781
782 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
788
789 fn embed_batch(
799 &self,
800 texts: &[&str],
801 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
802 let owned = owned_strs(texts);
803 async move {
804 let mut results = Vec::with_capacity(owned.len());
805 for text in &owned {
806 results.push(self.embed(text).await?);
807 }
808 Ok(results)
809 }
810 }
811
812 fn supports_embeddings(&self) -> bool;
814
815 fn name(&self) -> &str;
817
818 #[allow(clippy::unnecessary_literal_bound)]
821 fn model_identifier(&self) -> &str {
822 ""
823 }
824
825 fn supports_vision(&self) -> bool {
827 false
828 }
829
830 fn supports_tool_use(&self) -> bool {
832 true
833 }
834
835 fn chat_with_tools(
843 &self,
844 messages: &[Message],
845 _tools: &[ToolDefinition],
846 ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
847 let msgs = messages.to_vec();
848 async move { Ok(ChatResponse::Text(self.chat(&msgs).await?)) }
849 }
850
851 fn last_cache_usage(&self) -> Option<(u64, u64)> {
854 None
855 }
856
857 fn last_usage(&self) -> Option<(u64, u64)> {
860 None
861 }
862
863 fn last_reasoning_tokens(&self) -> Option<u64> {
868 None
869 }
870
871 fn take_compaction_summary(&self) -> Option<String> {
874 None
875 }
876
877 fn chat_with_extras(
892 &self,
893 messages: &[Message],
894 ) -> impl Future<Output = Result<(String, ChatExtras), LlmError>> + Send {
895 let msgs = messages.to_vec();
896 async move { Ok((self.chat(&msgs).await?, ChatExtras::default())) }
897 }
898
899 #[must_use]
903 fn debug_request_json(
904 &self,
905 messages: &[Message],
906 tools: &[ToolDefinition],
907 _stream: bool,
908 ) -> serde_json::Value {
909 default_debug_request_json(messages, tools)
910 }
911
912 fn list_models(&self) -> Vec<String> {
915 vec![]
916 }
917
918 fn supports_structured_output(&self) -> bool {
920 false
921 }
922
923 #[allow(async_fn_in_trait)]
934 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
935 where
936 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
937 Self: Sized,
938 {
939 let (_, schema_json) = cached_schema::<T>()?;
940 let type_name = short_type_name::<T>();
941
942 let mut augmented = messages.to_vec();
943 let instruction = format!(
944 "Respond with a valid JSON object matching this schema. \
945 Output ONLY the JSON, no markdown fences or extra text.\n\n\
946 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
947 );
948 augmented.insert(0, Message::from_legacy(Role::System, instruction));
949
950 let raw = self.chat(&augmented).await?;
951 let cleaned = strip_json_fences(&raw);
952 match serde_json::from_str::<T>(cleaned) {
953 Ok(val) => Ok(val),
954 Err(first_err) => {
955 augmented.push(Message::from_legacy(Role::Assistant, &raw));
956 augmented.push(Message::from_legacy(
957 Role::User,
958 format!(
959 "Your response was not valid JSON. Error: {first_err}. \
960 Please output ONLY valid JSON matching the schema."
961 ),
962 ));
963 let retry_raw = self.chat(&augmented).await?;
964 let retry_cleaned = strip_json_fences(&retry_raw);
965 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
966 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
967 })
968 }
969 }
970 }
971}
972
973fn strip_json_fences(s: &str) -> &str {
977 s.trim()
978 .trim_start_matches("```json")
979 .trim_start_matches("```")
980 .trim_end_matches("```")
981 .trim()
982}
983
984#[cfg(test)]
985mod tests {
986 use tokio_stream::StreamExt;
987
988 use super::*;
989
990 struct StubProvider {
991 response: String,
992 }
993
994 impl LlmProvider for StubProvider {
995 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
996 Ok(self.response.clone())
997 }
998
999 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1000 let response = self.chat(messages).await?;
1001 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1002 response,
1003 )))))
1004 }
1005
1006 fn supports_streaming(&self) -> bool {
1007 false
1008 }
1009
1010 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1011 Ok(vec![0.1, 0.2, 0.3])
1012 }
1013
1014 fn supports_embeddings(&self) -> bool {
1015 false
1016 }
1017
1018 fn name(&self) -> &'static str {
1019 "stub"
1020 }
1021 }
1022
1023 #[test]
1024 fn context_window_default_returns_none() {
1025 let provider = StubProvider {
1026 response: String::new(),
1027 };
1028 assert!(provider.context_window().is_none());
1029 }
1030
1031 #[test]
1032 fn supports_streaming_default_returns_false() {
1033 let provider = StubProvider {
1034 response: String::new(),
1035 };
1036 assert!(!provider.supports_streaming());
1037 }
1038
1039 #[tokio::test]
1040 async fn chat_stream_default_yields_single_chunk() {
1041 let provider = StubProvider {
1042 response: "hello world".into(),
1043 };
1044 let messages = vec![Message {
1045 role: Role::User,
1046 content: "test".into(),
1047 parts: vec![],
1048 metadata: MessageMetadata::default(),
1049 }];
1050
1051 let mut stream = provider.chat_stream(&messages).await.unwrap();
1052 let chunk = stream.next().await.unwrap().unwrap();
1053 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
1054 assert!(stream.next().await.is_none());
1055 }
1056
1057 #[tokio::test]
1058 async fn chat_stream_default_propagates_chat_error() {
1059 struct FailProvider;
1060
1061 impl LlmProvider for FailProvider {
1062 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1063 Err(LlmError::Unavailable)
1064 }
1065
1066 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1067 let response = self.chat(messages).await?;
1068 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1069 response,
1070 )))))
1071 }
1072
1073 fn supports_streaming(&self) -> bool {
1074 false
1075 }
1076
1077 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1078 Err(LlmError::Unavailable)
1079 }
1080
1081 fn supports_embeddings(&self) -> bool {
1082 false
1083 }
1084
1085 fn name(&self) -> &'static str {
1086 "fail"
1087 }
1088 }
1089
1090 let provider = FailProvider;
1091 let messages = vec![Message {
1092 role: Role::User,
1093 content: "test".into(),
1094 parts: vec![],
1095 metadata: MessageMetadata::default(),
1096 }];
1097
1098 let result = provider.chat_stream(&messages).await;
1099 assert!(result.is_err());
1100 if let Err(e) = result {
1101 assert!(e.to_string().contains("provider unavailable"));
1102 }
1103 }
1104
1105 #[tokio::test]
1106 async fn stub_provider_embed_returns_vector() {
1107 let provider = StubProvider {
1108 response: String::new(),
1109 };
1110 let embedding = provider.embed("test").await.unwrap();
1111 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
1112 }
1113
1114 #[tokio::test]
1115 async fn fail_provider_embed_propagates_error() {
1116 struct FailProvider;
1117
1118 impl LlmProvider for FailProvider {
1119 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1120 Err(LlmError::Unavailable)
1121 }
1122
1123 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1124 let response = self.chat(messages).await?;
1125 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1126 response,
1127 )))))
1128 }
1129
1130 fn supports_streaming(&self) -> bool {
1131 false
1132 }
1133
1134 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1135 Err(LlmError::EmbedUnsupported {
1136 provider: "fail".into(),
1137 })
1138 }
1139
1140 fn supports_embeddings(&self) -> bool {
1141 false
1142 }
1143
1144 fn name(&self) -> &'static str {
1145 "fail"
1146 }
1147 }
1148
1149 let provider = FailProvider;
1150 let result = provider.embed("test").await;
1151 assert!(result.is_err());
1152 assert!(
1153 result
1154 .unwrap_err()
1155 .to_string()
1156 .contains("embedding not supported")
1157 );
1158 }
1159
1160 #[test]
1161 fn role_serialization() {
1162 let system = Role::System;
1163 let user = Role::User;
1164 let assistant = Role::Assistant;
1165
1166 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
1167 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
1168 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
1169 }
1170
1171 #[test]
1172 fn role_deserialization() {
1173 let system: Role = serde_json::from_str("\"system\"").unwrap();
1174 let user: Role = serde_json::from_str("\"user\"").unwrap();
1175 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
1176
1177 assert_eq!(system, Role::System);
1178 assert_eq!(user, Role::User);
1179 assert_eq!(assistant, Role::Assistant);
1180 }
1181
1182 #[test]
1183 fn message_clone() {
1184 let msg = Message {
1185 role: Role::User,
1186 content: "test".into(),
1187 parts: vec![],
1188 metadata: MessageMetadata::default(),
1189 };
1190 let cloned = msg.clone();
1191 assert_eq!(cloned.role, msg.role);
1192 assert_eq!(cloned.content, msg.content);
1193 }
1194
1195 #[test]
1196 fn message_debug() {
1197 let msg = Message {
1198 role: Role::Assistant,
1199 content: "response".into(),
1200 parts: vec![],
1201 metadata: MessageMetadata::default(),
1202 };
1203 let debug = format!("{msg:?}");
1204 assert!(debug.contains("Assistant"));
1205 assert!(debug.contains("response"));
1206 }
1207
1208 #[test]
1209 fn message_serialization() {
1210 let msg = Message {
1211 role: Role::User,
1212 content: "hello".into(),
1213 parts: vec![],
1214 metadata: MessageMetadata::default(),
1215 };
1216 let json = serde_json::to_string(&msg).unwrap();
1217 assert!(json.contains("\"role\":\"user\""));
1218 assert!(json.contains("\"content\":\"hello\""));
1219 }
1220
1221 #[test]
1222 fn message_part_serde_round_trip() {
1223 let parts = vec![
1224 MessagePart::Text {
1225 text: "hello".into(),
1226 },
1227 MessagePart::ToolOutput {
1228 tool_name: "bash".into(),
1229 body: "output".into(),
1230 compacted_at: None,
1231 },
1232 MessagePart::Recall {
1233 text: "recall".into(),
1234 },
1235 MessagePart::CodeContext {
1236 text: "code".into(),
1237 },
1238 MessagePart::Summary {
1239 text: "summary".into(),
1240 },
1241 ];
1242 let json = serde_json::to_string(&parts).unwrap();
1243 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
1244 assert_eq!(deserialized.len(), 5);
1245 }
1246
1247 #[test]
1248 fn from_legacy_creates_empty_parts() {
1249 let msg = Message::from_legacy(Role::User, "hello");
1250 assert_eq!(msg.role, Role::User);
1251 assert_eq!(msg.content, "hello");
1252 assert!(msg.parts.is_empty());
1253 assert_eq!(msg.to_llm_content(), "hello");
1254 }
1255
1256 #[test]
1257 fn from_parts_flattens_content() {
1258 let msg = Message::from_parts(
1259 Role::System,
1260 vec![MessagePart::Recall {
1261 text: "recalled data".into(),
1262 }],
1263 );
1264 assert_eq!(msg.content, "recalled data");
1265 assert_eq!(msg.to_llm_content(), "recalled data");
1266 assert_eq!(msg.parts.len(), 1);
1267 }
1268
1269 #[test]
1270 fn from_parts_tool_output_format() {
1271 let msg = Message::from_parts(
1272 Role::User,
1273 vec![MessagePart::ToolOutput {
1274 tool_name: "bash".into(),
1275 body: "hello world".into(),
1276 compacted_at: None,
1277 }],
1278 );
1279 assert!(msg.content.contains("[tool output: bash]"));
1280 assert!(msg.content.contains("hello world"));
1281 }
1282
1283 #[test]
1284 fn message_deserializes_without_parts() {
1285 let json = r#"{"role":"user","content":"hello"}"#;
1286 let msg: Message = serde_json::from_str(json).unwrap();
1287 assert_eq!(msg.content, "hello");
1288 assert!(msg.parts.is_empty());
1289 }
1290
1291 #[test]
1292 fn flatten_skips_compacted_tool_output_empty_body() {
1293 let msg = Message::from_parts(
1295 Role::User,
1296 vec![
1297 MessagePart::Text {
1298 text: "prefix ".into(),
1299 },
1300 MessagePart::ToolOutput {
1301 tool_name: "bash".into(),
1302 body: String::new(),
1303 compacted_at: Some(1234),
1304 },
1305 MessagePart::Text {
1306 text: " suffix".into(),
1307 },
1308 ],
1309 );
1310 assert!(msg.content.contains("(pruned)"));
1311 assert!(msg.content.contains("prefix "));
1312 assert!(msg.content.contains(" suffix"));
1313 }
1314
1315 #[test]
1316 fn flatten_compacted_tool_output_with_reference_renders_body() {
1317 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1319 let msg = Message::from_parts(
1320 Role::User,
1321 vec![MessagePart::ToolOutput {
1322 tool_name: "bash".into(),
1323 body: ref_notice.into(),
1324 compacted_at: Some(1234),
1325 }],
1326 );
1327 assert!(msg.content.contains(ref_notice));
1328 assert!(!msg.content.contains("(pruned)"));
1329 }
1330
1331 #[test]
1332 fn rebuild_content_syncs_after_mutation() {
1333 let mut msg = Message::from_parts(
1334 Role::User,
1335 vec![MessagePart::ToolOutput {
1336 tool_name: "bash".into(),
1337 body: "original".into(),
1338 compacted_at: None,
1339 }],
1340 );
1341 assert!(msg.content.contains("original"));
1342
1343 if let MessagePart::ToolOutput {
1344 ref mut compacted_at,
1345 ref mut body,
1346 ..
1347 } = msg.parts[0]
1348 {
1349 *compacted_at = Some(999);
1350 body.clear(); }
1352 msg.rebuild_content();
1353
1354 assert!(msg.content.contains("(pruned)"));
1355 assert!(!msg.content.contains("original"));
1356 }
1357
1358 #[test]
1359 fn message_part_tool_use_serde_round_trip() {
1360 let part = MessagePart::ToolUse {
1361 id: "toolu_123".into(),
1362 name: "bash".into(),
1363 input: serde_json::json!({"command": "ls"}),
1364 };
1365 let json = serde_json::to_string(&part).unwrap();
1366 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1367 if let MessagePart::ToolUse { id, name, input } = deserialized {
1368 assert_eq!(id, "toolu_123");
1369 assert_eq!(name, "bash");
1370 assert_eq!(input["command"], "ls");
1371 } else {
1372 panic!("expected ToolUse");
1373 }
1374 }
1375
1376 #[test]
1377 fn message_part_tool_result_serde_round_trip() {
1378 let part = MessagePart::ToolResult {
1379 tool_use_id: "toolu_123".into(),
1380 content: "file1.rs\nfile2.rs".into(),
1381 is_error: false,
1382 };
1383 let json = serde_json::to_string(&part).unwrap();
1384 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1385 if let MessagePart::ToolResult {
1386 tool_use_id,
1387 content,
1388 is_error,
1389 } = deserialized
1390 {
1391 assert_eq!(tool_use_id, "toolu_123");
1392 assert_eq!(content, "file1.rs\nfile2.rs");
1393 assert!(!is_error);
1394 } else {
1395 panic!("expected ToolResult");
1396 }
1397 }
1398
1399 #[test]
1400 fn message_part_tool_result_is_error_default() {
1401 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1402 let part: MessagePart = serde_json::from_str(json).unwrap();
1403 if let MessagePart::ToolResult { is_error, .. } = part {
1404 assert!(!is_error);
1405 } else {
1406 panic!("expected ToolResult");
1407 }
1408 }
1409
1410 #[test]
1411 fn chat_response_construction() {
1412 let text = ChatResponse::Text("hello".into());
1413 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1414
1415 let tool_use = ChatResponse::ToolUse {
1416 text: Some("I'll run that".into()),
1417 tool_calls: vec![ToolUseRequest {
1418 id: "1".into(),
1419 name: "bash".into(),
1420 input: serde_json::json!({}),
1421 }],
1422 thinking_blocks: vec![],
1423 };
1424 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1425 }
1426
1427 #[test]
1428 fn flatten_parts_tool_use() {
1429 let msg = Message::from_parts(
1430 Role::Assistant,
1431 vec![MessagePart::ToolUse {
1432 id: "t1".into(),
1433 name: "bash".into(),
1434 input: serde_json::json!({"command": "ls"}),
1435 }],
1436 );
1437 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1438 }
1439
1440 #[test]
1441 fn flatten_parts_tool_result() {
1442 let msg = Message::from_parts(
1443 Role::User,
1444 vec![MessagePart::ToolResult {
1445 tool_use_id: "t1".into(),
1446 content: "output here".into(),
1447 is_error: false,
1448 }],
1449 );
1450 assert!(msg.content.contains("[tool_result: t1]"));
1451 assert!(msg.content.contains("output here"));
1452 }
1453
1454 #[test]
1455 fn tool_definition_serde_round_trip() {
1456 let def = ToolDefinition {
1457 name: "bash".into(),
1458 description: "Execute a shell command".into(),
1459 parameters: serde_json::json!({"type": "object"}),
1460 output_schema: None,
1461 };
1462 let json = serde_json::to_string(&def).unwrap();
1463 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1464 assert_eq!(deserialized.name, "bash");
1465 assert_eq!(deserialized.description, "Execute a shell command");
1466 }
1467
1468 #[tokio::test]
1469 async fn chat_with_tools_default_delegates_to_chat() {
1470 let provider = StubProvider {
1471 response: "hello".into(),
1472 };
1473 let messages = vec![Message::from_legacy(Role::User, "test")];
1474 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1475 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1476 }
1477
1478 #[test]
1479 fn tool_output_compacted_at_serde_default() {
1480 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1481 let part: MessagePart = serde_json::from_str(json).unwrap();
1482 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1483 assert!(compacted_at.is_none());
1484 } else {
1485 panic!("expected ToolOutput");
1486 }
1487 }
1488
1489 #[test]
1492 fn strip_json_fences_plain_json() {
1493 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1494 }
1495
1496 #[test]
1497 fn strip_json_fences_with_json_fence() {
1498 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1499 }
1500
1501 #[test]
1502 fn strip_json_fences_with_plain_fence() {
1503 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1504 }
1505
1506 #[test]
1507 fn strip_json_fences_whitespace() {
1508 assert_eq!(strip_json_fences(" \n "), "");
1509 }
1510
1511 #[test]
1512 fn strip_json_fences_empty() {
1513 assert_eq!(strip_json_fences(""), "");
1514 }
1515
1516 #[test]
1517 fn strip_json_fences_outer_whitespace() {
1518 assert_eq!(
1519 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1520 r#"{"a": 1}"#
1521 );
1522 }
1523
1524 #[test]
1525 fn strip_json_fences_only_opening_fence() {
1526 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1527 }
1528
1529 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1532 struct TestOutput {
1533 value: String,
1534 }
1535
1536 struct SequentialStub {
1537 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1538 }
1539
1540 impl SequentialStub {
1541 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1542 Self {
1543 responses: std::sync::Mutex::new(responses),
1544 }
1545 }
1546 }
1547
1548 impl LlmProvider for SequentialStub {
1549 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1550 let mut responses = self.responses.lock().unwrap();
1551 if responses.is_empty() {
1552 return Err(LlmError::Other("no more responses".into()));
1553 }
1554 responses.remove(0)
1555 }
1556
1557 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1558 let response = self.chat(messages).await?;
1559 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1560 response,
1561 )))))
1562 }
1563
1564 fn supports_streaming(&self) -> bool {
1565 false
1566 }
1567
1568 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1569 Err(LlmError::EmbedUnsupported {
1570 provider: "sequential-stub".into(),
1571 })
1572 }
1573
1574 fn supports_embeddings(&self) -> bool {
1575 false
1576 }
1577
1578 fn name(&self) -> &'static str {
1579 "sequential-stub"
1580 }
1581 }
1582
1583 #[tokio::test]
1584 async fn chat_typed_happy_path() {
1585 let provider = StubProvider {
1586 response: r#"{"value": "hello"}"#.into(),
1587 };
1588 let messages = vec![Message::from_legacy(Role::User, "test")];
1589 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1590 assert_eq!(
1591 result,
1592 TestOutput {
1593 value: "hello".into()
1594 }
1595 );
1596 }
1597
1598 #[tokio::test]
1599 async fn chat_typed_retry_succeeds() {
1600 let provider = SequentialStub::new(vec![
1601 Ok("not valid json".into()),
1602 Ok(r#"{"value": "ok"}"#.into()),
1603 ]);
1604 let messages = vec![Message::from_legacy(Role::User, "test")];
1605 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1606 assert_eq!(result, TestOutput { value: "ok".into() });
1607 }
1608
1609 #[tokio::test]
1610 async fn chat_typed_both_fail() {
1611 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1612 let messages = vec![Message::from_legacy(Role::User, "test")];
1613 let result = provider.chat_typed::<TestOutput>(&messages).await;
1614 let err = result.unwrap_err();
1615 assert!(err.to_string().contains("parse failed after retry"));
1616 }
1617
1618 #[tokio::test]
1619 async fn chat_typed_chat_error_propagates() {
1620 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1621 let messages = vec![Message::from_legacy(Role::User, "test")];
1622 let result = provider.chat_typed::<TestOutput>(&messages).await;
1623 assert!(matches!(result, Err(LlmError::Unavailable)));
1624 }
1625
1626 #[tokio::test]
1627 async fn chat_typed_strips_fences() {
1628 let provider = StubProvider {
1629 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1630 };
1631 let messages = vec![Message::from_legacy(Role::User, "test")];
1632 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1633 assert_eq!(
1634 result,
1635 TestOutput {
1636 value: "fenced".into()
1637 }
1638 );
1639 }
1640
1641 #[test]
1642 fn supports_structured_output_default_false() {
1643 let provider = StubProvider {
1644 response: String::new(),
1645 };
1646 assert!(!provider.supports_structured_output());
1647 }
1648
1649 #[test]
1650 fn structured_parse_error_display() {
1651 let err = LlmError::StructuredParse("test error".into());
1652 assert_eq!(
1653 err.to_string(),
1654 "structured output parse failed: test error"
1655 );
1656 }
1657
1658 #[test]
1659 fn message_part_image_roundtrip_json() {
1660 let part = MessagePart::Image(Box::new(ImageData {
1661 data: vec![1, 2, 3, 4],
1662 mime_type: "image/jpeg".into(),
1663 }));
1664 let json = serde_json::to_string(&part).unwrap();
1665 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1666 match decoded {
1667 MessagePart::Image(img) => {
1668 assert_eq!(img.data, vec![1, 2, 3, 4]);
1669 assert_eq!(img.mime_type, "image/jpeg");
1670 }
1671 _ => panic!("expected Image variant"),
1672 }
1673 }
1674
1675 #[test]
1676 fn flatten_parts_includes_image_placeholder() {
1677 let msg = Message::from_parts(
1678 Role::User,
1679 vec![
1680 MessagePart::Text {
1681 text: "see this".into(),
1682 },
1683 MessagePart::Image(Box::new(ImageData {
1684 data: vec![0u8; 100],
1685 mime_type: "image/png".into(),
1686 })),
1687 ],
1688 );
1689 let content = msg.to_llm_content();
1690 assert!(content.contains("see this"));
1691 assert!(content.contains("[image: image/png"));
1692 }
1693
1694 #[test]
1695 fn supports_vision_default_false() {
1696 let provider = StubProvider {
1697 response: String::new(),
1698 };
1699 assert!(!provider.supports_vision());
1700 }
1701
1702 #[test]
1703 fn message_metadata_default_both_visible() {
1704 let m = MessageMetadata::default();
1705 assert!(m.visibility.is_agent_visible());
1706 assert!(m.visibility.is_user_visible());
1707 assert_eq!(m.visibility, MessageVisibility::Both);
1708 assert!(m.compacted_at.is_none());
1709 }
1710
1711 #[test]
1712 fn message_metadata_agent_only() {
1713 let m = MessageMetadata::agent_only();
1714 assert!(m.visibility.is_agent_visible());
1715 assert!(!m.visibility.is_user_visible());
1716 assert_eq!(m.visibility, MessageVisibility::AgentOnly);
1717 }
1718
1719 #[test]
1720 fn message_metadata_user_only() {
1721 let m = MessageMetadata::user_only();
1722 assert!(!m.visibility.is_agent_visible());
1723 assert!(m.visibility.is_user_visible());
1724 assert_eq!(m.visibility, MessageVisibility::UserOnly);
1725 }
1726
1727 #[test]
1728 fn message_metadata_serde_default() {
1729 let json = r#"{"role":"user","content":"hello"}"#;
1730 let msg: Message = serde_json::from_str(json).unwrap();
1731 assert!(msg.metadata.visibility.is_agent_visible());
1732 assert!(msg.metadata.visibility.is_user_visible());
1733 }
1734
1735 #[test]
1736 fn message_metadata_round_trip() {
1737 let msg = Message {
1738 role: Role::User,
1739 content: "test".into(),
1740 parts: vec![],
1741 metadata: MessageMetadata::agent_only(),
1742 };
1743 let json = serde_json::to_string(&msg).unwrap();
1744 let decoded: Message = serde_json::from_str(&json).unwrap();
1745 assert!(decoded.metadata.visibility.is_agent_visible());
1746 assert!(!decoded.metadata.visibility.is_user_visible());
1747 assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
1748 }
1749
1750 #[test]
1751 fn message_part_compaction_round_trip() {
1752 let part = MessagePart::Compaction {
1753 summary: "Context was summarized.".to_owned(),
1754 };
1755 let json = serde_json::to_string(&part).unwrap();
1756 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1757 assert!(
1758 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1759 );
1760 }
1761
1762 #[test]
1763 fn flatten_parts_compaction_contributes_no_text() {
1764 let parts = vec![
1767 MessagePart::Text {
1768 text: "Hello".to_owned(),
1769 },
1770 MessagePart::Compaction {
1771 summary: "Summary".to_owned(),
1772 },
1773 ];
1774 let msg = Message::from_parts(Role::Assistant, parts);
1775 assert_eq!(msg.content.trim(), "Hello");
1777 }
1778
1779 #[test]
1780 fn stream_chunk_compaction_variant() {
1781 let chunk = StreamChunk::Compaction("A summary".to_owned());
1782 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1783 }
1784
1785 #[test]
1786 fn short_type_name_extracts_last_segment() {
1787 struct MyOutput;
1788 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1789 }
1790
1791 #[test]
1792 fn short_type_name_primitive_returns_full_name() {
1793 assert_eq!(short_type_name::<u32>(), "u32");
1795 assert_eq!(short_type_name::<bool>(), "bool");
1796 }
1797
1798 #[test]
1799 fn short_type_name_nested_path_returns_last() {
1800 assert_eq!(
1802 short_type_name::<std::collections::HashMap<u32, u32>>(),
1803 "HashMap<u32, u32>"
1804 );
1805 }
1806
1807 #[test]
1810 fn summary_roundtrip() {
1811 let part = MessagePart::Summary {
1812 text: "hello".to_string(),
1813 };
1814 let json = serde_json::to_string(&part).expect("serialization must not fail");
1815 assert!(
1816 json.contains("\"kind\":\"summary\""),
1817 "must use internally-tagged format, got: {json}"
1818 );
1819 assert!(
1820 !json.contains("\"Summary\""),
1821 "must not use externally-tagged format, got: {json}"
1822 );
1823 let decoded: MessagePart =
1824 serde_json::from_str(&json).expect("deserialization must not fail");
1825 match decoded {
1826 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1827 other => panic!("expected MessagePart::Summary, got {other:?}"),
1828 }
1829 }
1830
1831 #[tokio::test]
1832 async fn embed_batch_default_empty_returns_empty() {
1833 let provider = StubProvider {
1834 response: String::new(),
1835 };
1836 let result = provider.embed_batch(&[]).await.unwrap();
1837 assert!(result.is_empty());
1838 }
1839
1840 #[tokio::test]
1841 async fn embed_batch_default_calls_embed_sequentially() {
1842 let provider = StubProvider {
1843 response: String::new(),
1844 };
1845 let texts = ["hello", "world", "foo"];
1846 let result = provider.embed_batch(&texts).await.unwrap();
1847 assert_eq!(result.len(), 3);
1848 for vec in &result {
1850 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1851 }
1852 }
1853
1854 #[test]
1855 fn message_visibility_db_roundtrip_both() {
1856 assert_eq!(MessageVisibility::Both.as_db_str(), "both");
1857 assert_eq!(
1858 MessageVisibility::from_db_str("both"),
1859 MessageVisibility::Both
1860 );
1861 }
1862
1863 #[test]
1864 fn message_visibility_db_roundtrip_agent_only() {
1865 assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
1866 assert_eq!(
1867 MessageVisibility::from_db_str("agent_only"),
1868 MessageVisibility::AgentOnly
1869 );
1870 }
1871
1872 #[test]
1873 fn message_visibility_db_roundtrip_user_only() {
1874 assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
1875 assert_eq!(
1876 MessageVisibility::from_db_str("user_only"),
1877 MessageVisibility::UserOnly
1878 );
1879 }
1880
1881 #[test]
1882 fn message_visibility_from_db_str_unknown_defaults_to_both() {
1883 assert_eq!(
1884 MessageVisibility::from_db_str("unknown_future_value"),
1885 MessageVisibility::Both
1886 );
1887 assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
1888 }
1889}