1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use crate::embed::owned_strs;
16use crate::error::LlmError;
17
18static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
19 LazyLock::new(|| Mutex::new(HashMap::new()));
20
21pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
27-> Result<(serde_json::Value, String), crate::LlmError> {
28 let type_id = TypeId::of::<T>();
29 if let Ok(cache) = SCHEMA_CACHE.lock()
30 && let Some(entry) = cache.get(&type_id)
31 {
32 return Ok(entry.clone());
33 }
34 let schema = schemars::schema_for!(T);
35 let value = serde_json::to_value(&schema)
36 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
37 let pretty = serde_json::to_string_pretty(&schema)
38 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
39 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
40 cache.insert(type_id, (value.clone(), pretty.clone()));
41 }
42 Ok((value, pretty))
43}
44
45pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
59 std::any::type_name::<T>()
60 .rsplit("::")
61 .next()
62 .unwrap_or("Output")
63}
64
65#[derive(Debug, Clone)]
67pub enum StreamChunk {
68 Content(String),
70 Thinking(String),
72 Compaction(String),
75 ToolUse(Vec<ToolUseRequest>),
77}
78
79pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ToolDefinition {
87 pub name: String,
88 pub description: String,
89 pub parameters: serde_json::Value,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ToolUseRequest {
96 pub id: String,
97 pub name: String,
98 pub input: serde_json::Value,
99}
100
101#[derive(Debug, Clone)]
103pub enum ThinkingBlock {
104 Thinking { thinking: String, signature: String },
105 Redacted { data: String },
106}
107
108pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
111
112#[derive(Debug, Clone)]
114pub enum ChatResponse {
115 Text(String),
117 ToolUse {
119 text: Option<String>,
121 tool_calls: Vec<ToolUseRequest>,
122 thinking_blocks: Vec<ThinkingBlock>,
125 },
126}
127
128pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
130
131pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
133
134pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
136
137#[must_use]
140pub fn default_debug_request_json(
141 messages: &[Message],
142 tools: &[ToolDefinition],
143) -> serde_json::Value {
144 serde_json::json!({
145 "model": serde_json::Value::Null,
146 "max_tokens": serde_json::Value::Null,
147 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
148 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
149 "temperature": serde_json::Value::Null,
150 "cache_control": serde_json::Value::Null,
151 })
152}
153
154#[derive(Debug, Clone, Default)]
159pub struct GenerationOverrides {
160 pub temperature: Option<f64>,
161 pub top_p: Option<f64>,
162 pub top_k: Option<usize>,
163 pub frequency_penalty: Option<f64>,
164 pub presence_penalty: Option<f64>,
165}
166
167#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
168#[serde(rename_all = "lowercase")]
169pub enum Role {
170 System,
171 User,
172 Assistant,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize)]
176#[serde(tag = "kind", rename_all = "snake_case")]
177pub enum MessagePart {
178 Text {
179 text: String,
180 },
181 ToolOutput {
182 tool_name: String,
183 body: String,
184 #[serde(default, skip_serializing_if = "Option::is_none")]
185 compacted_at: Option<i64>,
186 },
187 Recall {
188 text: String,
189 },
190 CodeContext {
191 text: String,
192 },
193 Summary {
194 text: String,
195 },
196 CrossSession {
197 text: String,
198 },
199 ToolUse {
200 id: String,
201 name: String,
202 input: serde_json::Value,
203 },
204 ToolResult {
205 tool_use_id: String,
206 content: String,
207 #[serde(default)]
208 is_error: bool,
209 },
210 Image(Box<ImageData>),
211 ThinkingBlock {
213 thinking: String,
214 signature: String,
215 },
216 RedactedThinkingBlock {
218 data: String,
219 },
220 Compaction {
223 summary: String,
224 },
225}
226
227impl MessagePart {
228 #[must_use]
231 pub fn as_plain_text(&self) -> Option<&str> {
232 match self {
233 Self::Text { text }
234 | Self::Recall { text }
235 | Self::CodeContext { text }
236 | Self::Summary { text }
237 | Self::CrossSession { text } => Some(text.as_str()),
238 _ => None,
239 }
240 }
241
242 #[must_use]
244 pub fn as_image(&self) -> Option<&ImageData> {
245 if let Self::Image(img) = self {
246 Some(img)
247 } else {
248 None
249 }
250 }
251}
252
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct ImageData {
255 #[serde(with = "serde_bytes_base64")]
256 pub data: Vec<u8>,
257 pub mime_type: String,
258}
259
260mod serde_bytes_base64 {
261 use base64::{Engine, engine::general_purpose::STANDARD};
262 use serde::{Deserialize, Deserializer, Serializer};
263
264 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
265 where
266 S: Serializer,
267 {
268 s.serialize_str(&STANDARD.encode(bytes))
269 }
270
271 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
272 where
273 D: Deserializer<'de>,
274 {
275 let s = String::deserialize(d)?;
276 STANDARD.decode(&s).map_err(serde::de::Error::custom)
277 }
278}
279
280#[derive(Clone, Debug, Serialize, Deserialize)]
282pub struct MessageMetadata {
283 pub agent_visible: bool,
284 pub user_visible: bool,
285 #[serde(default, skip_serializing_if = "Option::is_none")]
286 pub compacted_at: Option<i64>,
287 #[serde(default, skip_serializing_if = "Option::is_none")]
290 pub deferred_summary: Option<String>,
291 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
294 pub focus_pinned: bool,
295 #[serde(default, skip_serializing_if = "Option::is_none")]
298 pub focus_marker_id: Option<uuid::Uuid>,
299 #[serde(skip)]
302 pub db_id: Option<i64>,
303}
304
305impl Default for MessageMetadata {
306 fn default() -> Self {
307 Self {
308 agent_visible: true,
309 user_visible: true,
310 compacted_at: None,
311 deferred_summary: None,
312 focus_pinned: false,
313 focus_marker_id: None,
314 db_id: None,
315 }
316 }
317}
318
319impl MessageMetadata {
320 #[must_use]
322 pub fn agent_only() -> Self {
323 Self {
324 agent_visible: true,
325 user_visible: false,
326 compacted_at: None,
327 deferred_summary: None,
328 focus_pinned: false,
329 focus_marker_id: None,
330 db_id: None,
331 }
332 }
333
334 #[must_use]
336 pub fn user_only() -> Self {
337 Self {
338 agent_visible: false,
339 user_visible: true,
340 compacted_at: None,
341 deferred_summary: None,
342 focus_pinned: false,
343 focus_marker_id: None,
344 db_id: None,
345 }
346 }
347
348 #[must_use]
350 pub fn focus_pinned() -> Self {
351 Self {
352 agent_visible: true,
353 user_visible: false,
354 compacted_at: None,
355 deferred_summary: None,
356 focus_pinned: true,
357 focus_marker_id: None,
358 db_id: None,
359 }
360 }
361}
362
363#[derive(Clone, Debug, Serialize, Deserialize)]
364pub struct Message {
365 pub role: Role,
366 pub content: String,
367 #[serde(default)]
368 pub parts: Vec<MessagePart>,
369 #[serde(default)]
370 pub metadata: MessageMetadata,
371}
372
373impl Default for Message {
374 fn default() -> Self {
375 Self {
376 role: Role::User,
377 content: String::new(),
378 parts: vec![],
379 metadata: MessageMetadata::default(),
380 }
381 }
382}
383
384impl Message {
385 #[must_use]
386 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
387 Self {
388 role,
389 content: content.into(),
390 parts: vec![],
391 metadata: MessageMetadata::default(),
392 }
393 }
394
395 #[must_use]
396 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
397 let content = Self::flatten_parts(&parts);
398 Self {
399 role,
400 content,
401 parts,
402 metadata: MessageMetadata::default(),
403 }
404 }
405
406 #[must_use]
407 pub fn to_llm_content(&self) -> &str {
408 &self.content
409 }
410
411 pub fn rebuild_content(&mut self) {
413 if !self.parts.is_empty() {
414 self.content = Self::flatten_parts(&self.parts);
415 }
416 }
417
418 fn flatten_parts(parts: &[MessagePart]) -> String {
419 use std::fmt::Write;
420 let mut out = String::new();
421 for part in parts {
422 match part {
423 MessagePart::Text { text }
424 | MessagePart::Recall { text }
425 | MessagePart::CodeContext { text }
426 | MessagePart::Summary { text }
427 | MessagePart::CrossSession { text } => out.push_str(text),
428 MessagePart::ToolOutput {
429 tool_name,
430 body,
431 compacted_at,
432 } => {
433 if compacted_at.is_some() {
434 if body.is_empty() {
435 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
436 } else {
437 let _ = write!(out, "[tool output: {tool_name}] {body}");
438 }
439 } else {
440 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
441 }
442 }
443 MessagePart::ToolUse { id, name, .. } => {
444 let _ = write!(out, "[tool_use: {name}({id})]");
445 }
446 MessagePart::ToolResult {
447 tool_use_id,
448 content,
449 ..
450 } => {
451 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
452 }
453 MessagePart::Image(img) => {
454 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
455 }
456 MessagePart::ThinkingBlock { .. }
458 | MessagePart::RedactedThinkingBlock { .. }
459 | MessagePart::Compaction { .. } => {}
460 }
461 }
462 out
463 }
464}
465
466pub trait LlmProvider: Send + Sync {
467 fn context_window(&self) -> Option<usize> {
471 None
472 }
473
474 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
480
481 fn chat_stream(
487 &self,
488 messages: &[Message],
489 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
490
491 fn supports_streaming(&self) -> bool;
493
494 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
500
501 fn embed_batch(
511 &self,
512 texts: &[&str],
513 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
514 let owned = owned_strs(texts);
515 async move {
516 let mut results = Vec::with_capacity(owned.len());
517 for text in &owned {
518 results.push(self.embed(text).await?);
519 }
520 Ok(results)
521 }
522 }
523
524 fn supports_embeddings(&self) -> bool;
526
527 fn name(&self) -> &str;
529
530 #[allow(clippy::unnecessary_literal_bound)]
533 fn model_identifier(&self) -> &str {
534 ""
535 }
536
537 fn supports_vision(&self) -> bool {
539 false
540 }
541
542 fn supports_tool_use(&self) -> bool {
544 true
545 }
546
547 #[allow(async_fn_in_trait)]
555 async fn chat_with_tools(
556 &self,
557 messages: &[Message],
558 _tools: &[ToolDefinition],
559 ) -> Result<ChatResponse, LlmError> {
560 Ok(ChatResponse::Text(self.chat(messages).await?))
561 }
562
563 fn last_cache_usage(&self) -> Option<(u64, u64)> {
566 None
567 }
568
569 fn last_usage(&self) -> Option<(u64, u64)> {
572 None
573 }
574
575 fn take_compaction_summary(&self) -> Option<String> {
578 None
579 }
580
581 fn record_quality_outcome(&self, _provider_name: &str, _success: bool) {}
587
588 #[must_use]
592 fn debug_request_json(
593 &self,
594 messages: &[Message],
595 tools: &[ToolDefinition],
596 _stream: bool,
597 ) -> serde_json::Value {
598 default_debug_request_json(messages, tools)
599 }
600
601 fn list_models(&self) -> Vec<String> {
604 vec![]
605 }
606
607 fn supports_structured_output(&self) -> bool {
609 false
610 }
611
612 #[allow(async_fn_in_trait)]
617 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
618 where
619 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
620 Self: Sized,
621 {
622 let (_, schema_json) = cached_schema::<T>()?;
623 let type_name = short_type_name::<T>();
624
625 let mut augmented = messages.to_vec();
626 let instruction = format!(
627 "Respond with a valid JSON object matching this schema. \
628 Output ONLY the JSON, no markdown fences or extra text.\n\n\
629 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
630 );
631 augmented.insert(0, Message::from_legacy(Role::System, instruction));
632
633 let raw = self.chat(&augmented).await?;
634 let cleaned = strip_json_fences(&raw);
635 match serde_json::from_str::<T>(cleaned) {
636 Ok(val) => Ok(val),
637 Err(first_err) => {
638 augmented.push(Message::from_legacy(Role::Assistant, &raw));
639 augmented.push(Message::from_legacy(
640 Role::User,
641 format!(
642 "Your response was not valid JSON. Error: {first_err}. \
643 Please output ONLY valid JSON matching the schema."
644 ),
645 ));
646 let retry_raw = self.chat(&augmented).await?;
647 let retry_cleaned = strip_json_fences(&retry_raw);
648 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
649 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
650 })
651 }
652 }
653 }
654}
655
656fn strip_json_fences(s: &str) -> &str {
660 s.trim()
661 .trim_start_matches("```json")
662 .trim_start_matches("```")
663 .trim_end_matches("```")
664 .trim()
665}
666
667#[cfg(test)]
668mod tests {
669 use tokio_stream::StreamExt;
670
671 use super::*;
672
673 struct StubProvider {
674 response: String,
675 }
676
677 impl LlmProvider for StubProvider {
678 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
679 Ok(self.response.clone())
680 }
681
682 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
683 let response = self.chat(messages).await?;
684 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
685 response,
686 )))))
687 }
688
689 fn supports_streaming(&self) -> bool {
690 false
691 }
692
693 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
694 Ok(vec![0.1, 0.2, 0.3])
695 }
696
697 fn supports_embeddings(&self) -> bool {
698 false
699 }
700
701 fn name(&self) -> &'static str {
702 "stub"
703 }
704 }
705
706 #[test]
707 fn context_window_default_returns_none() {
708 let provider = StubProvider {
709 response: String::new(),
710 };
711 assert!(provider.context_window().is_none());
712 }
713
714 #[test]
715 fn supports_streaming_default_returns_false() {
716 let provider = StubProvider {
717 response: String::new(),
718 };
719 assert!(!provider.supports_streaming());
720 }
721
722 #[tokio::test]
723 async fn chat_stream_default_yields_single_chunk() {
724 let provider = StubProvider {
725 response: "hello world".into(),
726 };
727 let messages = vec![Message {
728 role: Role::User,
729 content: "test".into(),
730 parts: vec![],
731 metadata: MessageMetadata::default(),
732 }];
733
734 let mut stream = provider.chat_stream(&messages).await.unwrap();
735 let chunk = stream.next().await.unwrap().unwrap();
736 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
737 assert!(stream.next().await.is_none());
738 }
739
740 #[tokio::test]
741 async fn chat_stream_default_propagates_chat_error() {
742 struct FailProvider;
743
744 impl LlmProvider for FailProvider {
745 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
746 Err(LlmError::Unavailable)
747 }
748
749 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
750 let response = self.chat(messages).await?;
751 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
752 response,
753 )))))
754 }
755
756 fn supports_streaming(&self) -> bool {
757 false
758 }
759
760 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
761 Err(LlmError::Unavailable)
762 }
763
764 fn supports_embeddings(&self) -> bool {
765 false
766 }
767
768 fn name(&self) -> &'static str {
769 "fail"
770 }
771 }
772
773 let provider = FailProvider;
774 let messages = vec![Message {
775 role: Role::User,
776 content: "test".into(),
777 parts: vec![],
778 metadata: MessageMetadata::default(),
779 }];
780
781 let result = provider.chat_stream(&messages).await;
782 assert!(result.is_err());
783 if let Err(e) = result {
784 assert!(e.to_string().contains("provider unavailable"));
785 }
786 }
787
788 #[tokio::test]
789 async fn stub_provider_embed_returns_vector() {
790 let provider = StubProvider {
791 response: String::new(),
792 };
793 let embedding = provider.embed("test").await.unwrap();
794 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
795 }
796
797 #[tokio::test]
798 async fn fail_provider_embed_propagates_error() {
799 struct FailProvider;
800
801 impl LlmProvider for FailProvider {
802 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
803 Err(LlmError::Unavailable)
804 }
805
806 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
807 let response = self.chat(messages).await?;
808 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
809 response,
810 )))))
811 }
812
813 fn supports_streaming(&self) -> bool {
814 false
815 }
816
817 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
818 Err(LlmError::EmbedUnsupported {
819 provider: "fail".into(),
820 })
821 }
822
823 fn supports_embeddings(&self) -> bool {
824 false
825 }
826
827 fn name(&self) -> &'static str {
828 "fail"
829 }
830 }
831
832 let provider = FailProvider;
833 let result = provider.embed("test").await;
834 assert!(result.is_err());
835 assert!(
836 result
837 .unwrap_err()
838 .to_string()
839 .contains("embedding not supported")
840 );
841 }
842
843 #[test]
844 fn role_serialization() {
845 let system = Role::System;
846 let user = Role::User;
847 let assistant = Role::Assistant;
848
849 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
850 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
851 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
852 }
853
854 #[test]
855 fn role_deserialization() {
856 let system: Role = serde_json::from_str("\"system\"").unwrap();
857 let user: Role = serde_json::from_str("\"user\"").unwrap();
858 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
859
860 assert_eq!(system, Role::System);
861 assert_eq!(user, Role::User);
862 assert_eq!(assistant, Role::Assistant);
863 }
864
865 #[test]
866 fn message_clone() {
867 let msg = Message {
868 role: Role::User,
869 content: "test".into(),
870 parts: vec![],
871 metadata: MessageMetadata::default(),
872 };
873 let cloned = msg.clone();
874 assert_eq!(cloned.role, msg.role);
875 assert_eq!(cloned.content, msg.content);
876 }
877
878 #[test]
879 fn message_debug() {
880 let msg = Message {
881 role: Role::Assistant,
882 content: "response".into(),
883 parts: vec![],
884 metadata: MessageMetadata::default(),
885 };
886 let debug = format!("{msg:?}");
887 assert!(debug.contains("Assistant"));
888 assert!(debug.contains("response"));
889 }
890
891 #[test]
892 fn message_serialization() {
893 let msg = Message {
894 role: Role::User,
895 content: "hello".into(),
896 parts: vec![],
897 metadata: MessageMetadata::default(),
898 };
899 let json = serde_json::to_string(&msg).unwrap();
900 assert!(json.contains("\"role\":\"user\""));
901 assert!(json.contains("\"content\":\"hello\""));
902 }
903
904 #[test]
905 fn message_part_serde_round_trip() {
906 let parts = vec![
907 MessagePart::Text {
908 text: "hello".into(),
909 },
910 MessagePart::ToolOutput {
911 tool_name: "bash".into(),
912 body: "output".into(),
913 compacted_at: None,
914 },
915 MessagePart::Recall {
916 text: "recall".into(),
917 },
918 MessagePart::CodeContext {
919 text: "code".into(),
920 },
921 MessagePart::Summary {
922 text: "summary".into(),
923 },
924 ];
925 let json = serde_json::to_string(&parts).unwrap();
926 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
927 assert_eq!(deserialized.len(), 5);
928 }
929
930 #[test]
931 fn from_legacy_creates_empty_parts() {
932 let msg = Message::from_legacy(Role::User, "hello");
933 assert_eq!(msg.role, Role::User);
934 assert_eq!(msg.content, "hello");
935 assert!(msg.parts.is_empty());
936 assert_eq!(msg.to_llm_content(), "hello");
937 }
938
939 #[test]
940 fn from_parts_flattens_content() {
941 let msg = Message::from_parts(
942 Role::System,
943 vec![MessagePart::Recall {
944 text: "recalled data".into(),
945 }],
946 );
947 assert_eq!(msg.content, "recalled data");
948 assert_eq!(msg.to_llm_content(), "recalled data");
949 assert_eq!(msg.parts.len(), 1);
950 }
951
952 #[test]
953 fn from_parts_tool_output_format() {
954 let msg = Message::from_parts(
955 Role::User,
956 vec![MessagePart::ToolOutput {
957 tool_name: "bash".into(),
958 body: "hello world".into(),
959 compacted_at: None,
960 }],
961 );
962 assert!(msg.content.contains("[tool output: bash]"));
963 assert!(msg.content.contains("hello world"));
964 }
965
966 #[test]
967 fn message_deserializes_without_parts() {
968 let json = r#"{"role":"user","content":"hello"}"#;
969 let msg: Message = serde_json::from_str(json).unwrap();
970 assert_eq!(msg.content, "hello");
971 assert!(msg.parts.is_empty());
972 }
973
974 #[test]
975 fn flatten_skips_compacted_tool_output_empty_body() {
976 let msg = Message::from_parts(
978 Role::User,
979 vec![
980 MessagePart::Text {
981 text: "prefix ".into(),
982 },
983 MessagePart::ToolOutput {
984 tool_name: "bash".into(),
985 body: String::new(),
986 compacted_at: Some(1234),
987 },
988 MessagePart::Text {
989 text: " suffix".into(),
990 },
991 ],
992 );
993 assert!(msg.content.contains("(pruned)"));
994 assert!(msg.content.contains("prefix "));
995 assert!(msg.content.contains(" suffix"));
996 }
997
998 #[test]
999 fn flatten_compacted_tool_output_with_reference_renders_body() {
1000 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1002 let msg = Message::from_parts(
1003 Role::User,
1004 vec![MessagePart::ToolOutput {
1005 tool_name: "bash".into(),
1006 body: ref_notice.into(),
1007 compacted_at: Some(1234),
1008 }],
1009 );
1010 assert!(msg.content.contains(ref_notice));
1011 assert!(!msg.content.contains("(pruned)"));
1012 }
1013
1014 #[test]
1015 fn rebuild_content_syncs_after_mutation() {
1016 let mut msg = Message::from_parts(
1017 Role::User,
1018 vec![MessagePart::ToolOutput {
1019 tool_name: "bash".into(),
1020 body: "original".into(),
1021 compacted_at: None,
1022 }],
1023 );
1024 assert!(msg.content.contains("original"));
1025
1026 if let MessagePart::ToolOutput {
1027 ref mut compacted_at,
1028 ref mut body,
1029 ..
1030 } = msg.parts[0]
1031 {
1032 *compacted_at = Some(999);
1033 body.clear(); }
1035 msg.rebuild_content();
1036
1037 assert!(msg.content.contains("(pruned)"));
1038 assert!(!msg.content.contains("original"));
1039 }
1040
1041 #[test]
1042 fn message_part_tool_use_serde_round_trip() {
1043 let part = MessagePart::ToolUse {
1044 id: "toolu_123".into(),
1045 name: "bash".into(),
1046 input: serde_json::json!({"command": "ls"}),
1047 };
1048 let json = serde_json::to_string(&part).unwrap();
1049 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1050 if let MessagePart::ToolUse { id, name, input } = deserialized {
1051 assert_eq!(id, "toolu_123");
1052 assert_eq!(name, "bash");
1053 assert_eq!(input["command"], "ls");
1054 } else {
1055 panic!("expected ToolUse");
1056 }
1057 }
1058
1059 #[test]
1060 fn message_part_tool_result_serde_round_trip() {
1061 let part = MessagePart::ToolResult {
1062 tool_use_id: "toolu_123".into(),
1063 content: "file1.rs\nfile2.rs".into(),
1064 is_error: false,
1065 };
1066 let json = serde_json::to_string(&part).unwrap();
1067 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1068 if let MessagePart::ToolResult {
1069 tool_use_id,
1070 content,
1071 is_error,
1072 } = deserialized
1073 {
1074 assert_eq!(tool_use_id, "toolu_123");
1075 assert_eq!(content, "file1.rs\nfile2.rs");
1076 assert!(!is_error);
1077 } else {
1078 panic!("expected ToolResult");
1079 }
1080 }
1081
1082 #[test]
1083 fn message_part_tool_result_is_error_default() {
1084 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1085 let part: MessagePart = serde_json::from_str(json).unwrap();
1086 if let MessagePart::ToolResult { is_error, .. } = part {
1087 assert!(!is_error);
1088 } else {
1089 panic!("expected ToolResult");
1090 }
1091 }
1092
1093 #[test]
1094 fn chat_response_construction() {
1095 let text = ChatResponse::Text("hello".into());
1096 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1097
1098 let tool_use = ChatResponse::ToolUse {
1099 text: Some("I'll run that".into()),
1100 tool_calls: vec![ToolUseRequest {
1101 id: "1".into(),
1102 name: "bash".into(),
1103 input: serde_json::json!({}),
1104 }],
1105 thinking_blocks: vec![],
1106 };
1107 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1108 }
1109
1110 #[test]
1111 fn flatten_parts_tool_use() {
1112 let msg = Message::from_parts(
1113 Role::Assistant,
1114 vec![MessagePart::ToolUse {
1115 id: "t1".into(),
1116 name: "bash".into(),
1117 input: serde_json::json!({"command": "ls"}),
1118 }],
1119 );
1120 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1121 }
1122
1123 #[test]
1124 fn flatten_parts_tool_result() {
1125 let msg = Message::from_parts(
1126 Role::User,
1127 vec![MessagePart::ToolResult {
1128 tool_use_id: "t1".into(),
1129 content: "output here".into(),
1130 is_error: false,
1131 }],
1132 );
1133 assert!(msg.content.contains("[tool_result: t1]"));
1134 assert!(msg.content.contains("output here"));
1135 }
1136
1137 #[test]
1138 fn tool_definition_serde_round_trip() {
1139 let def = ToolDefinition {
1140 name: "bash".into(),
1141 description: "Execute a shell command".into(),
1142 parameters: serde_json::json!({"type": "object"}),
1143 };
1144 let json = serde_json::to_string(&def).unwrap();
1145 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1146 assert_eq!(deserialized.name, "bash");
1147 assert_eq!(deserialized.description, "Execute a shell command");
1148 }
1149
1150 #[tokio::test]
1151 async fn chat_with_tools_default_delegates_to_chat() {
1152 let provider = StubProvider {
1153 response: "hello".into(),
1154 };
1155 let messages = vec![Message::from_legacy(Role::User, "test")];
1156 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1157 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1158 }
1159
1160 #[test]
1161 fn tool_output_compacted_at_serde_default() {
1162 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1163 let part: MessagePart = serde_json::from_str(json).unwrap();
1164 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1165 assert!(compacted_at.is_none());
1166 } else {
1167 panic!("expected ToolOutput");
1168 }
1169 }
1170
1171 #[test]
1174 fn strip_json_fences_plain_json() {
1175 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1176 }
1177
1178 #[test]
1179 fn strip_json_fences_with_json_fence() {
1180 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1181 }
1182
1183 #[test]
1184 fn strip_json_fences_with_plain_fence() {
1185 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1186 }
1187
1188 #[test]
1189 fn strip_json_fences_whitespace() {
1190 assert_eq!(strip_json_fences(" \n "), "");
1191 }
1192
1193 #[test]
1194 fn strip_json_fences_empty() {
1195 assert_eq!(strip_json_fences(""), "");
1196 }
1197
1198 #[test]
1199 fn strip_json_fences_outer_whitespace() {
1200 assert_eq!(
1201 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1202 r#"{"a": 1}"#
1203 );
1204 }
1205
1206 #[test]
1207 fn strip_json_fences_only_opening_fence() {
1208 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1209 }
1210
1211 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1214 struct TestOutput {
1215 value: String,
1216 }
1217
1218 struct SequentialStub {
1219 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1220 }
1221
1222 impl SequentialStub {
1223 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1224 Self {
1225 responses: std::sync::Mutex::new(responses),
1226 }
1227 }
1228 }
1229
1230 impl LlmProvider for SequentialStub {
1231 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1232 let mut responses = self.responses.lock().unwrap();
1233 if responses.is_empty() {
1234 return Err(LlmError::Other("no more responses".into()));
1235 }
1236 responses.remove(0)
1237 }
1238
1239 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1240 let response = self.chat(messages).await?;
1241 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1242 response,
1243 )))))
1244 }
1245
1246 fn supports_streaming(&self) -> bool {
1247 false
1248 }
1249
1250 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1251 Err(LlmError::EmbedUnsupported {
1252 provider: "sequential-stub".into(),
1253 })
1254 }
1255
1256 fn supports_embeddings(&self) -> bool {
1257 false
1258 }
1259
1260 fn name(&self) -> &'static str {
1261 "sequential-stub"
1262 }
1263 }
1264
1265 #[tokio::test]
1266 async fn chat_typed_happy_path() {
1267 let provider = StubProvider {
1268 response: r#"{"value": "hello"}"#.into(),
1269 };
1270 let messages = vec![Message::from_legacy(Role::User, "test")];
1271 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1272 assert_eq!(
1273 result,
1274 TestOutput {
1275 value: "hello".into()
1276 }
1277 );
1278 }
1279
1280 #[tokio::test]
1281 async fn chat_typed_retry_succeeds() {
1282 let provider = SequentialStub::new(vec![
1283 Ok("not valid json".into()),
1284 Ok(r#"{"value": "ok"}"#.into()),
1285 ]);
1286 let messages = vec![Message::from_legacy(Role::User, "test")];
1287 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1288 assert_eq!(result, TestOutput { value: "ok".into() });
1289 }
1290
1291 #[tokio::test]
1292 async fn chat_typed_both_fail() {
1293 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1294 let messages = vec![Message::from_legacy(Role::User, "test")];
1295 let result = provider.chat_typed::<TestOutput>(&messages).await;
1296 let err = result.unwrap_err();
1297 assert!(err.to_string().contains("parse failed after retry"));
1298 }
1299
1300 #[tokio::test]
1301 async fn chat_typed_chat_error_propagates() {
1302 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1303 let messages = vec![Message::from_legacy(Role::User, "test")];
1304 let result = provider.chat_typed::<TestOutput>(&messages).await;
1305 assert!(matches!(result, Err(LlmError::Unavailable)));
1306 }
1307
1308 #[tokio::test]
1309 async fn chat_typed_strips_fences() {
1310 let provider = StubProvider {
1311 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1312 };
1313 let messages = vec![Message::from_legacy(Role::User, "test")];
1314 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1315 assert_eq!(
1316 result,
1317 TestOutput {
1318 value: "fenced".into()
1319 }
1320 );
1321 }
1322
1323 #[test]
1324 fn supports_structured_output_default_false() {
1325 let provider = StubProvider {
1326 response: String::new(),
1327 };
1328 assert!(!provider.supports_structured_output());
1329 }
1330
1331 #[test]
1332 fn structured_parse_error_display() {
1333 let err = LlmError::StructuredParse("test error".into());
1334 assert_eq!(
1335 err.to_string(),
1336 "structured output parse failed: test error"
1337 );
1338 }
1339
1340 #[test]
1341 fn message_part_image_roundtrip_json() {
1342 let part = MessagePart::Image(Box::new(ImageData {
1343 data: vec![1, 2, 3, 4],
1344 mime_type: "image/jpeg".into(),
1345 }));
1346 let json = serde_json::to_string(&part).unwrap();
1347 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1348 match decoded {
1349 MessagePart::Image(img) => {
1350 assert_eq!(img.data, vec![1, 2, 3, 4]);
1351 assert_eq!(img.mime_type, "image/jpeg");
1352 }
1353 _ => panic!("expected Image variant"),
1354 }
1355 }
1356
1357 #[test]
1358 fn flatten_parts_includes_image_placeholder() {
1359 let msg = Message::from_parts(
1360 Role::User,
1361 vec![
1362 MessagePart::Text {
1363 text: "see this".into(),
1364 },
1365 MessagePart::Image(Box::new(ImageData {
1366 data: vec![0u8; 100],
1367 mime_type: "image/png".into(),
1368 })),
1369 ],
1370 );
1371 let content = msg.to_llm_content();
1372 assert!(content.contains("see this"));
1373 assert!(content.contains("[image: image/png"));
1374 }
1375
1376 #[test]
1377 fn supports_vision_default_false() {
1378 let provider = StubProvider {
1379 response: String::new(),
1380 };
1381 assert!(!provider.supports_vision());
1382 }
1383
1384 #[test]
1385 fn message_metadata_default_both_visible() {
1386 let m = MessageMetadata::default();
1387 assert!(m.agent_visible);
1388 assert!(m.user_visible);
1389 assert!(m.compacted_at.is_none());
1390 }
1391
1392 #[test]
1393 fn message_metadata_agent_only() {
1394 let m = MessageMetadata::agent_only();
1395 assert!(m.agent_visible);
1396 assert!(!m.user_visible);
1397 }
1398
1399 #[test]
1400 fn message_metadata_user_only() {
1401 let m = MessageMetadata::user_only();
1402 assert!(!m.agent_visible);
1403 assert!(m.user_visible);
1404 }
1405
1406 #[test]
1407 fn message_metadata_serde_default() {
1408 let json = r#"{"role":"user","content":"hello"}"#;
1409 let msg: Message = serde_json::from_str(json).unwrap();
1410 assert!(msg.metadata.agent_visible);
1411 assert!(msg.metadata.user_visible);
1412 }
1413
1414 #[test]
1415 fn message_metadata_round_trip() {
1416 let msg = Message {
1417 role: Role::User,
1418 content: "test".into(),
1419 parts: vec![],
1420 metadata: MessageMetadata::agent_only(),
1421 };
1422 let json = serde_json::to_string(&msg).unwrap();
1423 let decoded: Message = serde_json::from_str(&json).unwrap();
1424 assert!(decoded.metadata.agent_visible);
1425 assert!(!decoded.metadata.user_visible);
1426 }
1427
1428 #[test]
1429 fn message_part_compaction_round_trip() {
1430 let part = MessagePart::Compaction {
1431 summary: "Context was summarized.".to_owned(),
1432 };
1433 let json = serde_json::to_string(&part).unwrap();
1434 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1435 assert!(
1436 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1437 );
1438 }
1439
1440 #[test]
1441 fn flatten_parts_compaction_contributes_no_text() {
1442 let parts = vec![
1445 MessagePart::Text {
1446 text: "Hello".to_owned(),
1447 },
1448 MessagePart::Compaction {
1449 summary: "Summary".to_owned(),
1450 },
1451 ];
1452 let msg = Message::from_parts(Role::Assistant, parts);
1453 assert_eq!(msg.content.trim(), "Hello");
1455 }
1456
1457 #[test]
1458 fn stream_chunk_compaction_variant() {
1459 let chunk = StreamChunk::Compaction("A summary".to_owned());
1460 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1461 }
1462
1463 #[test]
1464 fn short_type_name_extracts_last_segment() {
1465 struct MyOutput;
1466 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1467 }
1468
1469 #[test]
1470 fn short_type_name_primitive_returns_full_name() {
1471 assert_eq!(short_type_name::<u32>(), "u32");
1473 assert_eq!(short_type_name::<bool>(), "bool");
1474 }
1475
1476 #[test]
1477 fn short_type_name_nested_path_returns_last() {
1478 assert_eq!(
1480 short_type_name::<std::collections::HashMap<u32, u32>>(),
1481 "HashMap<u32, u32>"
1482 );
1483 }
1484
1485 #[test]
1488 fn summary_roundtrip() {
1489 let part = MessagePart::Summary {
1490 text: "hello".to_string(),
1491 };
1492 let json = serde_json::to_string(&part).expect("serialization must not fail");
1493 assert!(
1494 json.contains("\"kind\":\"summary\""),
1495 "must use internally-tagged format, got: {json}"
1496 );
1497 assert!(
1498 !json.contains("\"Summary\""),
1499 "must not use externally-tagged format, got: {json}"
1500 );
1501 let decoded: MessagePart =
1502 serde_json::from_str(&json).expect("deserialization must not fail");
1503 match decoded {
1504 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1505 other => panic!("expected MessagePart::Summary, got {other:?}"),
1506 }
1507 }
1508
1509 #[tokio::test]
1510 async fn embed_batch_default_empty_returns_empty() {
1511 let provider = StubProvider {
1512 response: String::new(),
1513 };
1514 let result = provider.embed_batch(&[]).await.unwrap();
1515 assert!(result.is_empty());
1516 }
1517
1518 #[tokio::test]
1519 async fn embed_batch_default_calls_embed_sequentially() {
1520 let provider = StubProvider {
1521 response: String::new(),
1522 };
1523 let texts = ["hello", "world", "foo"];
1524 let result = provider.embed_batch(&texts).await.unwrap();
1525 assert_eq!(result.len(), 3);
1526 for vec in &result {
1528 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1529 }
1530 }
1531}