1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use zeph_common::ToolName;
16
17pub use zeph_common::ToolDefinition;
18
19use crate::embed::owned_strs;
20use crate::error::LlmError;
21
22static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
31-> Result<(serde_json::Value, String), crate::LlmError> {
32 let type_id = TypeId::of::<T>();
33 if let Ok(cache) = SCHEMA_CACHE.lock()
34 && let Some(entry) = cache.get(&type_id)
35 {
36 return Ok(entry.clone());
37 }
38 let schema = schemars::schema_for!(T);
39 let value = serde_json::to_value(&schema)
40 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
41 let pretty = serde_json::to_string_pretty(&schema)
42 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
43 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
44 cache.insert(type_id, (value.clone(), pretty.clone()));
45 }
46 Ok((value, pretty))
47}
48
49pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
63 std::any::type_name::<T>()
64 .rsplit("::")
65 .next()
66 .unwrap_or("Output")
67}
68
69#[derive(Debug, Clone)]
74pub enum StreamChunk {
75 Content(String),
77 Thinking(String),
79 Compaction(String),
82 ToolUse(Vec<ToolUseRequest>),
84}
85
86pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ToolUseRequest {
99 pub id: String,
101 pub name: ToolName,
103 pub input: serde_json::Value,
105}
106
107#[derive(Debug, Clone)]
113pub enum ThinkingBlock {
114 Thinking { thinking: String, signature: String },
116 Redacted { data: String },
118}
119
120pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
123
124#[derive(Debug, Clone)]
132pub enum ChatResponse {
133 Text(String),
135 ToolUse {
137 text: Option<String>,
139 tool_calls: Vec<ToolUseRequest>,
140 thinking_blocks: Vec<ThinkingBlock>,
143 },
144}
145
146pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
148
149pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
154
155pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
161
162#[must_use]
165pub fn default_debug_request_json(
166 messages: &[Message],
167 tools: &[ToolDefinition],
168) -> serde_json::Value {
169 serde_json::json!({
170 "model": serde_json::Value::Null,
171 "max_tokens": serde_json::Value::Null,
172 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
173 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
174 "temperature": serde_json::Value::Null,
175 "cache_control": serde_json::Value::Null,
176 })
177}
178
179#[derive(Debug, Clone, Default)]
188pub struct GenerationOverrides {
189 pub temperature: Option<f64>,
191 pub top_p: Option<f64>,
193 pub top_k: Option<usize>,
195 pub frequency_penalty: Option<f64>,
197 pub presence_penalty: Option<f64>,
199}
200
201#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
208#[serde(rename_all = "lowercase")]
209pub enum Role {
210 System,
211 User,
212 Assistant,
213}
214
215#[derive(Clone, Debug, Serialize, Deserialize)]
230#[serde(tag = "kind", rename_all = "snake_case")]
231pub enum MessagePart {
232 Text { text: String },
234 ToolOutput {
236 tool_name: zeph_common::ToolName,
237 body: String,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 compacted_at: Option<i64>,
240 },
241 Recall { text: String },
243 CodeContext { text: String },
245 Summary { text: String },
247 CrossSession { text: String },
249 ToolUse {
251 id: String,
252 name: String,
253 input: serde_json::Value,
254 },
255 ToolResult {
257 tool_use_id: String,
258 content: String,
259 #[serde(default)]
260 is_error: bool,
261 },
262 Image(Box<ImageData>),
264 ThinkingBlock { thinking: String, signature: String },
266 RedactedThinkingBlock { data: String },
268 Compaction { summary: String },
271}
272
273impl MessagePart {
274 #[must_use]
277 pub fn as_plain_text(&self) -> Option<&str> {
278 match self {
279 Self::Text { text }
280 | Self::Recall { text }
281 | Self::CodeContext { text }
282 | Self::Summary { text }
283 | Self::CrossSession { text } => Some(text.as_str()),
284 _ => None,
285 }
286 }
287
288 #[must_use]
290 pub fn as_image(&self) -> Option<&ImageData> {
291 if let Self::Image(img) = self {
292 Some(img)
293 } else {
294 None
295 }
296 }
297}
298
299#[derive(Clone, Debug, Serialize, Deserialize)]
300pub struct ImageData {
305 #[serde(with = "serde_bytes_base64")]
306 pub data: Vec<u8>,
307 pub mime_type: String,
308}
309
310mod serde_bytes_base64 {
311 use base64::{Engine, engine::general_purpose::STANDARD};
312 use serde::{Deserialize, Deserializer, Serializer};
313
314 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
315 where
316 S: Serializer,
317 {
318 s.serialize_str(&STANDARD.encode(bytes))
319 }
320
321 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
322 where
323 D: Deserializer<'de>,
324 {
325 let s = String::deserialize(d)?;
326 STANDARD.decode(&s).map_err(serde::de::Error::custom)
327 }
328}
329
330#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
346#[serde(rename_all = "snake_case")]
347pub enum MessageVisibility {
348 Both,
350 AgentOnly,
352 UserOnly,
354}
355
356impl MessageVisibility {
357 #[must_use]
359 pub fn is_agent_visible(self) -> bool {
360 matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
361 }
362
363 #[must_use]
365 pub fn is_user_visible(self) -> bool {
366 matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
367 }
368}
369
370impl Default for MessageVisibility {
371 fn default() -> Self {
373 MessageVisibility::Both
374 }
375}
376
377impl MessageVisibility {
378 #[must_use]
380 pub fn as_db_str(self) -> &'static str {
381 match self {
382 MessageVisibility::Both => "both",
383 MessageVisibility::AgentOnly => "agent_only",
384 MessageVisibility::UserOnly => "user_only",
385 }
386 }
387
388 #[must_use]
392 pub fn from_db_str(s: &str) -> Self {
393 match s {
394 "agent_only" => MessageVisibility::AgentOnly,
395 "user_only" => MessageVisibility::UserOnly,
396 _ => MessageVisibility::Both,
397 }
398 }
399}
400
401#[derive(Clone, Debug, Serialize, Deserialize)]
406pub struct MessageMetadata {
407 pub visibility: MessageVisibility,
409 #[serde(default, skip_serializing_if = "Option::is_none")]
411 pub compacted_at: Option<i64>,
412 #[serde(default, skip_serializing_if = "Option::is_none")]
415 pub deferred_summary: Option<String>,
416 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
419 pub focus_pinned: bool,
420 #[serde(default, skip_serializing_if = "Option::is_none")]
423 pub focus_marker_id: Option<uuid::Uuid>,
424 #[serde(skip)]
427 pub db_id: Option<i64>,
428}
429
430impl Default for MessageMetadata {
431 fn default() -> Self {
432 Self {
433 visibility: MessageVisibility::Both,
434 compacted_at: None,
435 deferred_summary: None,
436 focus_pinned: false,
437 focus_marker_id: None,
438 db_id: None,
439 }
440 }
441}
442
443impl MessageMetadata {
444 #[must_use]
446 pub fn agent_only() -> Self {
447 Self {
448 visibility: MessageVisibility::AgentOnly,
449 compacted_at: None,
450 deferred_summary: None,
451 focus_pinned: false,
452 focus_marker_id: None,
453 db_id: None,
454 }
455 }
456
457 #[must_use]
459 pub fn user_only() -> Self {
460 Self {
461 visibility: MessageVisibility::UserOnly,
462 compacted_at: None,
463 deferred_summary: None,
464 focus_pinned: false,
465 focus_marker_id: None,
466 db_id: None,
467 }
468 }
469
470 #[must_use]
472 pub fn focus_pinned() -> Self {
473 Self {
474 visibility: MessageVisibility::AgentOnly,
475 compacted_at: None,
476 deferred_summary: None,
477 focus_pinned: true,
478 focus_marker_id: None,
479 db_id: None,
480 }
481 }
482}
483
484#[derive(Clone, Debug, Serialize, Deserialize)]
511pub struct Message {
512 pub role: Role,
513 pub content: String,
515 #[serde(default)]
516 pub parts: Vec<MessagePart>,
517 #[serde(default)]
518 pub metadata: MessageMetadata,
519}
520
521impl Default for Message {
522 fn default() -> Self {
523 Self {
524 role: Role::User,
525 content: String::new(),
526 parts: vec![],
527 metadata: MessageMetadata::default(),
528 }
529 }
530}
531
532impl Message {
533 #[must_use]
538 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
539 Self {
540 role,
541 content: content.into(),
542 parts: vec![],
543 metadata: MessageMetadata::default(),
544 }
545 }
546
547 #[must_use]
552 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
553 let content = Self::flatten_parts(&parts);
554 Self {
555 role,
556 content,
557 parts,
558 metadata: MessageMetadata::default(),
559 }
560 }
561
562 #[must_use]
565 pub fn to_llm_content(&self) -> &str {
566 &self.content
567 }
568
569 pub fn rebuild_content(&mut self) {
571 if !self.parts.is_empty() {
572 self.content = Self::flatten_parts(&self.parts);
573 }
574 }
575
576 fn flatten_parts(parts: &[MessagePart]) -> String {
577 use std::fmt::Write;
578 let mut out = String::new();
579 for part in parts {
580 match part {
581 MessagePart::Text { text }
582 | MessagePart::Recall { text }
583 | MessagePart::CodeContext { text }
584 | MessagePart::Summary { text }
585 | MessagePart::CrossSession { text } => out.push_str(text),
586 MessagePart::ToolOutput {
587 tool_name,
588 body,
589 compacted_at,
590 } => {
591 if compacted_at.is_some() {
592 if body.is_empty() {
593 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
594 } else {
595 let _ = write!(out, "[tool output: {tool_name}] {body}");
596 }
597 } else {
598 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
599 }
600 }
601 MessagePart::ToolUse { id, name, .. } => {
602 let _ = write!(out, "[tool_use: {name}({id})]");
603 }
604 MessagePart::ToolResult {
605 tool_use_id,
606 content,
607 ..
608 } => {
609 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
610 }
611 MessagePart::Image(img) => {
612 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
613 }
614 MessagePart::ThinkingBlock { .. }
616 | MessagePart::RedactedThinkingBlock { .. }
617 | MessagePart::Compaction { .. } => {}
618 }
619 }
620 out
621 }
622}
623
624pub trait LlmProvider: Send + Sync {
678 fn context_window(&self) -> Option<usize> {
682 None
683 }
684
685 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
691
692 fn chat_stream(
698 &self,
699 messages: &[Message],
700 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
701
702 fn supports_streaming(&self) -> bool;
704
705 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
711
712 fn embed_batch(
722 &self,
723 texts: &[&str],
724 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
725 let owned = owned_strs(texts);
726 async move {
727 let mut results = Vec::with_capacity(owned.len());
728 for text in &owned {
729 results.push(self.embed(text).await?);
730 }
731 Ok(results)
732 }
733 }
734
735 fn supports_embeddings(&self) -> bool;
737
738 fn name(&self) -> &str;
740
741 #[allow(clippy::unnecessary_literal_bound)]
744 fn model_identifier(&self) -> &str {
745 ""
746 }
747
748 fn supports_vision(&self) -> bool {
750 false
751 }
752
753 fn supports_tool_use(&self) -> bool {
755 true
756 }
757
758 #[allow(async_fn_in_trait)]
766 async fn chat_with_tools(
767 &self,
768 messages: &[Message],
769 _tools: &[ToolDefinition],
770 ) -> Result<ChatResponse, LlmError> {
771 Ok(ChatResponse::Text(self.chat(messages).await?))
772 }
773
774 fn last_cache_usage(&self) -> Option<(u64, u64)> {
777 None
778 }
779
780 fn last_usage(&self) -> Option<(u64, u64)> {
783 None
784 }
785
786 fn take_compaction_summary(&self) -> Option<String> {
789 None
790 }
791
792 fn record_quality_outcome(&self, _provider_name: &str, _success: bool) {}
798
799 #[must_use]
803 fn debug_request_json(
804 &self,
805 messages: &[Message],
806 tools: &[ToolDefinition],
807 _stream: bool,
808 ) -> serde_json::Value {
809 default_debug_request_json(messages, tools)
810 }
811
812 fn list_models(&self) -> Vec<String> {
815 vec![]
816 }
817
818 fn supports_structured_output(&self) -> bool {
820 false
821 }
822
823 #[allow(async_fn_in_trait)]
828 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
829 where
830 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
831 Self: Sized,
832 {
833 let (_, schema_json) = cached_schema::<T>()?;
834 let type_name = short_type_name::<T>();
835
836 let mut augmented = messages.to_vec();
837 let instruction = format!(
838 "Respond with a valid JSON object matching this schema. \
839 Output ONLY the JSON, no markdown fences or extra text.\n\n\
840 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
841 );
842 augmented.insert(0, Message::from_legacy(Role::System, instruction));
843
844 let raw = self.chat(&augmented).await?;
845 let cleaned = strip_json_fences(&raw);
846 match serde_json::from_str::<T>(cleaned) {
847 Ok(val) => Ok(val),
848 Err(first_err) => {
849 augmented.push(Message::from_legacy(Role::Assistant, &raw));
850 augmented.push(Message::from_legacy(
851 Role::User,
852 format!(
853 "Your response was not valid JSON. Error: {first_err}. \
854 Please output ONLY valid JSON matching the schema."
855 ),
856 ));
857 let retry_raw = self.chat(&augmented).await?;
858 let retry_cleaned = strip_json_fences(&retry_raw);
859 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
860 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
861 })
862 }
863 }
864 }
865}
866
867fn strip_json_fences(s: &str) -> &str {
871 s.trim()
872 .trim_start_matches("```json")
873 .trim_start_matches("```")
874 .trim_end_matches("```")
875 .trim()
876}
877
878#[cfg(test)]
879mod tests {
880 use tokio_stream::StreamExt;
881
882 use super::*;
883
884 struct StubProvider {
885 response: String,
886 }
887
888 impl LlmProvider for StubProvider {
889 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
890 Ok(self.response.clone())
891 }
892
893 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
894 let response = self.chat(messages).await?;
895 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
896 response,
897 )))))
898 }
899
900 fn supports_streaming(&self) -> bool {
901 false
902 }
903
904 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
905 Ok(vec![0.1, 0.2, 0.3])
906 }
907
908 fn supports_embeddings(&self) -> bool {
909 false
910 }
911
912 fn name(&self) -> &'static str {
913 "stub"
914 }
915 }
916
917 #[test]
918 fn context_window_default_returns_none() {
919 let provider = StubProvider {
920 response: String::new(),
921 };
922 assert!(provider.context_window().is_none());
923 }
924
925 #[test]
926 fn supports_streaming_default_returns_false() {
927 let provider = StubProvider {
928 response: String::new(),
929 };
930 assert!(!provider.supports_streaming());
931 }
932
933 #[tokio::test]
934 async fn chat_stream_default_yields_single_chunk() {
935 let provider = StubProvider {
936 response: "hello world".into(),
937 };
938 let messages = vec![Message {
939 role: Role::User,
940 content: "test".into(),
941 parts: vec![],
942 metadata: MessageMetadata::default(),
943 }];
944
945 let mut stream = provider.chat_stream(&messages).await.unwrap();
946 let chunk = stream.next().await.unwrap().unwrap();
947 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
948 assert!(stream.next().await.is_none());
949 }
950
951 #[tokio::test]
952 async fn chat_stream_default_propagates_chat_error() {
953 struct FailProvider;
954
955 impl LlmProvider for FailProvider {
956 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
957 Err(LlmError::Unavailable)
958 }
959
960 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
961 let response = self.chat(messages).await?;
962 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
963 response,
964 )))))
965 }
966
967 fn supports_streaming(&self) -> bool {
968 false
969 }
970
971 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
972 Err(LlmError::Unavailable)
973 }
974
975 fn supports_embeddings(&self) -> bool {
976 false
977 }
978
979 fn name(&self) -> &'static str {
980 "fail"
981 }
982 }
983
984 let provider = FailProvider;
985 let messages = vec![Message {
986 role: Role::User,
987 content: "test".into(),
988 parts: vec![],
989 metadata: MessageMetadata::default(),
990 }];
991
992 let result = provider.chat_stream(&messages).await;
993 assert!(result.is_err());
994 if let Err(e) = result {
995 assert!(e.to_string().contains("provider unavailable"));
996 }
997 }
998
999 #[tokio::test]
1000 async fn stub_provider_embed_returns_vector() {
1001 let provider = StubProvider {
1002 response: String::new(),
1003 };
1004 let embedding = provider.embed("test").await.unwrap();
1005 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
1006 }
1007
1008 #[tokio::test]
1009 async fn fail_provider_embed_propagates_error() {
1010 struct FailProvider;
1011
1012 impl LlmProvider for FailProvider {
1013 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1014 Err(LlmError::Unavailable)
1015 }
1016
1017 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1018 let response = self.chat(messages).await?;
1019 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1020 response,
1021 )))))
1022 }
1023
1024 fn supports_streaming(&self) -> bool {
1025 false
1026 }
1027
1028 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1029 Err(LlmError::EmbedUnsupported {
1030 provider: "fail".into(),
1031 })
1032 }
1033
1034 fn supports_embeddings(&self) -> bool {
1035 false
1036 }
1037
1038 fn name(&self) -> &'static str {
1039 "fail"
1040 }
1041 }
1042
1043 let provider = FailProvider;
1044 let result = provider.embed("test").await;
1045 assert!(result.is_err());
1046 assert!(
1047 result
1048 .unwrap_err()
1049 .to_string()
1050 .contains("embedding not supported")
1051 );
1052 }
1053
1054 #[test]
1055 fn role_serialization() {
1056 let system = Role::System;
1057 let user = Role::User;
1058 let assistant = Role::Assistant;
1059
1060 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
1061 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
1062 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
1063 }
1064
1065 #[test]
1066 fn role_deserialization() {
1067 let system: Role = serde_json::from_str("\"system\"").unwrap();
1068 let user: Role = serde_json::from_str("\"user\"").unwrap();
1069 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
1070
1071 assert_eq!(system, Role::System);
1072 assert_eq!(user, Role::User);
1073 assert_eq!(assistant, Role::Assistant);
1074 }
1075
1076 #[test]
1077 fn message_clone() {
1078 let msg = Message {
1079 role: Role::User,
1080 content: "test".into(),
1081 parts: vec![],
1082 metadata: MessageMetadata::default(),
1083 };
1084 let cloned = msg.clone();
1085 assert_eq!(cloned.role, msg.role);
1086 assert_eq!(cloned.content, msg.content);
1087 }
1088
1089 #[test]
1090 fn message_debug() {
1091 let msg = Message {
1092 role: Role::Assistant,
1093 content: "response".into(),
1094 parts: vec![],
1095 metadata: MessageMetadata::default(),
1096 };
1097 let debug = format!("{msg:?}");
1098 assert!(debug.contains("Assistant"));
1099 assert!(debug.contains("response"));
1100 }
1101
1102 #[test]
1103 fn message_serialization() {
1104 let msg = Message {
1105 role: Role::User,
1106 content: "hello".into(),
1107 parts: vec![],
1108 metadata: MessageMetadata::default(),
1109 };
1110 let json = serde_json::to_string(&msg).unwrap();
1111 assert!(json.contains("\"role\":\"user\""));
1112 assert!(json.contains("\"content\":\"hello\""));
1113 }
1114
1115 #[test]
1116 fn message_part_serde_round_trip() {
1117 let parts = vec![
1118 MessagePart::Text {
1119 text: "hello".into(),
1120 },
1121 MessagePart::ToolOutput {
1122 tool_name: "bash".into(),
1123 body: "output".into(),
1124 compacted_at: None,
1125 },
1126 MessagePart::Recall {
1127 text: "recall".into(),
1128 },
1129 MessagePart::CodeContext {
1130 text: "code".into(),
1131 },
1132 MessagePart::Summary {
1133 text: "summary".into(),
1134 },
1135 ];
1136 let json = serde_json::to_string(&parts).unwrap();
1137 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
1138 assert_eq!(deserialized.len(), 5);
1139 }
1140
1141 #[test]
1142 fn from_legacy_creates_empty_parts() {
1143 let msg = Message::from_legacy(Role::User, "hello");
1144 assert_eq!(msg.role, Role::User);
1145 assert_eq!(msg.content, "hello");
1146 assert!(msg.parts.is_empty());
1147 assert_eq!(msg.to_llm_content(), "hello");
1148 }
1149
1150 #[test]
1151 fn from_parts_flattens_content() {
1152 let msg = Message::from_parts(
1153 Role::System,
1154 vec![MessagePart::Recall {
1155 text: "recalled data".into(),
1156 }],
1157 );
1158 assert_eq!(msg.content, "recalled data");
1159 assert_eq!(msg.to_llm_content(), "recalled data");
1160 assert_eq!(msg.parts.len(), 1);
1161 }
1162
1163 #[test]
1164 fn from_parts_tool_output_format() {
1165 let msg = Message::from_parts(
1166 Role::User,
1167 vec![MessagePart::ToolOutput {
1168 tool_name: "bash".into(),
1169 body: "hello world".into(),
1170 compacted_at: None,
1171 }],
1172 );
1173 assert!(msg.content.contains("[tool output: bash]"));
1174 assert!(msg.content.contains("hello world"));
1175 }
1176
1177 #[test]
1178 fn message_deserializes_without_parts() {
1179 let json = r#"{"role":"user","content":"hello"}"#;
1180 let msg: Message = serde_json::from_str(json).unwrap();
1181 assert_eq!(msg.content, "hello");
1182 assert!(msg.parts.is_empty());
1183 }
1184
1185 #[test]
1186 fn flatten_skips_compacted_tool_output_empty_body() {
1187 let msg = Message::from_parts(
1189 Role::User,
1190 vec![
1191 MessagePart::Text {
1192 text: "prefix ".into(),
1193 },
1194 MessagePart::ToolOutput {
1195 tool_name: "bash".into(),
1196 body: String::new(),
1197 compacted_at: Some(1234),
1198 },
1199 MessagePart::Text {
1200 text: " suffix".into(),
1201 },
1202 ],
1203 );
1204 assert!(msg.content.contains("(pruned)"));
1205 assert!(msg.content.contains("prefix "));
1206 assert!(msg.content.contains(" suffix"));
1207 }
1208
1209 #[test]
1210 fn flatten_compacted_tool_output_with_reference_renders_body() {
1211 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1213 let msg = Message::from_parts(
1214 Role::User,
1215 vec![MessagePart::ToolOutput {
1216 tool_name: "bash".into(),
1217 body: ref_notice.into(),
1218 compacted_at: Some(1234),
1219 }],
1220 );
1221 assert!(msg.content.contains(ref_notice));
1222 assert!(!msg.content.contains("(pruned)"));
1223 }
1224
1225 #[test]
1226 fn rebuild_content_syncs_after_mutation() {
1227 let mut msg = Message::from_parts(
1228 Role::User,
1229 vec![MessagePart::ToolOutput {
1230 tool_name: "bash".into(),
1231 body: "original".into(),
1232 compacted_at: None,
1233 }],
1234 );
1235 assert!(msg.content.contains("original"));
1236
1237 if let MessagePart::ToolOutput {
1238 ref mut compacted_at,
1239 ref mut body,
1240 ..
1241 } = msg.parts[0]
1242 {
1243 *compacted_at = Some(999);
1244 body.clear(); }
1246 msg.rebuild_content();
1247
1248 assert!(msg.content.contains("(pruned)"));
1249 assert!(!msg.content.contains("original"));
1250 }
1251
1252 #[test]
1253 fn message_part_tool_use_serde_round_trip() {
1254 let part = MessagePart::ToolUse {
1255 id: "toolu_123".into(),
1256 name: "bash".into(),
1257 input: serde_json::json!({"command": "ls"}),
1258 };
1259 let json = serde_json::to_string(&part).unwrap();
1260 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1261 if let MessagePart::ToolUse { id, name, input } = deserialized {
1262 assert_eq!(id, "toolu_123");
1263 assert_eq!(name, "bash");
1264 assert_eq!(input["command"], "ls");
1265 } else {
1266 panic!("expected ToolUse");
1267 }
1268 }
1269
1270 #[test]
1271 fn message_part_tool_result_serde_round_trip() {
1272 let part = MessagePart::ToolResult {
1273 tool_use_id: "toolu_123".into(),
1274 content: "file1.rs\nfile2.rs".into(),
1275 is_error: false,
1276 };
1277 let json = serde_json::to_string(&part).unwrap();
1278 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1279 if let MessagePart::ToolResult {
1280 tool_use_id,
1281 content,
1282 is_error,
1283 } = deserialized
1284 {
1285 assert_eq!(tool_use_id, "toolu_123");
1286 assert_eq!(content, "file1.rs\nfile2.rs");
1287 assert!(!is_error);
1288 } else {
1289 panic!("expected ToolResult");
1290 }
1291 }
1292
1293 #[test]
1294 fn message_part_tool_result_is_error_default() {
1295 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1296 let part: MessagePart = serde_json::from_str(json).unwrap();
1297 if let MessagePart::ToolResult { is_error, .. } = part {
1298 assert!(!is_error);
1299 } else {
1300 panic!("expected ToolResult");
1301 }
1302 }
1303
1304 #[test]
1305 fn chat_response_construction() {
1306 let text = ChatResponse::Text("hello".into());
1307 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1308
1309 let tool_use = ChatResponse::ToolUse {
1310 text: Some("I'll run that".into()),
1311 tool_calls: vec![ToolUseRequest {
1312 id: "1".into(),
1313 name: "bash".into(),
1314 input: serde_json::json!({}),
1315 }],
1316 thinking_blocks: vec![],
1317 };
1318 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1319 }
1320
1321 #[test]
1322 fn flatten_parts_tool_use() {
1323 let msg = Message::from_parts(
1324 Role::Assistant,
1325 vec![MessagePart::ToolUse {
1326 id: "t1".into(),
1327 name: "bash".into(),
1328 input: serde_json::json!({"command": "ls"}),
1329 }],
1330 );
1331 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1332 }
1333
1334 #[test]
1335 fn flatten_parts_tool_result() {
1336 let msg = Message::from_parts(
1337 Role::User,
1338 vec![MessagePart::ToolResult {
1339 tool_use_id: "t1".into(),
1340 content: "output here".into(),
1341 is_error: false,
1342 }],
1343 );
1344 assert!(msg.content.contains("[tool_result: t1]"));
1345 assert!(msg.content.contains("output here"));
1346 }
1347
1348 #[test]
1349 fn tool_definition_serde_round_trip() {
1350 let def = ToolDefinition {
1351 name: "bash".into(),
1352 description: "Execute a shell command".into(),
1353 parameters: serde_json::json!({"type": "object"}),
1354 };
1355 let json = serde_json::to_string(&def).unwrap();
1356 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1357 assert_eq!(deserialized.name, "bash");
1358 assert_eq!(deserialized.description, "Execute a shell command");
1359 }
1360
1361 #[tokio::test]
1362 async fn chat_with_tools_default_delegates_to_chat() {
1363 let provider = StubProvider {
1364 response: "hello".into(),
1365 };
1366 let messages = vec![Message::from_legacy(Role::User, "test")];
1367 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1368 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1369 }
1370
1371 #[test]
1372 fn tool_output_compacted_at_serde_default() {
1373 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1374 let part: MessagePart = serde_json::from_str(json).unwrap();
1375 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1376 assert!(compacted_at.is_none());
1377 } else {
1378 panic!("expected ToolOutput");
1379 }
1380 }
1381
1382 #[test]
1385 fn strip_json_fences_plain_json() {
1386 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1387 }
1388
1389 #[test]
1390 fn strip_json_fences_with_json_fence() {
1391 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1392 }
1393
1394 #[test]
1395 fn strip_json_fences_with_plain_fence() {
1396 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1397 }
1398
1399 #[test]
1400 fn strip_json_fences_whitespace() {
1401 assert_eq!(strip_json_fences(" \n "), "");
1402 }
1403
1404 #[test]
1405 fn strip_json_fences_empty() {
1406 assert_eq!(strip_json_fences(""), "");
1407 }
1408
1409 #[test]
1410 fn strip_json_fences_outer_whitespace() {
1411 assert_eq!(
1412 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1413 r#"{"a": 1}"#
1414 );
1415 }
1416
1417 #[test]
1418 fn strip_json_fences_only_opening_fence() {
1419 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1420 }
1421
1422 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1425 struct TestOutput {
1426 value: String,
1427 }
1428
1429 struct SequentialStub {
1430 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1431 }
1432
1433 impl SequentialStub {
1434 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1435 Self {
1436 responses: std::sync::Mutex::new(responses),
1437 }
1438 }
1439 }
1440
1441 impl LlmProvider for SequentialStub {
1442 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1443 let mut responses = self.responses.lock().unwrap();
1444 if responses.is_empty() {
1445 return Err(LlmError::Other("no more responses".into()));
1446 }
1447 responses.remove(0)
1448 }
1449
1450 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1451 let response = self.chat(messages).await?;
1452 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1453 response,
1454 )))))
1455 }
1456
1457 fn supports_streaming(&self) -> bool {
1458 false
1459 }
1460
1461 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1462 Err(LlmError::EmbedUnsupported {
1463 provider: "sequential-stub".into(),
1464 })
1465 }
1466
1467 fn supports_embeddings(&self) -> bool {
1468 false
1469 }
1470
1471 fn name(&self) -> &'static str {
1472 "sequential-stub"
1473 }
1474 }
1475
1476 #[tokio::test]
1477 async fn chat_typed_happy_path() {
1478 let provider = StubProvider {
1479 response: r#"{"value": "hello"}"#.into(),
1480 };
1481 let messages = vec![Message::from_legacy(Role::User, "test")];
1482 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1483 assert_eq!(
1484 result,
1485 TestOutput {
1486 value: "hello".into()
1487 }
1488 );
1489 }
1490
1491 #[tokio::test]
1492 async fn chat_typed_retry_succeeds() {
1493 let provider = SequentialStub::new(vec![
1494 Ok("not valid json".into()),
1495 Ok(r#"{"value": "ok"}"#.into()),
1496 ]);
1497 let messages = vec![Message::from_legacy(Role::User, "test")];
1498 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1499 assert_eq!(result, TestOutput { value: "ok".into() });
1500 }
1501
1502 #[tokio::test]
1503 async fn chat_typed_both_fail() {
1504 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1505 let messages = vec![Message::from_legacy(Role::User, "test")];
1506 let result = provider.chat_typed::<TestOutput>(&messages).await;
1507 let err = result.unwrap_err();
1508 assert!(err.to_string().contains("parse failed after retry"));
1509 }
1510
1511 #[tokio::test]
1512 async fn chat_typed_chat_error_propagates() {
1513 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1514 let messages = vec![Message::from_legacy(Role::User, "test")];
1515 let result = provider.chat_typed::<TestOutput>(&messages).await;
1516 assert!(matches!(result, Err(LlmError::Unavailable)));
1517 }
1518
1519 #[tokio::test]
1520 async fn chat_typed_strips_fences() {
1521 let provider = StubProvider {
1522 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1523 };
1524 let messages = vec![Message::from_legacy(Role::User, "test")];
1525 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1526 assert_eq!(
1527 result,
1528 TestOutput {
1529 value: "fenced".into()
1530 }
1531 );
1532 }
1533
1534 #[test]
1535 fn supports_structured_output_default_false() {
1536 let provider = StubProvider {
1537 response: String::new(),
1538 };
1539 assert!(!provider.supports_structured_output());
1540 }
1541
1542 #[test]
1543 fn structured_parse_error_display() {
1544 let err = LlmError::StructuredParse("test error".into());
1545 assert_eq!(
1546 err.to_string(),
1547 "structured output parse failed: test error"
1548 );
1549 }
1550
1551 #[test]
1552 fn message_part_image_roundtrip_json() {
1553 let part = MessagePart::Image(Box::new(ImageData {
1554 data: vec![1, 2, 3, 4],
1555 mime_type: "image/jpeg".into(),
1556 }));
1557 let json = serde_json::to_string(&part).unwrap();
1558 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1559 match decoded {
1560 MessagePart::Image(img) => {
1561 assert_eq!(img.data, vec![1, 2, 3, 4]);
1562 assert_eq!(img.mime_type, "image/jpeg");
1563 }
1564 _ => panic!("expected Image variant"),
1565 }
1566 }
1567
1568 #[test]
1569 fn flatten_parts_includes_image_placeholder() {
1570 let msg = Message::from_parts(
1571 Role::User,
1572 vec![
1573 MessagePart::Text {
1574 text: "see this".into(),
1575 },
1576 MessagePart::Image(Box::new(ImageData {
1577 data: vec![0u8; 100],
1578 mime_type: "image/png".into(),
1579 })),
1580 ],
1581 );
1582 let content = msg.to_llm_content();
1583 assert!(content.contains("see this"));
1584 assert!(content.contains("[image: image/png"));
1585 }
1586
1587 #[test]
1588 fn supports_vision_default_false() {
1589 let provider = StubProvider {
1590 response: String::new(),
1591 };
1592 assert!(!provider.supports_vision());
1593 }
1594
1595 #[test]
1596 fn message_metadata_default_both_visible() {
1597 let m = MessageMetadata::default();
1598 assert!(m.visibility.is_agent_visible());
1599 assert!(m.visibility.is_user_visible());
1600 assert_eq!(m.visibility, MessageVisibility::Both);
1601 assert!(m.compacted_at.is_none());
1602 }
1603
1604 #[test]
1605 fn message_metadata_agent_only() {
1606 let m = MessageMetadata::agent_only();
1607 assert!(m.visibility.is_agent_visible());
1608 assert!(!m.visibility.is_user_visible());
1609 assert_eq!(m.visibility, MessageVisibility::AgentOnly);
1610 }
1611
1612 #[test]
1613 fn message_metadata_user_only() {
1614 let m = MessageMetadata::user_only();
1615 assert!(!m.visibility.is_agent_visible());
1616 assert!(m.visibility.is_user_visible());
1617 assert_eq!(m.visibility, MessageVisibility::UserOnly);
1618 }
1619
1620 #[test]
1621 fn message_metadata_serde_default() {
1622 let json = r#"{"role":"user","content":"hello"}"#;
1623 let msg: Message = serde_json::from_str(json).unwrap();
1624 assert!(msg.metadata.visibility.is_agent_visible());
1625 assert!(msg.metadata.visibility.is_user_visible());
1626 }
1627
1628 #[test]
1629 fn message_metadata_round_trip() {
1630 let msg = Message {
1631 role: Role::User,
1632 content: "test".into(),
1633 parts: vec![],
1634 metadata: MessageMetadata::agent_only(),
1635 };
1636 let json = serde_json::to_string(&msg).unwrap();
1637 let decoded: Message = serde_json::from_str(&json).unwrap();
1638 assert!(decoded.metadata.visibility.is_agent_visible());
1639 assert!(!decoded.metadata.visibility.is_user_visible());
1640 assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
1641 }
1642
1643 #[test]
1644 fn message_part_compaction_round_trip() {
1645 let part = MessagePart::Compaction {
1646 summary: "Context was summarized.".to_owned(),
1647 };
1648 let json = serde_json::to_string(&part).unwrap();
1649 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1650 assert!(
1651 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1652 );
1653 }
1654
1655 #[test]
1656 fn flatten_parts_compaction_contributes_no_text() {
1657 let parts = vec![
1660 MessagePart::Text {
1661 text: "Hello".to_owned(),
1662 },
1663 MessagePart::Compaction {
1664 summary: "Summary".to_owned(),
1665 },
1666 ];
1667 let msg = Message::from_parts(Role::Assistant, parts);
1668 assert_eq!(msg.content.trim(), "Hello");
1670 }
1671
1672 #[test]
1673 fn stream_chunk_compaction_variant() {
1674 let chunk = StreamChunk::Compaction("A summary".to_owned());
1675 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1676 }
1677
1678 #[test]
1679 fn short_type_name_extracts_last_segment() {
1680 struct MyOutput;
1681 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1682 }
1683
1684 #[test]
1685 fn short_type_name_primitive_returns_full_name() {
1686 assert_eq!(short_type_name::<u32>(), "u32");
1688 assert_eq!(short_type_name::<bool>(), "bool");
1689 }
1690
1691 #[test]
1692 fn short_type_name_nested_path_returns_last() {
1693 assert_eq!(
1695 short_type_name::<std::collections::HashMap<u32, u32>>(),
1696 "HashMap<u32, u32>"
1697 );
1698 }
1699
1700 #[test]
1703 fn summary_roundtrip() {
1704 let part = MessagePart::Summary {
1705 text: "hello".to_string(),
1706 };
1707 let json = serde_json::to_string(&part).expect("serialization must not fail");
1708 assert!(
1709 json.contains("\"kind\":\"summary\""),
1710 "must use internally-tagged format, got: {json}"
1711 );
1712 assert!(
1713 !json.contains("\"Summary\""),
1714 "must not use externally-tagged format, got: {json}"
1715 );
1716 let decoded: MessagePart =
1717 serde_json::from_str(&json).expect("deserialization must not fail");
1718 match decoded {
1719 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1720 other => panic!("expected MessagePart::Summary, got {other:?}"),
1721 }
1722 }
1723
1724 #[tokio::test]
1725 async fn embed_batch_default_empty_returns_empty() {
1726 let provider = StubProvider {
1727 response: String::new(),
1728 };
1729 let result = provider.embed_batch(&[]).await.unwrap();
1730 assert!(result.is_empty());
1731 }
1732
1733 #[tokio::test]
1734 async fn embed_batch_default_calls_embed_sequentially() {
1735 let provider = StubProvider {
1736 response: String::new(),
1737 };
1738 let texts = ["hello", "world", "foo"];
1739 let result = provider.embed_batch(&texts).await.unwrap();
1740 assert_eq!(result.len(), 3);
1741 for vec in &result {
1743 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1744 }
1745 }
1746
1747 #[test]
1748 fn message_visibility_db_roundtrip_both() {
1749 assert_eq!(MessageVisibility::Both.as_db_str(), "both");
1750 assert_eq!(
1751 MessageVisibility::from_db_str("both"),
1752 MessageVisibility::Both
1753 );
1754 }
1755
1756 #[test]
1757 fn message_visibility_db_roundtrip_agent_only() {
1758 assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
1759 assert_eq!(
1760 MessageVisibility::from_db_str("agent_only"),
1761 MessageVisibility::AgentOnly
1762 );
1763 }
1764
1765 #[test]
1766 fn message_visibility_db_roundtrip_user_only() {
1767 assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
1768 assert_eq!(
1769 MessageVisibility::from_db_str("user_only"),
1770 MessageVisibility::UserOnly
1771 );
1772 }
1773
1774 #[test]
1775 fn message_visibility_from_db_str_unknown_defaults_to_both() {
1776 assert_eq!(
1777 MessageVisibility::from_db_str("unknown_future_value"),
1778 MessageVisibility::Both
1779 );
1780 assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
1781 }
1782}