1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use zeph_common::ToolName;
16
17pub use zeph_common::ToolDefinition;
18
19use crate::embed::owned_strs;
20use crate::error::LlmError;
21
22static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
31-> Result<(serde_json::Value, String), crate::LlmError> {
32 let type_id = TypeId::of::<T>();
33 if let Ok(cache) = SCHEMA_CACHE.lock()
34 && let Some(entry) = cache.get(&type_id)
35 {
36 return Ok(entry.clone());
37 }
38 let schema = schemars::schema_for!(T);
39 let value = serde_json::to_value(&schema)
40 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
41 let pretty = serde_json::to_string_pretty(&schema)
42 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
43 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
44 cache.insert(type_id, (value.clone(), pretty.clone()));
45 }
46 Ok((value, pretty))
47}
48
49pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
63 std::any::type_name::<T>()
64 .rsplit("::")
65 .next()
66 .unwrap_or("Output")
67}
68
69#[non_exhaustive]
78#[derive(Debug, Clone, Default)]
79pub struct ChatExtras {
80 pub entropy: Option<f64>,
85}
86
87impl ChatExtras {
88 #[must_use]
101 pub fn with_entropy(entropy: f64) -> Self {
102 Self {
103 entropy: Some(entropy),
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
113pub enum StreamChunk {
114 Content(String),
116 Thinking(String),
118 Compaction(String),
121 ToolUse(Vec<ToolUseRequest>),
123}
124
125pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ToolUseRequest {
138 pub id: String,
140 pub name: ToolName,
142 pub input: serde_json::Value,
144}
145
146#[derive(Debug, Clone)]
152pub enum ThinkingBlock {
153 Thinking { thinking: String, signature: String },
155 Redacted { data: String },
157}
158
159pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
162
163#[derive(Debug, Clone)]
171pub enum ChatResponse {
172 Text(String),
174 ToolUse {
176 text: Option<String>,
178 tool_calls: Vec<ToolUseRequest>,
179 thinking_blocks: Vec<ThinkingBlock>,
182 },
183}
184
185pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
187
188pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
193
194pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
200
201#[must_use]
204pub fn default_debug_request_json(
205 messages: &[Message],
206 tools: &[ToolDefinition],
207) -> serde_json::Value {
208 serde_json::json!({
209 "model": serde_json::Value::Null,
210 "max_tokens": serde_json::Value::Null,
211 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
212 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
213 "temperature": serde_json::Value::Null,
214 "cache_control": serde_json::Value::Null,
215 })
216}
217
218#[derive(Debug, Clone, Default)]
227pub struct GenerationOverrides {
228 pub temperature: Option<f64>,
230 pub top_p: Option<f64>,
232 pub top_k: Option<usize>,
234 pub frequency_penalty: Option<f64>,
236 pub presence_penalty: Option<f64>,
238}
239
240#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
247#[serde(rename_all = "lowercase")]
248pub enum Role {
249 System,
250 User,
251 Assistant,
252}
253
254#[derive(Clone, Debug, Serialize, Deserialize)]
269#[serde(tag = "kind", rename_all = "snake_case")]
270pub enum MessagePart {
271 Text { text: String },
273 ToolOutput {
275 tool_name: zeph_common::ToolName,
276 body: String,
277 #[serde(default, skip_serializing_if = "Option::is_none")]
278 compacted_at: Option<i64>,
279 },
280 Recall { text: String },
282 CodeContext { text: String },
284 Summary { text: String },
286 CrossSession { text: String },
288 ToolUse {
290 id: String,
291 name: String,
292 input: serde_json::Value,
293 },
294 ToolResult {
296 tool_use_id: String,
297 content: String,
298 #[serde(default)]
299 is_error: bool,
300 },
301 Image(Box<ImageData>),
303 ThinkingBlock { thinking: String, signature: String },
305 RedactedThinkingBlock { data: String },
307 Compaction { summary: String },
310}
311
312impl MessagePart {
313 #[must_use]
316 pub fn as_plain_text(&self) -> Option<&str> {
317 match self {
318 Self::Text { text }
319 | Self::Recall { text }
320 | Self::CodeContext { text }
321 | Self::Summary { text }
322 | Self::CrossSession { text } => Some(text.as_str()),
323 _ => None,
324 }
325 }
326
327 #[must_use]
329 pub fn as_image(&self) -> Option<&ImageData> {
330 if let Self::Image(img) = self {
331 Some(img)
332 } else {
333 None
334 }
335 }
336}
337
338#[derive(Clone, Debug, Serialize, Deserialize)]
339pub struct ImageData {
344 #[serde(with = "serde_bytes_base64")]
345 pub data: Vec<u8>,
346 pub mime_type: String,
347}
348
349mod serde_bytes_base64 {
350 use base64::{Engine, engine::general_purpose::STANDARD};
351 use serde::{Deserialize, Deserializer, Serializer};
352
353 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
354 where
355 S: Serializer,
356 {
357 s.serialize_str(&STANDARD.encode(bytes))
358 }
359
360 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
361 where
362 D: Deserializer<'de>,
363 {
364 let s = String::deserialize(d)?;
365 STANDARD.decode(&s).map_err(serde::de::Error::custom)
366 }
367}
368
369#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
385#[serde(rename_all = "snake_case")]
386pub enum MessageVisibility {
387 Both,
389 AgentOnly,
391 UserOnly,
393}
394
395impl MessageVisibility {
396 #[must_use]
398 pub fn is_agent_visible(self) -> bool {
399 matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
400 }
401
402 #[must_use]
404 pub fn is_user_visible(self) -> bool {
405 matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
406 }
407}
408
409impl Default for MessageVisibility {
410 fn default() -> Self {
412 MessageVisibility::Both
413 }
414}
415
416impl MessageVisibility {
417 #[must_use]
419 pub fn as_db_str(self) -> &'static str {
420 match self {
421 MessageVisibility::Both => "both",
422 MessageVisibility::AgentOnly => "agent_only",
423 MessageVisibility::UserOnly => "user_only",
424 }
425 }
426
427 #[must_use]
431 pub fn from_db_str(s: &str) -> Self {
432 match s {
433 "agent_only" => MessageVisibility::AgentOnly,
434 "user_only" => MessageVisibility::UserOnly,
435 _ => MessageVisibility::Both,
436 }
437 }
438}
439
440#[derive(Clone, Debug, Serialize, Deserialize)]
445pub struct MessageMetadata {
446 pub visibility: MessageVisibility,
448 #[serde(default, skip_serializing_if = "Option::is_none")]
450 pub compacted_at: Option<i64>,
451 #[serde(default, skip_serializing_if = "Option::is_none")]
454 pub deferred_summary: Option<String>,
455 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
458 pub focus_pinned: bool,
459 #[serde(default, skip_serializing_if = "Option::is_none")]
462 pub focus_marker_id: Option<uuid::Uuid>,
463 #[serde(skip)]
466 pub db_id: Option<i64>,
467}
468
469impl Default for MessageMetadata {
470 fn default() -> Self {
471 Self {
472 visibility: MessageVisibility::Both,
473 compacted_at: None,
474 deferred_summary: None,
475 focus_pinned: false,
476 focus_marker_id: None,
477 db_id: None,
478 }
479 }
480}
481
482impl MessageMetadata {
483 #[must_use]
485 pub fn agent_only() -> Self {
486 Self {
487 visibility: MessageVisibility::AgentOnly,
488 compacted_at: None,
489 deferred_summary: None,
490 focus_pinned: false,
491 focus_marker_id: None,
492 db_id: None,
493 }
494 }
495
496 #[must_use]
498 pub fn user_only() -> Self {
499 Self {
500 visibility: MessageVisibility::UserOnly,
501 compacted_at: None,
502 deferred_summary: None,
503 focus_pinned: false,
504 focus_marker_id: None,
505 db_id: None,
506 }
507 }
508
509 #[must_use]
511 pub fn focus_pinned() -> Self {
512 Self {
513 visibility: MessageVisibility::AgentOnly,
514 compacted_at: None,
515 deferred_summary: None,
516 focus_pinned: true,
517 focus_marker_id: None,
518 db_id: None,
519 }
520 }
521}
522
523#[derive(Clone, Debug, Serialize, Deserialize)]
550pub struct Message {
551 pub role: Role,
552 pub content: String,
554 #[serde(default)]
555 pub parts: Vec<MessagePart>,
556 #[serde(default)]
557 pub metadata: MessageMetadata,
558}
559
560impl Default for Message {
561 fn default() -> Self {
562 Self {
563 role: Role::User,
564 content: String::new(),
565 parts: vec![],
566 metadata: MessageMetadata::default(),
567 }
568 }
569}
570
571impl Message {
572 #[must_use]
577 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
578 Self {
579 role,
580 content: content.into(),
581 parts: vec![],
582 metadata: MessageMetadata::default(),
583 }
584 }
585
586 #[must_use]
591 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
592 let content = Self::flatten_parts(&parts);
593 Self {
594 role,
595 content,
596 parts,
597 metadata: MessageMetadata::default(),
598 }
599 }
600
601 #[must_use]
604 pub fn to_llm_content(&self) -> &str {
605 &self.content
606 }
607
608 pub fn rebuild_content(&mut self) {
610 if !self.parts.is_empty() {
611 self.content = Self::flatten_parts(&self.parts);
612 }
613 }
614
615 fn flatten_parts(parts: &[MessagePart]) -> String {
616 use std::fmt::Write;
617 let mut out = String::new();
618 for part in parts {
619 match part {
620 MessagePart::Text { text }
621 | MessagePart::Recall { text }
622 | MessagePart::CodeContext { text }
623 | MessagePart::Summary { text }
624 | MessagePart::CrossSession { text } => out.push_str(text),
625 MessagePart::ToolOutput {
626 tool_name,
627 body,
628 compacted_at,
629 } => {
630 if compacted_at.is_some() {
631 if body.is_empty() {
632 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
633 } else {
634 let _ = write!(out, "[tool output: {tool_name}] {body}");
635 }
636 } else {
637 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
638 }
639 }
640 MessagePart::ToolUse { id, name, .. } => {
641 let _ = write!(out, "[tool_use: {name}({id})]");
642 }
643 MessagePart::ToolResult {
644 tool_use_id,
645 content,
646 ..
647 } => {
648 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
649 }
650 MessagePart::Image(img) => {
651 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
652 }
653 MessagePart::ThinkingBlock { .. }
655 | MessagePart::RedactedThinkingBlock { .. }
656 | MessagePart::Compaction { .. } => {}
657 }
658 }
659 out
660 }
661}
662
663pub trait LlmProvider: Send + Sync {
717 fn context_window(&self) -> Option<usize> {
721 None
722 }
723
724 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
730
731 fn chat_stream(
737 &self,
738 messages: &[Message],
739 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
740
741 fn supports_streaming(&self) -> bool;
743
744 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
750
751 fn embed_batch(
761 &self,
762 texts: &[&str],
763 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
764 let owned = owned_strs(texts);
765 async move {
766 let mut results = Vec::with_capacity(owned.len());
767 for text in &owned {
768 results.push(self.embed(text).await?);
769 }
770 Ok(results)
771 }
772 }
773
774 fn supports_embeddings(&self) -> bool;
776
777 fn name(&self) -> &str;
779
780 #[allow(clippy::unnecessary_literal_bound)]
783 fn model_identifier(&self) -> &str {
784 ""
785 }
786
787 fn supports_vision(&self) -> bool {
789 false
790 }
791
792 fn supports_tool_use(&self) -> bool {
794 true
795 }
796
797 #[allow(async_fn_in_trait)]
805 async fn chat_with_tools(
806 &self,
807 messages: &[Message],
808 _tools: &[ToolDefinition],
809 ) -> Result<ChatResponse, LlmError> {
810 Ok(ChatResponse::Text(self.chat(messages).await?))
811 }
812
813 fn last_cache_usage(&self) -> Option<(u64, u64)> {
816 None
817 }
818
819 fn last_usage(&self) -> Option<(u64, u64)> {
822 None
823 }
824
825 fn take_compaction_summary(&self) -> Option<String> {
828 None
829 }
830
831 fn record_quality_outcome(&self, _provider_name: &str, _success: bool) {}
837
838 fn chat_with_extras(
853 &self,
854 messages: &[Message],
855 ) -> impl Future<Output = Result<(String, ChatExtras), LlmError>> + Send {
856 let msgs = messages.to_vec();
857 async move { Ok((self.chat(&msgs).await?, ChatExtras::default())) }
858 }
859
860 #[must_use]
864 fn debug_request_json(
865 &self,
866 messages: &[Message],
867 tools: &[ToolDefinition],
868 _stream: bool,
869 ) -> serde_json::Value {
870 default_debug_request_json(messages, tools)
871 }
872
873 fn list_models(&self) -> Vec<String> {
876 vec![]
877 }
878
879 fn supports_structured_output(&self) -> bool {
881 false
882 }
883
884 #[allow(async_fn_in_trait)]
889 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
890 where
891 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
892 Self: Sized,
893 {
894 let (_, schema_json) = cached_schema::<T>()?;
895 let type_name = short_type_name::<T>();
896
897 let mut augmented = messages.to_vec();
898 let instruction = format!(
899 "Respond with a valid JSON object matching this schema. \
900 Output ONLY the JSON, no markdown fences or extra text.\n\n\
901 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
902 );
903 augmented.insert(0, Message::from_legacy(Role::System, instruction));
904
905 let raw = self.chat(&augmented).await?;
906 let cleaned = strip_json_fences(&raw);
907 match serde_json::from_str::<T>(cleaned) {
908 Ok(val) => Ok(val),
909 Err(first_err) => {
910 augmented.push(Message::from_legacy(Role::Assistant, &raw));
911 augmented.push(Message::from_legacy(
912 Role::User,
913 format!(
914 "Your response was not valid JSON. Error: {first_err}. \
915 Please output ONLY valid JSON matching the schema."
916 ),
917 ));
918 let retry_raw = self.chat(&augmented).await?;
919 let retry_cleaned = strip_json_fences(&retry_raw);
920 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
921 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
922 })
923 }
924 }
925 }
926}
927
928fn strip_json_fences(s: &str) -> &str {
932 s.trim()
933 .trim_start_matches("```json")
934 .trim_start_matches("```")
935 .trim_end_matches("```")
936 .trim()
937}
938
939#[cfg(test)]
940mod tests {
941 use tokio_stream::StreamExt;
942
943 use super::*;
944
945 struct StubProvider {
946 response: String,
947 }
948
949 impl LlmProvider for StubProvider {
950 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
951 Ok(self.response.clone())
952 }
953
954 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
955 let response = self.chat(messages).await?;
956 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
957 response,
958 )))))
959 }
960
961 fn supports_streaming(&self) -> bool {
962 false
963 }
964
965 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
966 Ok(vec![0.1, 0.2, 0.3])
967 }
968
969 fn supports_embeddings(&self) -> bool {
970 false
971 }
972
973 fn name(&self) -> &'static str {
974 "stub"
975 }
976 }
977
978 #[test]
979 fn context_window_default_returns_none() {
980 let provider = StubProvider {
981 response: String::new(),
982 };
983 assert!(provider.context_window().is_none());
984 }
985
986 #[test]
987 fn supports_streaming_default_returns_false() {
988 let provider = StubProvider {
989 response: String::new(),
990 };
991 assert!(!provider.supports_streaming());
992 }
993
994 #[tokio::test]
995 async fn chat_stream_default_yields_single_chunk() {
996 let provider = StubProvider {
997 response: "hello world".into(),
998 };
999 let messages = vec![Message {
1000 role: Role::User,
1001 content: "test".into(),
1002 parts: vec![],
1003 metadata: MessageMetadata::default(),
1004 }];
1005
1006 let mut stream = provider.chat_stream(&messages).await.unwrap();
1007 let chunk = stream.next().await.unwrap().unwrap();
1008 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
1009 assert!(stream.next().await.is_none());
1010 }
1011
1012 #[tokio::test]
1013 async fn chat_stream_default_propagates_chat_error() {
1014 struct FailProvider;
1015
1016 impl LlmProvider for FailProvider {
1017 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1018 Err(LlmError::Unavailable)
1019 }
1020
1021 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1022 let response = self.chat(messages).await?;
1023 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1024 response,
1025 )))))
1026 }
1027
1028 fn supports_streaming(&self) -> bool {
1029 false
1030 }
1031
1032 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1033 Err(LlmError::Unavailable)
1034 }
1035
1036 fn supports_embeddings(&self) -> bool {
1037 false
1038 }
1039
1040 fn name(&self) -> &'static str {
1041 "fail"
1042 }
1043 }
1044
1045 let provider = FailProvider;
1046 let messages = vec![Message {
1047 role: Role::User,
1048 content: "test".into(),
1049 parts: vec![],
1050 metadata: MessageMetadata::default(),
1051 }];
1052
1053 let result = provider.chat_stream(&messages).await;
1054 assert!(result.is_err());
1055 if let Err(e) = result {
1056 assert!(e.to_string().contains("provider unavailable"));
1057 }
1058 }
1059
1060 #[tokio::test]
1061 async fn stub_provider_embed_returns_vector() {
1062 let provider = StubProvider {
1063 response: String::new(),
1064 };
1065 let embedding = provider.embed("test").await.unwrap();
1066 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
1067 }
1068
1069 #[tokio::test]
1070 async fn fail_provider_embed_propagates_error() {
1071 struct FailProvider;
1072
1073 impl LlmProvider for FailProvider {
1074 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1075 Err(LlmError::Unavailable)
1076 }
1077
1078 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1079 let response = self.chat(messages).await?;
1080 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1081 response,
1082 )))))
1083 }
1084
1085 fn supports_streaming(&self) -> bool {
1086 false
1087 }
1088
1089 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1090 Err(LlmError::EmbedUnsupported {
1091 provider: "fail".into(),
1092 })
1093 }
1094
1095 fn supports_embeddings(&self) -> bool {
1096 false
1097 }
1098
1099 fn name(&self) -> &'static str {
1100 "fail"
1101 }
1102 }
1103
1104 let provider = FailProvider;
1105 let result = provider.embed("test").await;
1106 assert!(result.is_err());
1107 assert!(
1108 result
1109 .unwrap_err()
1110 .to_string()
1111 .contains("embedding not supported")
1112 );
1113 }
1114
1115 #[test]
1116 fn role_serialization() {
1117 let system = Role::System;
1118 let user = Role::User;
1119 let assistant = Role::Assistant;
1120
1121 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
1122 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
1123 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
1124 }
1125
1126 #[test]
1127 fn role_deserialization() {
1128 let system: Role = serde_json::from_str("\"system\"").unwrap();
1129 let user: Role = serde_json::from_str("\"user\"").unwrap();
1130 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
1131
1132 assert_eq!(system, Role::System);
1133 assert_eq!(user, Role::User);
1134 assert_eq!(assistant, Role::Assistant);
1135 }
1136
1137 #[test]
1138 fn message_clone() {
1139 let msg = Message {
1140 role: Role::User,
1141 content: "test".into(),
1142 parts: vec![],
1143 metadata: MessageMetadata::default(),
1144 };
1145 let cloned = msg.clone();
1146 assert_eq!(cloned.role, msg.role);
1147 assert_eq!(cloned.content, msg.content);
1148 }
1149
1150 #[test]
1151 fn message_debug() {
1152 let msg = Message {
1153 role: Role::Assistant,
1154 content: "response".into(),
1155 parts: vec![],
1156 metadata: MessageMetadata::default(),
1157 };
1158 let debug = format!("{msg:?}");
1159 assert!(debug.contains("Assistant"));
1160 assert!(debug.contains("response"));
1161 }
1162
1163 #[test]
1164 fn message_serialization() {
1165 let msg = Message {
1166 role: Role::User,
1167 content: "hello".into(),
1168 parts: vec![],
1169 metadata: MessageMetadata::default(),
1170 };
1171 let json = serde_json::to_string(&msg).unwrap();
1172 assert!(json.contains("\"role\":\"user\""));
1173 assert!(json.contains("\"content\":\"hello\""));
1174 }
1175
1176 #[test]
1177 fn message_part_serde_round_trip() {
1178 let parts = vec![
1179 MessagePart::Text {
1180 text: "hello".into(),
1181 },
1182 MessagePart::ToolOutput {
1183 tool_name: "bash".into(),
1184 body: "output".into(),
1185 compacted_at: None,
1186 },
1187 MessagePart::Recall {
1188 text: "recall".into(),
1189 },
1190 MessagePart::CodeContext {
1191 text: "code".into(),
1192 },
1193 MessagePart::Summary {
1194 text: "summary".into(),
1195 },
1196 ];
1197 let json = serde_json::to_string(&parts).unwrap();
1198 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
1199 assert_eq!(deserialized.len(), 5);
1200 }
1201
1202 #[test]
1203 fn from_legacy_creates_empty_parts() {
1204 let msg = Message::from_legacy(Role::User, "hello");
1205 assert_eq!(msg.role, Role::User);
1206 assert_eq!(msg.content, "hello");
1207 assert!(msg.parts.is_empty());
1208 assert_eq!(msg.to_llm_content(), "hello");
1209 }
1210
1211 #[test]
1212 fn from_parts_flattens_content() {
1213 let msg = Message::from_parts(
1214 Role::System,
1215 vec![MessagePart::Recall {
1216 text: "recalled data".into(),
1217 }],
1218 );
1219 assert_eq!(msg.content, "recalled data");
1220 assert_eq!(msg.to_llm_content(), "recalled data");
1221 assert_eq!(msg.parts.len(), 1);
1222 }
1223
1224 #[test]
1225 fn from_parts_tool_output_format() {
1226 let msg = Message::from_parts(
1227 Role::User,
1228 vec![MessagePart::ToolOutput {
1229 tool_name: "bash".into(),
1230 body: "hello world".into(),
1231 compacted_at: None,
1232 }],
1233 );
1234 assert!(msg.content.contains("[tool output: bash]"));
1235 assert!(msg.content.contains("hello world"));
1236 }
1237
1238 #[test]
1239 fn message_deserializes_without_parts() {
1240 let json = r#"{"role":"user","content":"hello"}"#;
1241 let msg: Message = serde_json::from_str(json).unwrap();
1242 assert_eq!(msg.content, "hello");
1243 assert!(msg.parts.is_empty());
1244 }
1245
1246 #[test]
1247 fn flatten_skips_compacted_tool_output_empty_body() {
1248 let msg = Message::from_parts(
1250 Role::User,
1251 vec![
1252 MessagePart::Text {
1253 text: "prefix ".into(),
1254 },
1255 MessagePart::ToolOutput {
1256 tool_name: "bash".into(),
1257 body: String::new(),
1258 compacted_at: Some(1234),
1259 },
1260 MessagePart::Text {
1261 text: " suffix".into(),
1262 },
1263 ],
1264 );
1265 assert!(msg.content.contains("(pruned)"));
1266 assert!(msg.content.contains("prefix "));
1267 assert!(msg.content.contains(" suffix"));
1268 }
1269
1270 #[test]
1271 fn flatten_compacted_tool_output_with_reference_renders_body() {
1272 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1274 let msg = Message::from_parts(
1275 Role::User,
1276 vec![MessagePart::ToolOutput {
1277 tool_name: "bash".into(),
1278 body: ref_notice.into(),
1279 compacted_at: Some(1234),
1280 }],
1281 );
1282 assert!(msg.content.contains(ref_notice));
1283 assert!(!msg.content.contains("(pruned)"));
1284 }
1285
1286 #[test]
1287 fn rebuild_content_syncs_after_mutation() {
1288 let mut msg = Message::from_parts(
1289 Role::User,
1290 vec![MessagePart::ToolOutput {
1291 tool_name: "bash".into(),
1292 body: "original".into(),
1293 compacted_at: None,
1294 }],
1295 );
1296 assert!(msg.content.contains("original"));
1297
1298 if let MessagePart::ToolOutput {
1299 ref mut compacted_at,
1300 ref mut body,
1301 ..
1302 } = msg.parts[0]
1303 {
1304 *compacted_at = Some(999);
1305 body.clear(); }
1307 msg.rebuild_content();
1308
1309 assert!(msg.content.contains("(pruned)"));
1310 assert!(!msg.content.contains("original"));
1311 }
1312
1313 #[test]
1314 fn message_part_tool_use_serde_round_trip() {
1315 let part = MessagePart::ToolUse {
1316 id: "toolu_123".into(),
1317 name: "bash".into(),
1318 input: serde_json::json!({"command": "ls"}),
1319 };
1320 let json = serde_json::to_string(&part).unwrap();
1321 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1322 if let MessagePart::ToolUse { id, name, input } = deserialized {
1323 assert_eq!(id, "toolu_123");
1324 assert_eq!(name, "bash");
1325 assert_eq!(input["command"], "ls");
1326 } else {
1327 panic!("expected ToolUse");
1328 }
1329 }
1330
1331 #[test]
1332 fn message_part_tool_result_serde_round_trip() {
1333 let part = MessagePart::ToolResult {
1334 tool_use_id: "toolu_123".into(),
1335 content: "file1.rs\nfile2.rs".into(),
1336 is_error: false,
1337 };
1338 let json = serde_json::to_string(&part).unwrap();
1339 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1340 if let MessagePart::ToolResult {
1341 tool_use_id,
1342 content,
1343 is_error,
1344 } = deserialized
1345 {
1346 assert_eq!(tool_use_id, "toolu_123");
1347 assert_eq!(content, "file1.rs\nfile2.rs");
1348 assert!(!is_error);
1349 } else {
1350 panic!("expected ToolResult");
1351 }
1352 }
1353
1354 #[test]
1355 fn message_part_tool_result_is_error_default() {
1356 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1357 let part: MessagePart = serde_json::from_str(json).unwrap();
1358 if let MessagePart::ToolResult { is_error, .. } = part {
1359 assert!(!is_error);
1360 } else {
1361 panic!("expected ToolResult");
1362 }
1363 }
1364
1365 #[test]
1366 fn chat_response_construction() {
1367 let text = ChatResponse::Text("hello".into());
1368 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1369
1370 let tool_use = ChatResponse::ToolUse {
1371 text: Some("I'll run that".into()),
1372 tool_calls: vec![ToolUseRequest {
1373 id: "1".into(),
1374 name: "bash".into(),
1375 input: serde_json::json!({}),
1376 }],
1377 thinking_blocks: vec![],
1378 };
1379 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1380 }
1381
1382 #[test]
1383 fn flatten_parts_tool_use() {
1384 let msg = Message::from_parts(
1385 Role::Assistant,
1386 vec![MessagePart::ToolUse {
1387 id: "t1".into(),
1388 name: "bash".into(),
1389 input: serde_json::json!({"command": "ls"}),
1390 }],
1391 );
1392 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1393 }
1394
1395 #[test]
1396 fn flatten_parts_tool_result() {
1397 let msg = Message::from_parts(
1398 Role::User,
1399 vec![MessagePart::ToolResult {
1400 tool_use_id: "t1".into(),
1401 content: "output here".into(),
1402 is_error: false,
1403 }],
1404 );
1405 assert!(msg.content.contains("[tool_result: t1]"));
1406 assert!(msg.content.contains("output here"));
1407 }
1408
1409 #[test]
1410 fn tool_definition_serde_round_trip() {
1411 let def = ToolDefinition {
1412 name: "bash".into(),
1413 description: "Execute a shell command".into(),
1414 parameters: serde_json::json!({"type": "object"}),
1415 output_schema: None,
1416 };
1417 let json = serde_json::to_string(&def).unwrap();
1418 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1419 assert_eq!(deserialized.name, "bash");
1420 assert_eq!(deserialized.description, "Execute a shell command");
1421 }
1422
1423 #[tokio::test]
1424 async fn chat_with_tools_default_delegates_to_chat() {
1425 let provider = StubProvider {
1426 response: "hello".into(),
1427 };
1428 let messages = vec![Message::from_legacy(Role::User, "test")];
1429 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1430 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1431 }
1432
1433 #[test]
1434 fn tool_output_compacted_at_serde_default() {
1435 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1436 let part: MessagePart = serde_json::from_str(json).unwrap();
1437 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1438 assert!(compacted_at.is_none());
1439 } else {
1440 panic!("expected ToolOutput");
1441 }
1442 }
1443
1444 #[test]
1447 fn strip_json_fences_plain_json() {
1448 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1449 }
1450
1451 #[test]
1452 fn strip_json_fences_with_json_fence() {
1453 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1454 }
1455
1456 #[test]
1457 fn strip_json_fences_with_plain_fence() {
1458 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1459 }
1460
1461 #[test]
1462 fn strip_json_fences_whitespace() {
1463 assert_eq!(strip_json_fences(" \n "), "");
1464 }
1465
1466 #[test]
1467 fn strip_json_fences_empty() {
1468 assert_eq!(strip_json_fences(""), "");
1469 }
1470
1471 #[test]
1472 fn strip_json_fences_outer_whitespace() {
1473 assert_eq!(
1474 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1475 r#"{"a": 1}"#
1476 );
1477 }
1478
1479 #[test]
1480 fn strip_json_fences_only_opening_fence() {
1481 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1482 }
1483
1484 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1487 struct TestOutput {
1488 value: String,
1489 }
1490
1491 struct SequentialStub {
1492 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1493 }
1494
1495 impl SequentialStub {
1496 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1497 Self {
1498 responses: std::sync::Mutex::new(responses),
1499 }
1500 }
1501 }
1502
1503 impl LlmProvider for SequentialStub {
1504 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1505 let mut responses = self.responses.lock().unwrap();
1506 if responses.is_empty() {
1507 return Err(LlmError::Other("no more responses".into()));
1508 }
1509 responses.remove(0)
1510 }
1511
1512 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1513 let response = self.chat(messages).await?;
1514 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1515 response,
1516 )))))
1517 }
1518
1519 fn supports_streaming(&self) -> bool {
1520 false
1521 }
1522
1523 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1524 Err(LlmError::EmbedUnsupported {
1525 provider: "sequential-stub".into(),
1526 })
1527 }
1528
1529 fn supports_embeddings(&self) -> bool {
1530 false
1531 }
1532
1533 fn name(&self) -> &'static str {
1534 "sequential-stub"
1535 }
1536 }
1537
1538 #[tokio::test]
1539 async fn chat_typed_happy_path() {
1540 let provider = StubProvider {
1541 response: r#"{"value": "hello"}"#.into(),
1542 };
1543 let messages = vec![Message::from_legacy(Role::User, "test")];
1544 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1545 assert_eq!(
1546 result,
1547 TestOutput {
1548 value: "hello".into()
1549 }
1550 );
1551 }
1552
1553 #[tokio::test]
1554 async fn chat_typed_retry_succeeds() {
1555 let provider = SequentialStub::new(vec![
1556 Ok("not valid json".into()),
1557 Ok(r#"{"value": "ok"}"#.into()),
1558 ]);
1559 let messages = vec![Message::from_legacy(Role::User, "test")];
1560 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1561 assert_eq!(result, TestOutput { value: "ok".into() });
1562 }
1563
1564 #[tokio::test]
1565 async fn chat_typed_both_fail() {
1566 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1567 let messages = vec![Message::from_legacy(Role::User, "test")];
1568 let result = provider.chat_typed::<TestOutput>(&messages).await;
1569 let err = result.unwrap_err();
1570 assert!(err.to_string().contains("parse failed after retry"));
1571 }
1572
1573 #[tokio::test]
1574 async fn chat_typed_chat_error_propagates() {
1575 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1576 let messages = vec![Message::from_legacy(Role::User, "test")];
1577 let result = provider.chat_typed::<TestOutput>(&messages).await;
1578 assert!(matches!(result, Err(LlmError::Unavailable)));
1579 }
1580
1581 #[tokio::test]
1582 async fn chat_typed_strips_fences() {
1583 let provider = StubProvider {
1584 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1585 };
1586 let messages = vec![Message::from_legacy(Role::User, "test")];
1587 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1588 assert_eq!(
1589 result,
1590 TestOutput {
1591 value: "fenced".into()
1592 }
1593 );
1594 }
1595
1596 #[test]
1597 fn supports_structured_output_default_false() {
1598 let provider = StubProvider {
1599 response: String::new(),
1600 };
1601 assert!(!provider.supports_structured_output());
1602 }
1603
1604 #[test]
1605 fn structured_parse_error_display() {
1606 let err = LlmError::StructuredParse("test error".into());
1607 assert_eq!(
1608 err.to_string(),
1609 "structured output parse failed: test error"
1610 );
1611 }
1612
1613 #[test]
1614 fn message_part_image_roundtrip_json() {
1615 let part = MessagePart::Image(Box::new(ImageData {
1616 data: vec![1, 2, 3, 4],
1617 mime_type: "image/jpeg".into(),
1618 }));
1619 let json = serde_json::to_string(&part).unwrap();
1620 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1621 match decoded {
1622 MessagePart::Image(img) => {
1623 assert_eq!(img.data, vec![1, 2, 3, 4]);
1624 assert_eq!(img.mime_type, "image/jpeg");
1625 }
1626 _ => panic!("expected Image variant"),
1627 }
1628 }
1629
1630 #[test]
1631 fn flatten_parts_includes_image_placeholder() {
1632 let msg = Message::from_parts(
1633 Role::User,
1634 vec![
1635 MessagePart::Text {
1636 text: "see this".into(),
1637 },
1638 MessagePart::Image(Box::new(ImageData {
1639 data: vec![0u8; 100],
1640 mime_type: "image/png".into(),
1641 })),
1642 ],
1643 );
1644 let content = msg.to_llm_content();
1645 assert!(content.contains("see this"));
1646 assert!(content.contains("[image: image/png"));
1647 }
1648
1649 #[test]
1650 fn supports_vision_default_false() {
1651 let provider = StubProvider {
1652 response: String::new(),
1653 };
1654 assert!(!provider.supports_vision());
1655 }
1656
1657 #[test]
1658 fn message_metadata_default_both_visible() {
1659 let m = MessageMetadata::default();
1660 assert!(m.visibility.is_agent_visible());
1661 assert!(m.visibility.is_user_visible());
1662 assert_eq!(m.visibility, MessageVisibility::Both);
1663 assert!(m.compacted_at.is_none());
1664 }
1665
1666 #[test]
1667 fn message_metadata_agent_only() {
1668 let m = MessageMetadata::agent_only();
1669 assert!(m.visibility.is_agent_visible());
1670 assert!(!m.visibility.is_user_visible());
1671 assert_eq!(m.visibility, MessageVisibility::AgentOnly);
1672 }
1673
1674 #[test]
1675 fn message_metadata_user_only() {
1676 let m = MessageMetadata::user_only();
1677 assert!(!m.visibility.is_agent_visible());
1678 assert!(m.visibility.is_user_visible());
1679 assert_eq!(m.visibility, MessageVisibility::UserOnly);
1680 }
1681
1682 #[test]
1683 fn message_metadata_serde_default() {
1684 let json = r#"{"role":"user","content":"hello"}"#;
1685 let msg: Message = serde_json::from_str(json).unwrap();
1686 assert!(msg.metadata.visibility.is_agent_visible());
1687 assert!(msg.metadata.visibility.is_user_visible());
1688 }
1689
1690 #[test]
1691 fn message_metadata_round_trip() {
1692 let msg = Message {
1693 role: Role::User,
1694 content: "test".into(),
1695 parts: vec![],
1696 metadata: MessageMetadata::agent_only(),
1697 };
1698 let json = serde_json::to_string(&msg).unwrap();
1699 let decoded: Message = serde_json::from_str(&json).unwrap();
1700 assert!(decoded.metadata.visibility.is_agent_visible());
1701 assert!(!decoded.metadata.visibility.is_user_visible());
1702 assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
1703 }
1704
1705 #[test]
1706 fn message_part_compaction_round_trip() {
1707 let part = MessagePart::Compaction {
1708 summary: "Context was summarized.".to_owned(),
1709 };
1710 let json = serde_json::to_string(&part).unwrap();
1711 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1712 assert!(
1713 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1714 );
1715 }
1716
1717 #[test]
1718 fn flatten_parts_compaction_contributes_no_text() {
1719 let parts = vec![
1722 MessagePart::Text {
1723 text: "Hello".to_owned(),
1724 },
1725 MessagePart::Compaction {
1726 summary: "Summary".to_owned(),
1727 },
1728 ];
1729 let msg = Message::from_parts(Role::Assistant, parts);
1730 assert_eq!(msg.content.trim(), "Hello");
1732 }
1733
1734 #[test]
1735 fn stream_chunk_compaction_variant() {
1736 let chunk = StreamChunk::Compaction("A summary".to_owned());
1737 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1738 }
1739
1740 #[test]
1741 fn short_type_name_extracts_last_segment() {
1742 struct MyOutput;
1743 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1744 }
1745
1746 #[test]
1747 fn short_type_name_primitive_returns_full_name() {
1748 assert_eq!(short_type_name::<u32>(), "u32");
1750 assert_eq!(short_type_name::<bool>(), "bool");
1751 }
1752
1753 #[test]
1754 fn short_type_name_nested_path_returns_last() {
1755 assert_eq!(
1757 short_type_name::<std::collections::HashMap<u32, u32>>(),
1758 "HashMap<u32, u32>"
1759 );
1760 }
1761
1762 #[test]
1765 fn summary_roundtrip() {
1766 let part = MessagePart::Summary {
1767 text: "hello".to_string(),
1768 };
1769 let json = serde_json::to_string(&part).expect("serialization must not fail");
1770 assert!(
1771 json.contains("\"kind\":\"summary\""),
1772 "must use internally-tagged format, got: {json}"
1773 );
1774 assert!(
1775 !json.contains("\"Summary\""),
1776 "must not use externally-tagged format, got: {json}"
1777 );
1778 let decoded: MessagePart =
1779 serde_json::from_str(&json).expect("deserialization must not fail");
1780 match decoded {
1781 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1782 other => panic!("expected MessagePart::Summary, got {other:?}"),
1783 }
1784 }
1785
1786 #[tokio::test]
1787 async fn embed_batch_default_empty_returns_empty() {
1788 let provider = StubProvider {
1789 response: String::new(),
1790 };
1791 let result = provider.embed_batch(&[]).await.unwrap();
1792 assert!(result.is_empty());
1793 }
1794
1795 #[tokio::test]
1796 async fn embed_batch_default_calls_embed_sequentially() {
1797 let provider = StubProvider {
1798 response: String::new(),
1799 };
1800 let texts = ["hello", "world", "foo"];
1801 let result = provider.embed_batch(&texts).await.unwrap();
1802 assert_eq!(result.len(), 3);
1803 for vec in &result {
1805 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1806 }
1807 }
1808
1809 #[test]
1810 fn message_visibility_db_roundtrip_both() {
1811 assert_eq!(MessageVisibility::Both.as_db_str(), "both");
1812 assert_eq!(
1813 MessageVisibility::from_db_str("both"),
1814 MessageVisibility::Both
1815 );
1816 }
1817
1818 #[test]
1819 fn message_visibility_db_roundtrip_agent_only() {
1820 assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
1821 assert_eq!(
1822 MessageVisibility::from_db_str("agent_only"),
1823 MessageVisibility::AgentOnly
1824 );
1825 }
1826
1827 #[test]
1828 fn message_visibility_db_roundtrip_user_only() {
1829 assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
1830 assert_eq!(
1831 MessageVisibility::from_db_str("user_only"),
1832 MessageVisibility::UserOnly
1833 );
1834 }
1835
1836 #[test]
1837 fn message_visibility_from_db_str_unknown_defaults_to_both() {
1838 assert_eq!(
1839 MessageVisibility::from_db_str("unknown_future_value"),
1840 MessageVisibility::Both
1841 );
1842 assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
1843 }
1844}