1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use zeph_common::ToolName;
16
17pub use zeph_common::ToolDefinition;
18
19use crate::embed::owned_strs;
20use crate::error::LlmError;
21
22static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
31-> Result<(serde_json::Value, String), crate::LlmError> {
32 let type_id = TypeId::of::<T>();
33 if let Ok(cache) = SCHEMA_CACHE.lock()
34 && let Some(entry) = cache.get(&type_id)
35 {
36 return Ok(entry.clone());
37 }
38 let schema = schemars::schema_for!(T);
39 let value = serde_json::to_value(&schema)
40 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
41 let pretty = serde_json::to_string_pretty(&schema)
42 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
43 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
44 cache.insert(type_id, (value.clone(), pretty.clone()));
45 }
46 Ok((value, pretty))
47}
48
49pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
63 std::any::type_name::<T>()
64 .rsplit("::")
65 .next()
66 .unwrap_or("Output")
67}
68
69#[non_exhaustive]
78#[derive(Debug, Clone, Default)]
79pub struct ChatExtras {
80 pub entropy: Option<f64>,
85}
86
87impl ChatExtras {
88 #[must_use]
101 pub fn with_entropy(entropy: f64) -> Self {
102 Self {
103 entropy: Some(entropy),
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
113pub enum StreamChunk {
114 Content(String),
116 Thinking(String),
118 Compaction(String),
121 ToolUse(Vec<ToolUseRequest>),
123}
124
125pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ToolUseRequest {
138 pub id: String,
140 pub name: ToolName,
142 pub input: serde_json::Value,
144}
145
146#[derive(Debug, Clone)]
152pub enum ThinkingBlock {
153 Thinking { thinking: String, signature: String },
155 Redacted { data: String },
157}
158
159pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
162
163#[derive(Debug, Clone)]
171pub enum ChatResponse {
172 Text(String),
174 ToolUse {
176 text: Option<String>,
178 tool_calls: Vec<ToolUseRequest>,
179 thinking_blocks: Vec<ThinkingBlock>,
182 },
183}
184
185pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
187
188pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
193
194pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
200
201#[must_use]
204pub fn default_debug_request_json(
205 messages: &[Message],
206 tools: &[ToolDefinition],
207) -> serde_json::Value {
208 serde_json::json!({
209 "model": serde_json::Value::Null,
210 "max_tokens": serde_json::Value::Null,
211 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
212 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
213 "temperature": serde_json::Value::Null,
214 "cache_control": serde_json::Value::Null,
215 })
216}
217
218#[derive(Debug, Clone, Default)]
227pub struct GenerationOverrides {
228 pub temperature: Option<f64>,
230 pub top_p: Option<f64>,
232 pub top_k: Option<usize>,
234 pub frequency_penalty: Option<f64>,
236 pub presence_penalty: Option<f64>,
238}
239
240#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
247#[serde(rename_all = "lowercase")]
248pub enum Role {
249 System,
250 User,
251 Assistant,
252}
253
254#[derive(Clone, Debug, Serialize, Deserialize)]
269#[serde(tag = "kind", rename_all = "snake_case")]
270pub enum MessagePart {
271 Text { text: String },
273 ToolOutput {
275 tool_name: zeph_common::ToolName,
276 body: String,
277 #[serde(default, skip_serializing_if = "Option::is_none")]
278 compacted_at: Option<i64>,
279 },
280 Recall { text: String },
282 CodeContext { text: String },
284 Summary { text: String },
286 CrossSession { text: String },
288 ToolUse {
290 id: String,
291 name: String,
292 input: serde_json::Value,
293 },
294 ToolResult {
296 tool_use_id: String,
297 content: String,
298 #[serde(default)]
299 is_error: bool,
300 },
301 Image(Box<ImageData>),
303 ThinkingBlock { thinking: String, signature: String },
305 RedactedThinkingBlock { data: String },
307 Compaction { summary: String },
310}
311
312impl MessagePart {
313 #[must_use]
316 pub fn as_plain_text(&self) -> Option<&str> {
317 match self {
318 Self::Text { text }
319 | Self::Recall { text }
320 | Self::CodeContext { text }
321 | Self::Summary { text }
322 | Self::CrossSession { text } => Some(text.as_str()),
323 _ => None,
324 }
325 }
326
327 #[must_use]
329 pub fn as_image(&self) -> Option<&ImageData> {
330 if let Self::Image(img) = self {
331 Some(img)
332 } else {
333 None
334 }
335 }
336}
337
338#[derive(Clone, Debug, Serialize, Deserialize)]
339pub struct ImageData {
344 #[serde(with = "serde_bytes_base64")]
345 pub data: Vec<u8>,
346 pub mime_type: String,
347}
348
349mod serde_bytes_base64 {
350 use base64::{Engine, engine::general_purpose::STANDARD};
351 use serde::{Deserialize, Deserializer, Serializer};
352
353 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
354 where
355 S: Serializer,
356 {
357 s.serialize_str(&STANDARD.encode(bytes))
358 }
359
360 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
361 where
362 D: Deserializer<'de>,
363 {
364 let s = String::deserialize(d)?;
365 STANDARD.decode(&s).map_err(serde::de::Error::custom)
366 }
367}
368
369#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
385#[serde(rename_all = "snake_case")]
386pub enum MessageVisibility {
387 Both,
389 AgentOnly,
391 UserOnly,
393}
394
395impl MessageVisibility {
396 #[must_use]
398 pub fn is_agent_visible(self) -> bool {
399 matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
400 }
401
402 #[must_use]
404 pub fn is_user_visible(self) -> bool {
405 matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
406 }
407}
408
409impl Default for MessageVisibility {
410 fn default() -> Self {
412 MessageVisibility::Both
413 }
414}
415
416impl MessageVisibility {
417 #[must_use]
419 pub fn as_db_str(self) -> &'static str {
420 match self {
421 MessageVisibility::Both => "both",
422 MessageVisibility::AgentOnly => "agent_only",
423 MessageVisibility::UserOnly => "user_only",
424 }
425 }
426
427 #[must_use]
431 pub fn from_db_str(s: &str) -> Self {
432 match s {
433 "agent_only" => MessageVisibility::AgentOnly,
434 "user_only" => MessageVisibility::UserOnly,
435 _ => MessageVisibility::Both,
436 }
437 }
438}
439
440#[derive(Clone, Debug, Serialize, Deserialize)]
445pub struct MessageMetadata {
446 pub visibility: MessageVisibility,
448 #[serde(default, skip_serializing_if = "Option::is_none")]
450 pub compacted_at: Option<i64>,
451 #[serde(default, skip_serializing_if = "Option::is_none")]
454 pub deferred_summary: Option<String>,
455 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
458 pub focus_pinned: bool,
459 #[serde(default, skip_serializing_if = "Option::is_none")]
462 pub focus_marker_id: Option<uuid::Uuid>,
463 #[serde(skip)]
466 pub db_id: Option<i64>,
467}
468
469impl Default for MessageMetadata {
470 fn default() -> Self {
471 Self {
472 visibility: MessageVisibility::Both,
473 compacted_at: None,
474 deferred_summary: None,
475 focus_pinned: false,
476 focus_marker_id: None,
477 db_id: None,
478 }
479 }
480}
481
482impl MessageMetadata {
483 #[must_use]
485 pub fn agent_only() -> Self {
486 Self {
487 visibility: MessageVisibility::AgentOnly,
488 compacted_at: None,
489 deferred_summary: None,
490 focus_pinned: false,
491 focus_marker_id: None,
492 db_id: None,
493 }
494 }
495
496 #[must_use]
498 pub fn user_only() -> Self {
499 Self {
500 visibility: MessageVisibility::UserOnly,
501 compacted_at: None,
502 deferred_summary: None,
503 focus_pinned: false,
504 focus_marker_id: None,
505 db_id: None,
506 }
507 }
508
509 #[must_use]
511 pub fn focus_pinned() -> Self {
512 Self {
513 visibility: MessageVisibility::AgentOnly,
514 compacted_at: None,
515 deferred_summary: None,
516 focus_pinned: true,
517 focus_marker_id: None,
518 db_id: None,
519 }
520 }
521}
522
523#[derive(Clone, Debug, Serialize, Deserialize)]
550pub struct Message {
551 pub role: Role,
552 pub content: String,
554 #[serde(default)]
555 pub parts: Vec<MessagePart>,
556 #[serde(default)]
557 pub metadata: MessageMetadata,
558}
559
560impl Default for Message {
561 fn default() -> Self {
562 Self {
563 role: Role::User,
564 content: String::new(),
565 parts: vec![],
566 metadata: MessageMetadata::default(),
567 }
568 }
569}
570
571impl Message {
572 #[must_use]
577 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
578 Self {
579 role,
580 content: content.into(),
581 parts: vec![],
582 metadata: MessageMetadata::default(),
583 }
584 }
585
586 #[must_use]
591 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
592 let content = Self::flatten_parts(&parts);
593 Self {
594 role,
595 content,
596 parts,
597 metadata: MessageMetadata::default(),
598 }
599 }
600
601 #[must_use]
604 pub fn to_llm_content(&self) -> &str {
605 &self.content
606 }
607
608 pub fn rebuild_content(&mut self) {
610 if !self.parts.is_empty() {
611 self.content = Self::flatten_parts(&self.parts);
612 }
613 }
614
615 fn flatten_parts(parts: &[MessagePart]) -> String {
616 use std::fmt::Write;
617 let mut out = String::new();
618 for part in parts {
619 match part {
620 MessagePart::Text { text }
621 | MessagePart::Recall { text }
622 | MessagePart::CodeContext { text }
623 | MessagePart::Summary { text }
624 | MessagePart::CrossSession { text } => out.push_str(text),
625 MessagePart::ToolOutput {
626 tool_name,
627 body,
628 compacted_at,
629 } => {
630 if compacted_at.is_some() {
631 if body.is_empty() {
632 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
633 } else {
634 let _ = write!(out, "[tool output: {tool_name}] {body}");
635 }
636 } else {
637 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
638 }
639 }
640 MessagePart::ToolUse { id, name, .. } => {
641 let _ = write!(out, "[tool_use: {name}({id})]");
642 }
643 MessagePart::ToolResult {
644 tool_use_id,
645 content,
646 ..
647 } => {
648 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
649 }
650 MessagePart::Image(img) => {
651 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
652 }
653 MessagePart::ThinkingBlock { .. }
655 | MessagePart::RedactedThinkingBlock { .. }
656 | MessagePart::Compaction { .. } => {}
657 }
658 }
659 out
660 }
661}
662
663pub trait LlmProvider: Send + Sync {
731 fn context_window(&self) -> Option<usize> {
735 None
736 }
737
738 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
744
745 fn chat_stream(
751 &self,
752 messages: &[Message],
753 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
754
755 fn supports_streaming(&self) -> bool;
757
758 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
764
765 fn embed_batch(
775 &self,
776 texts: &[&str],
777 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
778 let owned = owned_strs(texts);
779 async move {
780 let mut results = Vec::with_capacity(owned.len());
781 for text in &owned {
782 results.push(self.embed(text).await?);
783 }
784 Ok(results)
785 }
786 }
787
788 fn supports_embeddings(&self) -> bool;
790
791 fn name(&self) -> &str;
793
794 #[allow(clippy::unnecessary_literal_bound)]
797 fn model_identifier(&self) -> &str {
798 ""
799 }
800
801 fn supports_vision(&self) -> bool {
803 false
804 }
805
806 fn supports_tool_use(&self) -> bool {
808 true
809 }
810
811 fn chat_with_tools(
819 &self,
820 messages: &[Message],
821 _tools: &[ToolDefinition],
822 ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
823 let msgs = messages.to_vec();
824 async move { Ok(ChatResponse::Text(self.chat(&msgs).await?)) }
825 }
826
827 fn last_cache_usage(&self) -> Option<(u64, u64)> {
830 None
831 }
832
833 fn last_usage(&self) -> Option<(u64, u64)> {
836 None
837 }
838
839 fn last_reasoning_tokens(&self) -> Option<u64> {
844 None
845 }
846
847 fn take_compaction_summary(&self) -> Option<String> {
850 None
851 }
852
853 fn chat_with_extras(
868 &self,
869 messages: &[Message],
870 ) -> impl Future<Output = Result<(String, ChatExtras), LlmError>> + Send {
871 let msgs = messages.to_vec();
872 async move { Ok((self.chat(&msgs).await?, ChatExtras::default())) }
873 }
874
875 #[must_use]
879 fn debug_request_json(
880 &self,
881 messages: &[Message],
882 tools: &[ToolDefinition],
883 _stream: bool,
884 ) -> serde_json::Value {
885 default_debug_request_json(messages, tools)
886 }
887
888 fn list_models(&self) -> Vec<String> {
891 vec![]
892 }
893
894 fn supports_structured_output(&self) -> bool {
896 false
897 }
898
899 #[allow(async_fn_in_trait)]
910 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
911 where
912 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
913 Self: Sized,
914 {
915 let (_, schema_json) = cached_schema::<T>()?;
916 let type_name = short_type_name::<T>();
917
918 let mut augmented = messages.to_vec();
919 let instruction = format!(
920 "Respond with a valid JSON object matching this schema. \
921 Output ONLY the JSON, no markdown fences or extra text.\n\n\
922 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
923 );
924 augmented.insert(0, Message::from_legacy(Role::System, instruction));
925
926 let raw = self.chat(&augmented).await?;
927 let cleaned = strip_json_fences(&raw);
928 match serde_json::from_str::<T>(cleaned) {
929 Ok(val) => Ok(val),
930 Err(first_err) => {
931 augmented.push(Message::from_legacy(Role::Assistant, &raw));
932 augmented.push(Message::from_legacy(
933 Role::User,
934 format!(
935 "Your response was not valid JSON. Error: {first_err}. \
936 Please output ONLY valid JSON matching the schema."
937 ),
938 ));
939 let retry_raw = self.chat(&augmented).await?;
940 let retry_cleaned = strip_json_fences(&retry_raw);
941 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
942 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
943 })
944 }
945 }
946 }
947}
948
949fn strip_json_fences(s: &str) -> &str {
953 s.trim()
954 .trim_start_matches("```json")
955 .trim_start_matches("```")
956 .trim_end_matches("```")
957 .trim()
958}
959
960#[cfg(test)]
961mod tests {
962 use tokio_stream::StreamExt;
963
964 use super::*;
965
966 struct StubProvider {
967 response: String,
968 }
969
970 impl LlmProvider for StubProvider {
971 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
972 Ok(self.response.clone())
973 }
974
975 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
976 let response = self.chat(messages).await?;
977 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
978 response,
979 )))))
980 }
981
982 fn supports_streaming(&self) -> bool {
983 false
984 }
985
986 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
987 Ok(vec![0.1, 0.2, 0.3])
988 }
989
990 fn supports_embeddings(&self) -> bool {
991 false
992 }
993
994 fn name(&self) -> &'static str {
995 "stub"
996 }
997 }
998
999 #[test]
1000 fn context_window_default_returns_none() {
1001 let provider = StubProvider {
1002 response: String::new(),
1003 };
1004 assert!(provider.context_window().is_none());
1005 }
1006
1007 #[test]
1008 fn supports_streaming_default_returns_false() {
1009 let provider = StubProvider {
1010 response: String::new(),
1011 };
1012 assert!(!provider.supports_streaming());
1013 }
1014
1015 #[tokio::test]
1016 async fn chat_stream_default_yields_single_chunk() {
1017 let provider = StubProvider {
1018 response: "hello world".into(),
1019 };
1020 let messages = vec![Message {
1021 role: Role::User,
1022 content: "test".into(),
1023 parts: vec![],
1024 metadata: MessageMetadata::default(),
1025 }];
1026
1027 let mut stream = provider.chat_stream(&messages).await.unwrap();
1028 let chunk = stream.next().await.unwrap().unwrap();
1029 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
1030 assert!(stream.next().await.is_none());
1031 }
1032
1033 #[tokio::test]
1034 async fn chat_stream_default_propagates_chat_error() {
1035 struct FailProvider;
1036
1037 impl LlmProvider for FailProvider {
1038 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1039 Err(LlmError::Unavailable)
1040 }
1041
1042 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1043 let response = self.chat(messages).await?;
1044 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1045 response,
1046 )))))
1047 }
1048
1049 fn supports_streaming(&self) -> bool {
1050 false
1051 }
1052
1053 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1054 Err(LlmError::Unavailable)
1055 }
1056
1057 fn supports_embeddings(&self) -> bool {
1058 false
1059 }
1060
1061 fn name(&self) -> &'static str {
1062 "fail"
1063 }
1064 }
1065
1066 let provider = FailProvider;
1067 let messages = vec![Message {
1068 role: Role::User,
1069 content: "test".into(),
1070 parts: vec![],
1071 metadata: MessageMetadata::default(),
1072 }];
1073
1074 let result = provider.chat_stream(&messages).await;
1075 assert!(result.is_err());
1076 if let Err(e) = result {
1077 assert!(e.to_string().contains("provider unavailable"));
1078 }
1079 }
1080
1081 #[tokio::test]
1082 async fn stub_provider_embed_returns_vector() {
1083 let provider = StubProvider {
1084 response: String::new(),
1085 };
1086 let embedding = provider.embed("test").await.unwrap();
1087 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
1088 }
1089
1090 #[tokio::test]
1091 async fn fail_provider_embed_propagates_error() {
1092 struct FailProvider;
1093
1094 impl LlmProvider for FailProvider {
1095 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1096 Err(LlmError::Unavailable)
1097 }
1098
1099 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1100 let response = self.chat(messages).await?;
1101 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1102 response,
1103 )))))
1104 }
1105
1106 fn supports_streaming(&self) -> bool {
1107 false
1108 }
1109
1110 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1111 Err(LlmError::EmbedUnsupported {
1112 provider: "fail".into(),
1113 })
1114 }
1115
1116 fn supports_embeddings(&self) -> bool {
1117 false
1118 }
1119
1120 fn name(&self) -> &'static str {
1121 "fail"
1122 }
1123 }
1124
1125 let provider = FailProvider;
1126 let result = provider.embed("test").await;
1127 assert!(result.is_err());
1128 assert!(
1129 result
1130 .unwrap_err()
1131 .to_string()
1132 .contains("embedding not supported")
1133 );
1134 }
1135
1136 #[test]
1137 fn role_serialization() {
1138 let system = Role::System;
1139 let user = Role::User;
1140 let assistant = Role::Assistant;
1141
1142 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
1143 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
1144 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
1145 }
1146
1147 #[test]
1148 fn role_deserialization() {
1149 let system: Role = serde_json::from_str("\"system\"").unwrap();
1150 let user: Role = serde_json::from_str("\"user\"").unwrap();
1151 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
1152
1153 assert_eq!(system, Role::System);
1154 assert_eq!(user, Role::User);
1155 assert_eq!(assistant, Role::Assistant);
1156 }
1157
1158 #[test]
1159 fn message_clone() {
1160 let msg = Message {
1161 role: Role::User,
1162 content: "test".into(),
1163 parts: vec![],
1164 metadata: MessageMetadata::default(),
1165 };
1166 let cloned = msg.clone();
1167 assert_eq!(cloned.role, msg.role);
1168 assert_eq!(cloned.content, msg.content);
1169 }
1170
1171 #[test]
1172 fn message_debug() {
1173 let msg = Message {
1174 role: Role::Assistant,
1175 content: "response".into(),
1176 parts: vec![],
1177 metadata: MessageMetadata::default(),
1178 };
1179 let debug = format!("{msg:?}");
1180 assert!(debug.contains("Assistant"));
1181 assert!(debug.contains("response"));
1182 }
1183
1184 #[test]
1185 fn message_serialization() {
1186 let msg = Message {
1187 role: Role::User,
1188 content: "hello".into(),
1189 parts: vec![],
1190 metadata: MessageMetadata::default(),
1191 };
1192 let json = serde_json::to_string(&msg).unwrap();
1193 assert!(json.contains("\"role\":\"user\""));
1194 assert!(json.contains("\"content\":\"hello\""));
1195 }
1196
1197 #[test]
1198 fn message_part_serde_round_trip() {
1199 let parts = vec![
1200 MessagePart::Text {
1201 text: "hello".into(),
1202 },
1203 MessagePart::ToolOutput {
1204 tool_name: "bash".into(),
1205 body: "output".into(),
1206 compacted_at: None,
1207 },
1208 MessagePart::Recall {
1209 text: "recall".into(),
1210 },
1211 MessagePart::CodeContext {
1212 text: "code".into(),
1213 },
1214 MessagePart::Summary {
1215 text: "summary".into(),
1216 },
1217 ];
1218 let json = serde_json::to_string(&parts).unwrap();
1219 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
1220 assert_eq!(deserialized.len(), 5);
1221 }
1222
1223 #[test]
1224 fn from_legacy_creates_empty_parts() {
1225 let msg = Message::from_legacy(Role::User, "hello");
1226 assert_eq!(msg.role, Role::User);
1227 assert_eq!(msg.content, "hello");
1228 assert!(msg.parts.is_empty());
1229 assert_eq!(msg.to_llm_content(), "hello");
1230 }
1231
1232 #[test]
1233 fn from_parts_flattens_content() {
1234 let msg = Message::from_parts(
1235 Role::System,
1236 vec![MessagePart::Recall {
1237 text: "recalled data".into(),
1238 }],
1239 );
1240 assert_eq!(msg.content, "recalled data");
1241 assert_eq!(msg.to_llm_content(), "recalled data");
1242 assert_eq!(msg.parts.len(), 1);
1243 }
1244
1245 #[test]
1246 fn from_parts_tool_output_format() {
1247 let msg = Message::from_parts(
1248 Role::User,
1249 vec![MessagePart::ToolOutput {
1250 tool_name: "bash".into(),
1251 body: "hello world".into(),
1252 compacted_at: None,
1253 }],
1254 );
1255 assert!(msg.content.contains("[tool output: bash]"));
1256 assert!(msg.content.contains("hello world"));
1257 }
1258
1259 #[test]
1260 fn message_deserializes_without_parts() {
1261 let json = r#"{"role":"user","content":"hello"}"#;
1262 let msg: Message = serde_json::from_str(json).unwrap();
1263 assert_eq!(msg.content, "hello");
1264 assert!(msg.parts.is_empty());
1265 }
1266
1267 #[test]
1268 fn flatten_skips_compacted_tool_output_empty_body() {
1269 let msg = Message::from_parts(
1271 Role::User,
1272 vec![
1273 MessagePart::Text {
1274 text: "prefix ".into(),
1275 },
1276 MessagePart::ToolOutput {
1277 tool_name: "bash".into(),
1278 body: String::new(),
1279 compacted_at: Some(1234),
1280 },
1281 MessagePart::Text {
1282 text: " suffix".into(),
1283 },
1284 ],
1285 );
1286 assert!(msg.content.contains("(pruned)"));
1287 assert!(msg.content.contains("prefix "));
1288 assert!(msg.content.contains(" suffix"));
1289 }
1290
1291 #[test]
1292 fn flatten_compacted_tool_output_with_reference_renders_body() {
1293 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1295 let msg = Message::from_parts(
1296 Role::User,
1297 vec![MessagePart::ToolOutput {
1298 tool_name: "bash".into(),
1299 body: ref_notice.into(),
1300 compacted_at: Some(1234),
1301 }],
1302 );
1303 assert!(msg.content.contains(ref_notice));
1304 assert!(!msg.content.contains("(pruned)"));
1305 }
1306
1307 #[test]
1308 fn rebuild_content_syncs_after_mutation() {
1309 let mut msg = Message::from_parts(
1310 Role::User,
1311 vec![MessagePart::ToolOutput {
1312 tool_name: "bash".into(),
1313 body: "original".into(),
1314 compacted_at: None,
1315 }],
1316 );
1317 assert!(msg.content.contains("original"));
1318
1319 if let MessagePart::ToolOutput {
1320 ref mut compacted_at,
1321 ref mut body,
1322 ..
1323 } = msg.parts[0]
1324 {
1325 *compacted_at = Some(999);
1326 body.clear(); }
1328 msg.rebuild_content();
1329
1330 assert!(msg.content.contains("(pruned)"));
1331 assert!(!msg.content.contains("original"));
1332 }
1333
1334 #[test]
1335 fn message_part_tool_use_serde_round_trip() {
1336 let part = MessagePart::ToolUse {
1337 id: "toolu_123".into(),
1338 name: "bash".into(),
1339 input: serde_json::json!({"command": "ls"}),
1340 };
1341 let json = serde_json::to_string(&part).unwrap();
1342 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1343 if let MessagePart::ToolUse { id, name, input } = deserialized {
1344 assert_eq!(id, "toolu_123");
1345 assert_eq!(name, "bash");
1346 assert_eq!(input["command"], "ls");
1347 } else {
1348 panic!("expected ToolUse");
1349 }
1350 }
1351
1352 #[test]
1353 fn message_part_tool_result_serde_round_trip() {
1354 let part = MessagePart::ToolResult {
1355 tool_use_id: "toolu_123".into(),
1356 content: "file1.rs\nfile2.rs".into(),
1357 is_error: false,
1358 };
1359 let json = serde_json::to_string(&part).unwrap();
1360 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1361 if let MessagePart::ToolResult {
1362 tool_use_id,
1363 content,
1364 is_error,
1365 } = deserialized
1366 {
1367 assert_eq!(tool_use_id, "toolu_123");
1368 assert_eq!(content, "file1.rs\nfile2.rs");
1369 assert!(!is_error);
1370 } else {
1371 panic!("expected ToolResult");
1372 }
1373 }
1374
1375 #[test]
1376 fn message_part_tool_result_is_error_default() {
1377 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1378 let part: MessagePart = serde_json::from_str(json).unwrap();
1379 if let MessagePart::ToolResult { is_error, .. } = part {
1380 assert!(!is_error);
1381 } else {
1382 panic!("expected ToolResult");
1383 }
1384 }
1385
1386 #[test]
1387 fn chat_response_construction() {
1388 let text = ChatResponse::Text("hello".into());
1389 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1390
1391 let tool_use = ChatResponse::ToolUse {
1392 text: Some("I'll run that".into()),
1393 tool_calls: vec![ToolUseRequest {
1394 id: "1".into(),
1395 name: "bash".into(),
1396 input: serde_json::json!({}),
1397 }],
1398 thinking_blocks: vec![],
1399 };
1400 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1401 }
1402
1403 #[test]
1404 fn flatten_parts_tool_use() {
1405 let msg = Message::from_parts(
1406 Role::Assistant,
1407 vec![MessagePart::ToolUse {
1408 id: "t1".into(),
1409 name: "bash".into(),
1410 input: serde_json::json!({"command": "ls"}),
1411 }],
1412 );
1413 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1414 }
1415
1416 #[test]
1417 fn flatten_parts_tool_result() {
1418 let msg = Message::from_parts(
1419 Role::User,
1420 vec![MessagePart::ToolResult {
1421 tool_use_id: "t1".into(),
1422 content: "output here".into(),
1423 is_error: false,
1424 }],
1425 );
1426 assert!(msg.content.contains("[tool_result: t1]"));
1427 assert!(msg.content.contains("output here"));
1428 }
1429
1430 #[test]
1431 fn tool_definition_serde_round_trip() {
1432 let def = ToolDefinition {
1433 name: "bash".into(),
1434 description: "Execute a shell command".into(),
1435 parameters: serde_json::json!({"type": "object"}),
1436 output_schema: None,
1437 };
1438 let json = serde_json::to_string(&def).unwrap();
1439 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1440 assert_eq!(deserialized.name, "bash");
1441 assert_eq!(deserialized.description, "Execute a shell command");
1442 }
1443
1444 #[tokio::test]
1445 async fn chat_with_tools_default_delegates_to_chat() {
1446 let provider = StubProvider {
1447 response: "hello".into(),
1448 };
1449 let messages = vec![Message::from_legacy(Role::User, "test")];
1450 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1451 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1452 }
1453
1454 #[test]
1455 fn tool_output_compacted_at_serde_default() {
1456 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1457 let part: MessagePart = serde_json::from_str(json).unwrap();
1458 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1459 assert!(compacted_at.is_none());
1460 } else {
1461 panic!("expected ToolOutput");
1462 }
1463 }
1464
1465 #[test]
1468 fn strip_json_fences_plain_json() {
1469 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1470 }
1471
1472 #[test]
1473 fn strip_json_fences_with_json_fence() {
1474 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1475 }
1476
1477 #[test]
1478 fn strip_json_fences_with_plain_fence() {
1479 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1480 }
1481
1482 #[test]
1483 fn strip_json_fences_whitespace() {
1484 assert_eq!(strip_json_fences(" \n "), "");
1485 }
1486
1487 #[test]
1488 fn strip_json_fences_empty() {
1489 assert_eq!(strip_json_fences(""), "");
1490 }
1491
1492 #[test]
1493 fn strip_json_fences_outer_whitespace() {
1494 assert_eq!(
1495 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1496 r#"{"a": 1}"#
1497 );
1498 }
1499
1500 #[test]
1501 fn strip_json_fences_only_opening_fence() {
1502 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1503 }
1504
1505 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1508 struct TestOutput {
1509 value: String,
1510 }
1511
1512 struct SequentialStub {
1513 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1514 }
1515
1516 impl SequentialStub {
1517 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1518 Self {
1519 responses: std::sync::Mutex::new(responses),
1520 }
1521 }
1522 }
1523
1524 impl LlmProvider for SequentialStub {
1525 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1526 let mut responses = self.responses.lock().unwrap();
1527 if responses.is_empty() {
1528 return Err(LlmError::Other("no more responses".into()));
1529 }
1530 responses.remove(0)
1531 }
1532
1533 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1534 let response = self.chat(messages).await?;
1535 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1536 response,
1537 )))))
1538 }
1539
1540 fn supports_streaming(&self) -> bool {
1541 false
1542 }
1543
1544 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1545 Err(LlmError::EmbedUnsupported {
1546 provider: "sequential-stub".into(),
1547 })
1548 }
1549
1550 fn supports_embeddings(&self) -> bool {
1551 false
1552 }
1553
1554 fn name(&self) -> &'static str {
1555 "sequential-stub"
1556 }
1557 }
1558
1559 #[tokio::test]
1560 async fn chat_typed_happy_path() {
1561 let provider = StubProvider {
1562 response: r#"{"value": "hello"}"#.into(),
1563 };
1564 let messages = vec![Message::from_legacy(Role::User, "test")];
1565 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1566 assert_eq!(
1567 result,
1568 TestOutput {
1569 value: "hello".into()
1570 }
1571 );
1572 }
1573
1574 #[tokio::test]
1575 async fn chat_typed_retry_succeeds() {
1576 let provider = SequentialStub::new(vec![
1577 Ok("not valid json".into()),
1578 Ok(r#"{"value": "ok"}"#.into()),
1579 ]);
1580 let messages = vec![Message::from_legacy(Role::User, "test")];
1581 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1582 assert_eq!(result, TestOutput { value: "ok".into() });
1583 }
1584
1585 #[tokio::test]
1586 async fn chat_typed_both_fail() {
1587 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1588 let messages = vec![Message::from_legacy(Role::User, "test")];
1589 let result = provider.chat_typed::<TestOutput>(&messages).await;
1590 let err = result.unwrap_err();
1591 assert!(err.to_string().contains("parse failed after retry"));
1592 }
1593
1594 #[tokio::test]
1595 async fn chat_typed_chat_error_propagates() {
1596 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1597 let messages = vec![Message::from_legacy(Role::User, "test")];
1598 let result = provider.chat_typed::<TestOutput>(&messages).await;
1599 assert!(matches!(result, Err(LlmError::Unavailable)));
1600 }
1601
1602 #[tokio::test]
1603 async fn chat_typed_strips_fences() {
1604 let provider = StubProvider {
1605 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1606 };
1607 let messages = vec![Message::from_legacy(Role::User, "test")];
1608 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1609 assert_eq!(
1610 result,
1611 TestOutput {
1612 value: "fenced".into()
1613 }
1614 );
1615 }
1616
1617 #[test]
1618 fn supports_structured_output_default_false() {
1619 let provider = StubProvider {
1620 response: String::new(),
1621 };
1622 assert!(!provider.supports_structured_output());
1623 }
1624
1625 #[test]
1626 fn structured_parse_error_display() {
1627 let err = LlmError::StructuredParse("test error".into());
1628 assert_eq!(
1629 err.to_string(),
1630 "structured output parse failed: test error"
1631 );
1632 }
1633
1634 #[test]
1635 fn message_part_image_roundtrip_json() {
1636 let part = MessagePart::Image(Box::new(ImageData {
1637 data: vec![1, 2, 3, 4],
1638 mime_type: "image/jpeg".into(),
1639 }));
1640 let json = serde_json::to_string(&part).unwrap();
1641 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1642 match decoded {
1643 MessagePart::Image(img) => {
1644 assert_eq!(img.data, vec![1, 2, 3, 4]);
1645 assert_eq!(img.mime_type, "image/jpeg");
1646 }
1647 _ => panic!("expected Image variant"),
1648 }
1649 }
1650
1651 #[test]
1652 fn flatten_parts_includes_image_placeholder() {
1653 let msg = Message::from_parts(
1654 Role::User,
1655 vec![
1656 MessagePart::Text {
1657 text: "see this".into(),
1658 },
1659 MessagePart::Image(Box::new(ImageData {
1660 data: vec![0u8; 100],
1661 mime_type: "image/png".into(),
1662 })),
1663 ],
1664 );
1665 let content = msg.to_llm_content();
1666 assert!(content.contains("see this"));
1667 assert!(content.contains("[image: image/png"));
1668 }
1669
1670 #[test]
1671 fn supports_vision_default_false() {
1672 let provider = StubProvider {
1673 response: String::new(),
1674 };
1675 assert!(!provider.supports_vision());
1676 }
1677
1678 #[test]
1679 fn message_metadata_default_both_visible() {
1680 let m = MessageMetadata::default();
1681 assert!(m.visibility.is_agent_visible());
1682 assert!(m.visibility.is_user_visible());
1683 assert_eq!(m.visibility, MessageVisibility::Both);
1684 assert!(m.compacted_at.is_none());
1685 }
1686
1687 #[test]
1688 fn message_metadata_agent_only() {
1689 let m = MessageMetadata::agent_only();
1690 assert!(m.visibility.is_agent_visible());
1691 assert!(!m.visibility.is_user_visible());
1692 assert_eq!(m.visibility, MessageVisibility::AgentOnly);
1693 }
1694
1695 #[test]
1696 fn message_metadata_user_only() {
1697 let m = MessageMetadata::user_only();
1698 assert!(!m.visibility.is_agent_visible());
1699 assert!(m.visibility.is_user_visible());
1700 assert_eq!(m.visibility, MessageVisibility::UserOnly);
1701 }
1702
1703 #[test]
1704 fn message_metadata_serde_default() {
1705 let json = r#"{"role":"user","content":"hello"}"#;
1706 let msg: Message = serde_json::from_str(json).unwrap();
1707 assert!(msg.metadata.visibility.is_agent_visible());
1708 assert!(msg.metadata.visibility.is_user_visible());
1709 }
1710
1711 #[test]
1712 fn message_metadata_round_trip() {
1713 let msg = Message {
1714 role: Role::User,
1715 content: "test".into(),
1716 parts: vec![],
1717 metadata: MessageMetadata::agent_only(),
1718 };
1719 let json = serde_json::to_string(&msg).unwrap();
1720 let decoded: Message = serde_json::from_str(&json).unwrap();
1721 assert!(decoded.metadata.visibility.is_agent_visible());
1722 assert!(!decoded.metadata.visibility.is_user_visible());
1723 assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
1724 }
1725
1726 #[test]
1727 fn message_part_compaction_round_trip() {
1728 let part = MessagePart::Compaction {
1729 summary: "Context was summarized.".to_owned(),
1730 };
1731 let json = serde_json::to_string(&part).unwrap();
1732 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1733 assert!(
1734 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1735 );
1736 }
1737
1738 #[test]
1739 fn flatten_parts_compaction_contributes_no_text() {
1740 let parts = vec![
1743 MessagePart::Text {
1744 text: "Hello".to_owned(),
1745 },
1746 MessagePart::Compaction {
1747 summary: "Summary".to_owned(),
1748 },
1749 ];
1750 let msg = Message::from_parts(Role::Assistant, parts);
1751 assert_eq!(msg.content.trim(), "Hello");
1753 }
1754
1755 #[test]
1756 fn stream_chunk_compaction_variant() {
1757 let chunk = StreamChunk::Compaction("A summary".to_owned());
1758 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1759 }
1760
1761 #[test]
1762 fn short_type_name_extracts_last_segment() {
1763 struct MyOutput;
1764 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1765 }
1766
1767 #[test]
1768 fn short_type_name_primitive_returns_full_name() {
1769 assert_eq!(short_type_name::<u32>(), "u32");
1771 assert_eq!(short_type_name::<bool>(), "bool");
1772 }
1773
1774 #[test]
1775 fn short_type_name_nested_path_returns_last() {
1776 assert_eq!(
1778 short_type_name::<std::collections::HashMap<u32, u32>>(),
1779 "HashMap<u32, u32>"
1780 );
1781 }
1782
1783 #[test]
1786 fn summary_roundtrip() {
1787 let part = MessagePart::Summary {
1788 text: "hello".to_string(),
1789 };
1790 let json = serde_json::to_string(&part).expect("serialization must not fail");
1791 assert!(
1792 json.contains("\"kind\":\"summary\""),
1793 "must use internally-tagged format, got: {json}"
1794 );
1795 assert!(
1796 !json.contains("\"Summary\""),
1797 "must not use externally-tagged format, got: {json}"
1798 );
1799 let decoded: MessagePart =
1800 serde_json::from_str(&json).expect("deserialization must not fail");
1801 match decoded {
1802 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1803 other => panic!("expected MessagePart::Summary, got {other:?}"),
1804 }
1805 }
1806
1807 #[tokio::test]
1808 async fn embed_batch_default_empty_returns_empty() {
1809 let provider = StubProvider {
1810 response: String::new(),
1811 };
1812 let result = provider.embed_batch(&[]).await.unwrap();
1813 assert!(result.is_empty());
1814 }
1815
1816 #[tokio::test]
1817 async fn embed_batch_default_calls_embed_sequentially() {
1818 let provider = StubProvider {
1819 response: String::new(),
1820 };
1821 let texts = ["hello", "world", "foo"];
1822 let result = provider.embed_batch(&texts).await.unwrap();
1823 assert_eq!(result.len(), 3);
1824 for vec in &result {
1826 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1827 }
1828 }
1829
1830 #[test]
1831 fn message_visibility_db_roundtrip_both() {
1832 assert_eq!(MessageVisibility::Both.as_db_str(), "both");
1833 assert_eq!(
1834 MessageVisibility::from_db_str("both"),
1835 MessageVisibility::Both
1836 );
1837 }
1838
1839 #[test]
1840 fn message_visibility_db_roundtrip_agent_only() {
1841 assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
1842 assert_eq!(
1843 MessageVisibility::from_db_str("agent_only"),
1844 MessageVisibility::AgentOnly
1845 );
1846 }
1847
1848 #[test]
1849 fn message_visibility_db_roundtrip_user_only() {
1850 assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
1851 assert_eq!(
1852 MessageVisibility::from_db_str("user_only"),
1853 MessageVisibility::UserOnly
1854 );
1855 }
1856
1857 #[test]
1858 fn message_visibility_from_db_str_unknown_defaults_to_both() {
1859 assert_eq!(
1860 MessageVisibility::from_db_str("unknown_future_value"),
1861 MessageVisibility::Both
1862 );
1863 assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
1864 }
1865}