1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use crate::error::LlmError;
16
17static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
18 LazyLock::new(|| Mutex::new(HashMap::new()));
19
20pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
26-> Result<(serde_json::Value, String), crate::LlmError> {
27 let type_id = TypeId::of::<T>();
28 if let Ok(cache) = SCHEMA_CACHE.lock()
29 && let Some(entry) = cache.get(&type_id)
30 {
31 return Ok(entry.clone());
32 }
33 let schema = schemars::schema_for!(T);
34 let value = serde_json::to_value(&schema)
35 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
36 let pretty = serde_json::to_string_pretty(&schema)
37 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
38 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
39 cache.insert(type_id, (value.clone(), pretty.clone()));
40 }
41 Ok((value, pretty))
42}
43
44pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
58 std::any::type_name::<T>()
59 .rsplit("::")
60 .next()
61 .unwrap_or("Output")
62}
63
64#[derive(Debug, Clone)]
66pub enum StreamChunk {
67 Content(String),
69 Thinking(String),
71 Compaction(String),
74 ToolUse(Vec<ToolUseRequest>),
76}
77
78pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ToolDefinition {
86 pub name: String,
87 pub description: String,
88 pub parameters: serde_json::Value,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ToolUseRequest {
95 pub id: String,
96 pub name: String,
97 pub input: serde_json::Value,
98}
99
100#[derive(Debug, Clone)]
102pub enum ThinkingBlock {
103 Thinking { thinking: String, signature: String },
104 Redacted { data: String },
105}
106
107pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
110
111#[derive(Debug, Clone)]
113pub enum ChatResponse {
114 Text(String),
116 ToolUse {
118 text: Option<String>,
120 tool_calls: Vec<ToolUseRequest>,
121 thinking_blocks: Vec<ThinkingBlock>,
124 },
125}
126
127pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
129
130pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
132
133pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
135
136#[must_use]
139pub fn default_debug_request_json(
140 messages: &[Message],
141 tools: &[ToolDefinition],
142) -> serde_json::Value {
143 serde_json::json!({
144 "model": serde_json::Value::Null,
145 "max_tokens": serde_json::Value::Null,
146 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
147 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
148 "temperature": serde_json::Value::Null,
149 "cache_control": serde_json::Value::Null,
150 })
151}
152
153#[derive(Debug, Clone, Default)]
158pub struct GenerationOverrides {
159 pub temperature: Option<f64>,
160 pub top_p: Option<f64>,
161 pub top_k: Option<usize>,
162 pub frequency_penalty: Option<f64>,
163 pub presence_penalty: Option<f64>,
164}
165
166#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
167#[serde(rename_all = "lowercase")]
168pub enum Role {
169 System,
170 User,
171 Assistant,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
175#[serde(tag = "kind", rename_all = "snake_case")]
176pub enum MessagePart {
177 Text {
178 text: String,
179 },
180 ToolOutput {
181 tool_name: String,
182 body: String,
183 #[serde(default, skip_serializing_if = "Option::is_none")]
184 compacted_at: Option<i64>,
185 },
186 Recall {
187 text: String,
188 },
189 CodeContext {
190 text: String,
191 },
192 Summary {
193 text: String,
194 },
195 CrossSession {
196 text: String,
197 },
198 ToolUse {
199 id: String,
200 name: String,
201 input: serde_json::Value,
202 },
203 ToolResult {
204 tool_use_id: String,
205 content: String,
206 #[serde(default)]
207 is_error: bool,
208 },
209 Image(Box<ImageData>),
210 ThinkingBlock {
212 thinking: String,
213 signature: String,
214 },
215 RedactedThinkingBlock {
217 data: String,
218 },
219 Compaction {
222 summary: String,
223 },
224}
225
226impl MessagePart {
227 #[must_use]
230 pub fn as_plain_text(&self) -> Option<&str> {
231 match self {
232 Self::Text { text }
233 | Self::Recall { text }
234 | Self::CodeContext { text }
235 | Self::Summary { text }
236 | Self::CrossSession { text } => Some(text.as_str()),
237 _ => None,
238 }
239 }
240
241 #[must_use]
243 pub fn as_image(&self) -> Option<&ImageData> {
244 if let Self::Image(img) = self {
245 Some(img)
246 } else {
247 None
248 }
249 }
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct ImageData {
254 #[serde(with = "serde_bytes_base64")]
255 pub data: Vec<u8>,
256 pub mime_type: String,
257}
258
259mod serde_bytes_base64 {
260 use base64::{Engine, engine::general_purpose::STANDARD};
261 use serde::{Deserialize, Deserializer, Serializer};
262
263 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
264 where
265 S: Serializer,
266 {
267 s.serialize_str(&STANDARD.encode(bytes))
268 }
269
270 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
271 where
272 D: Deserializer<'de>,
273 {
274 let s = String::deserialize(d)?;
275 STANDARD.decode(&s).map_err(serde::de::Error::custom)
276 }
277}
278
279#[derive(Clone, Debug, Serialize, Deserialize)]
281pub struct MessageMetadata {
282 pub agent_visible: bool,
283 pub user_visible: bool,
284 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub compacted_at: Option<i64>,
286 #[serde(default, skip_serializing_if = "Option::is_none")]
289 pub deferred_summary: Option<String>,
290 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
293 pub focus_pinned: bool,
294 #[serde(default, skip_serializing_if = "Option::is_none")]
297 pub focus_marker_id: Option<uuid::Uuid>,
298 #[serde(skip)]
301 pub db_id: Option<i64>,
302}
303
304impl Default for MessageMetadata {
305 fn default() -> Self {
306 Self {
307 agent_visible: true,
308 user_visible: true,
309 compacted_at: None,
310 deferred_summary: None,
311 focus_pinned: false,
312 focus_marker_id: None,
313 db_id: None,
314 }
315 }
316}
317
318impl MessageMetadata {
319 #[must_use]
321 pub fn agent_only() -> Self {
322 Self {
323 agent_visible: true,
324 user_visible: false,
325 compacted_at: None,
326 deferred_summary: None,
327 focus_pinned: false,
328 focus_marker_id: None,
329 db_id: None,
330 }
331 }
332
333 #[must_use]
335 pub fn user_only() -> Self {
336 Self {
337 agent_visible: false,
338 user_visible: true,
339 compacted_at: None,
340 deferred_summary: None,
341 focus_pinned: false,
342 focus_marker_id: None,
343 db_id: None,
344 }
345 }
346
347 #[must_use]
349 pub fn focus_pinned() -> Self {
350 Self {
351 agent_visible: true,
352 user_visible: false,
353 compacted_at: None,
354 deferred_summary: None,
355 focus_pinned: true,
356 focus_marker_id: None,
357 db_id: None,
358 }
359 }
360}
361
362#[derive(Clone, Debug, Serialize, Deserialize)]
363pub struct Message {
364 pub role: Role,
365 pub content: String,
366 #[serde(default)]
367 pub parts: Vec<MessagePart>,
368 #[serde(default)]
369 pub metadata: MessageMetadata,
370}
371
372impl Default for Message {
373 fn default() -> Self {
374 Self {
375 role: Role::User,
376 content: String::new(),
377 parts: vec![],
378 metadata: MessageMetadata::default(),
379 }
380 }
381}
382
383impl Message {
384 #[must_use]
385 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
386 Self {
387 role,
388 content: content.into(),
389 parts: vec![],
390 metadata: MessageMetadata::default(),
391 }
392 }
393
394 #[must_use]
395 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
396 let content = Self::flatten_parts(&parts);
397 Self {
398 role,
399 content,
400 parts,
401 metadata: MessageMetadata::default(),
402 }
403 }
404
405 #[must_use]
406 pub fn to_llm_content(&self) -> &str {
407 &self.content
408 }
409
410 pub fn rebuild_content(&mut self) {
412 if !self.parts.is_empty() {
413 self.content = Self::flatten_parts(&self.parts);
414 }
415 }
416
417 fn flatten_parts(parts: &[MessagePart]) -> String {
418 use std::fmt::Write;
419 let mut out = String::new();
420 for part in parts {
421 match part {
422 MessagePart::Text { text }
423 | MessagePart::Recall { text }
424 | MessagePart::CodeContext { text }
425 | MessagePart::Summary { text }
426 | MessagePart::CrossSession { text } => out.push_str(text),
427 MessagePart::ToolOutput {
428 tool_name,
429 body,
430 compacted_at,
431 } => {
432 if compacted_at.is_some() {
433 if body.is_empty() {
434 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
435 } else {
436 let _ = write!(out, "[tool output: {tool_name}] {body}");
437 }
438 } else {
439 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
440 }
441 }
442 MessagePart::ToolUse { id, name, .. } => {
443 let _ = write!(out, "[tool_use: {name}({id})]");
444 }
445 MessagePart::ToolResult {
446 tool_use_id,
447 content,
448 ..
449 } => {
450 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
451 }
452 MessagePart::Image(img) => {
453 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
454 }
455 MessagePart::ThinkingBlock { .. }
457 | MessagePart::RedactedThinkingBlock { .. }
458 | MessagePart::Compaction { .. } => {}
459 }
460 }
461 out
462 }
463}
464
465pub trait LlmProvider: Send + Sync {
466 fn context_window(&self) -> Option<usize> {
470 None
471 }
472
473 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
479
480 fn chat_stream(
486 &self,
487 messages: &[Message],
488 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
489
490 fn supports_streaming(&self) -> bool;
492
493 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
499
500 fn supports_embeddings(&self) -> bool;
502
503 fn name(&self) -> &str;
505
506 #[allow(clippy::unnecessary_literal_bound)]
509 fn model_identifier(&self) -> &str {
510 ""
511 }
512
513 fn supports_vision(&self) -> bool {
515 false
516 }
517
518 fn supports_tool_use(&self) -> bool {
520 false
521 }
522
523 #[allow(async_fn_in_trait)]
531 async fn chat_with_tools(
532 &self,
533 messages: &[Message],
534 _tools: &[ToolDefinition],
535 ) -> Result<ChatResponse, LlmError> {
536 Ok(ChatResponse::Text(self.chat(messages).await?))
537 }
538
539 fn last_cache_usage(&self) -> Option<(u64, u64)> {
542 None
543 }
544
545 fn last_usage(&self) -> Option<(u64, u64)> {
548 None
549 }
550
551 fn take_compaction_summary(&self) -> Option<String> {
554 None
555 }
556
557 fn record_quality_outcome(&self, _provider_name: &str, _success: bool) {}
563
564 #[must_use]
568 fn debug_request_json(
569 &self,
570 messages: &[Message],
571 tools: &[ToolDefinition],
572 _stream: bool,
573 ) -> serde_json::Value {
574 default_debug_request_json(messages, tools)
575 }
576
577 fn list_models(&self) -> Vec<String> {
580 vec![]
581 }
582
583 fn supports_structured_output(&self) -> bool {
585 false
586 }
587
588 #[allow(async_fn_in_trait)]
593 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
594 where
595 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
596 Self: Sized,
597 {
598 let (_, schema_json) = cached_schema::<T>()?;
599 let type_name = short_type_name::<T>();
600
601 let mut augmented = messages.to_vec();
602 let instruction = format!(
603 "Respond with a valid JSON object matching this schema. \
604 Output ONLY the JSON, no markdown fences or extra text.\n\n\
605 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
606 );
607 augmented.insert(0, Message::from_legacy(Role::System, instruction));
608
609 let raw = self.chat(&augmented).await?;
610 let cleaned = strip_json_fences(&raw);
611 match serde_json::from_str::<T>(cleaned) {
612 Ok(val) => Ok(val),
613 Err(first_err) => {
614 augmented.push(Message::from_legacy(Role::Assistant, &raw));
615 augmented.push(Message::from_legacy(
616 Role::User,
617 format!(
618 "Your response was not valid JSON. Error: {first_err}. \
619 Please output ONLY valid JSON matching the schema."
620 ),
621 ));
622 let retry_raw = self.chat(&augmented).await?;
623 let retry_cleaned = strip_json_fences(&retry_raw);
624 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
625 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
626 })
627 }
628 }
629 }
630}
631
632fn strip_json_fences(s: &str) -> &str {
636 s.trim()
637 .trim_start_matches("```json")
638 .trim_start_matches("```")
639 .trim_end_matches("```")
640 .trim()
641}
642
643#[cfg(test)]
644mod tests {
645 use tokio_stream::StreamExt;
646
647 use super::*;
648
649 struct StubProvider {
650 response: String,
651 }
652
653 impl LlmProvider for StubProvider {
654 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
655 Ok(self.response.clone())
656 }
657
658 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
659 let response = self.chat(messages).await?;
660 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
661 response,
662 )))))
663 }
664
665 fn supports_streaming(&self) -> bool {
666 false
667 }
668
669 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
670 Ok(vec![0.1, 0.2, 0.3])
671 }
672
673 fn supports_embeddings(&self) -> bool {
674 false
675 }
676
677 fn name(&self) -> &'static str {
678 "stub"
679 }
680 }
681
682 #[test]
683 fn context_window_default_returns_none() {
684 let provider = StubProvider {
685 response: String::new(),
686 };
687 assert!(provider.context_window().is_none());
688 }
689
690 #[test]
691 fn supports_streaming_default_returns_false() {
692 let provider = StubProvider {
693 response: String::new(),
694 };
695 assert!(!provider.supports_streaming());
696 }
697
698 #[tokio::test]
699 async fn chat_stream_default_yields_single_chunk() {
700 let provider = StubProvider {
701 response: "hello world".into(),
702 };
703 let messages = vec![Message {
704 role: Role::User,
705 content: "test".into(),
706 parts: vec![],
707 metadata: MessageMetadata::default(),
708 }];
709
710 let mut stream = provider.chat_stream(&messages).await.unwrap();
711 let chunk = stream.next().await.unwrap().unwrap();
712 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
713 assert!(stream.next().await.is_none());
714 }
715
716 #[tokio::test]
717 async fn chat_stream_default_propagates_chat_error() {
718 struct FailProvider;
719
720 impl LlmProvider for FailProvider {
721 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
722 Err(LlmError::Unavailable)
723 }
724
725 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
726 let response = self.chat(messages).await?;
727 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
728 response,
729 )))))
730 }
731
732 fn supports_streaming(&self) -> bool {
733 false
734 }
735
736 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
737 Err(LlmError::Unavailable)
738 }
739
740 fn supports_embeddings(&self) -> bool {
741 false
742 }
743
744 fn name(&self) -> &'static str {
745 "fail"
746 }
747 }
748
749 let provider = FailProvider;
750 let messages = vec![Message {
751 role: Role::User,
752 content: "test".into(),
753 parts: vec![],
754 metadata: MessageMetadata::default(),
755 }];
756
757 let result = provider.chat_stream(&messages).await;
758 assert!(result.is_err());
759 if let Err(e) = result {
760 assert!(e.to_string().contains("provider unavailable"));
761 }
762 }
763
764 #[tokio::test]
765 async fn stub_provider_embed_returns_vector() {
766 let provider = StubProvider {
767 response: String::new(),
768 };
769 let embedding = provider.embed("test").await.unwrap();
770 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
771 }
772
773 #[tokio::test]
774 async fn fail_provider_embed_propagates_error() {
775 struct FailProvider;
776
777 impl LlmProvider for FailProvider {
778 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
779 Err(LlmError::Unavailable)
780 }
781
782 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
783 let response = self.chat(messages).await?;
784 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
785 response,
786 )))))
787 }
788
789 fn supports_streaming(&self) -> bool {
790 false
791 }
792
793 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
794 Err(LlmError::EmbedUnsupported {
795 provider: "fail".into(),
796 })
797 }
798
799 fn supports_embeddings(&self) -> bool {
800 false
801 }
802
803 fn name(&self) -> &'static str {
804 "fail"
805 }
806 }
807
808 let provider = FailProvider;
809 let result = provider.embed("test").await;
810 assert!(result.is_err());
811 assert!(
812 result
813 .unwrap_err()
814 .to_string()
815 .contains("embedding not supported")
816 );
817 }
818
819 #[test]
820 fn role_serialization() {
821 let system = Role::System;
822 let user = Role::User;
823 let assistant = Role::Assistant;
824
825 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
826 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
827 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
828 }
829
830 #[test]
831 fn role_deserialization() {
832 let system: Role = serde_json::from_str("\"system\"").unwrap();
833 let user: Role = serde_json::from_str("\"user\"").unwrap();
834 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
835
836 assert_eq!(system, Role::System);
837 assert_eq!(user, Role::User);
838 assert_eq!(assistant, Role::Assistant);
839 }
840
841 #[test]
842 fn message_clone() {
843 let msg = Message {
844 role: Role::User,
845 content: "test".into(),
846 parts: vec![],
847 metadata: MessageMetadata::default(),
848 };
849 let cloned = msg.clone();
850 assert_eq!(cloned.role, msg.role);
851 assert_eq!(cloned.content, msg.content);
852 }
853
854 #[test]
855 fn message_debug() {
856 let msg = Message {
857 role: Role::Assistant,
858 content: "response".into(),
859 parts: vec![],
860 metadata: MessageMetadata::default(),
861 };
862 let debug = format!("{msg:?}");
863 assert!(debug.contains("Assistant"));
864 assert!(debug.contains("response"));
865 }
866
867 #[test]
868 fn message_serialization() {
869 let msg = Message {
870 role: Role::User,
871 content: "hello".into(),
872 parts: vec![],
873 metadata: MessageMetadata::default(),
874 };
875 let json = serde_json::to_string(&msg).unwrap();
876 assert!(json.contains("\"role\":\"user\""));
877 assert!(json.contains("\"content\":\"hello\""));
878 }
879
880 #[test]
881 fn message_part_serde_round_trip() {
882 let parts = vec![
883 MessagePart::Text {
884 text: "hello".into(),
885 },
886 MessagePart::ToolOutput {
887 tool_name: "bash".into(),
888 body: "output".into(),
889 compacted_at: None,
890 },
891 MessagePart::Recall {
892 text: "recall".into(),
893 },
894 MessagePart::CodeContext {
895 text: "code".into(),
896 },
897 MessagePart::Summary {
898 text: "summary".into(),
899 },
900 ];
901 let json = serde_json::to_string(&parts).unwrap();
902 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
903 assert_eq!(deserialized.len(), 5);
904 }
905
906 #[test]
907 fn from_legacy_creates_empty_parts() {
908 let msg = Message::from_legacy(Role::User, "hello");
909 assert_eq!(msg.role, Role::User);
910 assert_eq!(msg.content, "hello");
911 assert!(msg.parts.is_empty());
912 assert_eq!(msg.to_llm_content(), "hello");
913 }
914
915 #[test]
916 fn from_parts_flattens_content() {
917 let msg = Message::from_parts(
918 Role::System,
919 vec![MessagePart::Recall {
920 text: "recalled data".into(),
921 }],
922 );
923 assert_eq!(msg.content, "recalled data");
924 assert_eq!(msg.to_llm_content(), "recalled data");
925 assert_eq!(msg.parts.len(), 1);
926 }
927
928 #[test]
929 fn from_parts_tool_output_format() {
930 let msg = Message::from_parts(
931 Role::User,
932 vec![MessagePart::ToolOutput {
933 tool_name: "bash".into(),
934 body: "hello world".into(),
935 compacted_at: None,
936 }],
937 );
938 assert!(msg.content.contains("[tool output: bash]"));
939 assert!(msg.content.contains("hello world"));
940 }
941
942 #[test]
943 fn message_deserializes_without_parts() {
944 let json = r#"{"role":"user","content":"hello"}"#;
945 let msg: Message = serde_json::from_str(json).unwrap();
946 assert_eq!(msg.content, "hello");
947 assert!(msg.parts.is_empty());
948 }
949
950 #[test]
951 fn flatten_skips_compacted_tool_output_empty_body() {
952 let msg = Message::from_parts(
954 Role::User,
955 vec![
956 MessagePart::Text {
957 text: "prefix ".into(),
958 },
959 MessagePart::ToolOutput {
960 tool_name: "bash".into(),
961 body: String::new(),
962 compacted_at: Some(1234),
963 },
964 MessagePart::Text {
965 text: " suffix".into(),
966 },
967 ],
968 );
969 assert!(msg.content.contains("(pruned)"));
970 assert!(msg.content.contains("prefix "));
971 assert!(msg.content.contains(" suffix"));
972 }
973
974 #[test]
975 fn flatten_compacted_tool_output_with_reference_renders_body() {
976 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
978 let msg = Message::from_parts(
979 Role::User,
980 vec![MessagePart::ToolOutput {
981 tool_name: "bash".into(),
982 body: ref_notice.into(),
983 compacted_at: Some(1234),
984 }],
985 );
986 assert!(msg.content.contains(ref_notice));
987 assert!(!msg.content.contains("(pruned)"));
988 }
989
990 #[test]
991 fn rebuild_content_syncs_after_mutation() {
992 let mut msg = Message::from_parts(
993 Role::User,
994 vec![MessagePart::ToolOutput {
995 tool_name: "bash".into(),
996 body: "original".into(),
997 compacted_at: None,
998 }],
999 );
1000 assert!(msg.content.contains("original"));
1001
1002 if let MessagePart::ToolOutput {
1003 ref mut compacted_at,
1004 ref mut body,
1005 ..
1006 } = msg.parts[0]
1007 {
1008 *compacted_at = Some(999);
1009 body.clear(); }
1011 msg.rebuild_content();
1012
1013 assert!(msg.content.contains("(pruned)"));
1014 assert!(!msg.content.contains("original"));
1015 }
1016
1017 #[test]
1018 fn message_part_tool_use_serde_round_trip() {
1019 let part = MessagePart::ToolUse {
1020 id: "toolu_123".into(),
1021 name: "bash".into(),
1022 input: serde_json::json!({"command": "ls"}),
1023 };
1024 let json = serde_json::to_string(&part).unwrap();
1025 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1026 if let MessagePart::ToolUse { id, name, input } = deserialized {
1027 assert_eq!(id, "toolu_123");
1028 assert_eq!(name, "bash");
1029 assert_eq!(input["command"], "ls");
1030 } else {
1031 panic!("expected ToolUse");
1032 }
1033 }
1034
1035 #[test]
1036 fn message_part_tool_result_serde_round_trip() {
1037 let part = MessagePart::ToolResult {
1038 tool_use_id: "toolu_123".into(),
1039 content: "file1.rs\nfile2.rs".into(),
1040 is_error: false,
1041 };
1042 let json = serde_json::to_string(&part).unwrap();
1043 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1044 if let MessagePart::ToolResult {
1045 tool_use_id,
1046 content,
1047 is_error,
1048 } = deserialized
1049 {
1050 assert_eq!(tool_use_id, "toolu_123");
1051 assert_eq!(content, "file1.rs\nfile2.rs");
1052 assert!(!is_error);
1053 } else {
1054 panic!("expected ToolResult");
1055 }
1056 }
1057
1058 #[test]
1059 fn message_part_tool_result_is_error_default() {
1060 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1061 let part: MessagePart = serde_json::from_str(json).unwrap();
1062 if let MessagePart::ToolResult { is_error, .. } = part {
1063 assert!(!is_error);
1064 } else {
1065 panic!("expected ToolResult");
1066 }
1067 }
1068
1069 #[test]
1070 fn chat_response_construction() {
1071 let text = ChatResponse::Text("hello".into());
1072 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1073
1074 let tool_use = ChatResponse::ToolUse {
1075 text: Some("I'll run that".into()),
1076 tool_calls: vec![ToolUseRequest {
1077 id: "1".into(),
1078 name: "bash".into(),
1079 input: serde_json::json!({}),
1080 }],
1081 thinking_blocks: vec![],
1082 };
1083 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1084 }
1085
1086 #[test]
1087 fn flatten_parts_tool_use() {
1088 let msg = Message::from_parts(
1089 Role::Assistant,
1090 vec![MessagePart::ToolUse {
1091 id: "t1".into(),
1092 name: "bash".into(),
1093 input: serde_json::json!({"command": "ls"}),
1094 }],
1095 );
1096 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1097 }
1098
1099 #[test]
1100 fn flatten_parts_tool_result() {
1101 let msg = Message::from_parts(
1102 Role::User,
1103 vec![MessagePart::ToolResult {
1104 tool_use_id: "t1".into(),
1105 content: "output here".into(),
1106 is_error: false,
1107 }],
1108 );
1109 assert!(msg.content.contains("[tool_result: t1]"));
1110 assert!(msg.content.contains("output here"));
1111 }
1112
1113 #[test]
1114 fn tool_definition_serde_round_trip() {
1115 let def = ToolDefinition {
1116 name: "bash".into(),
1117 description: "Execute a shell command".into(),
1118 parameters: serde_json::json!({"type": "object"}),
1119 };
1120 let json = serde_json::to_string(&def).unwrap();
1121 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1122 assert_eq!(deserialized.name, "bash");
1123 assert_eq!(deserialized.description, "Execute a shell command");
1124 }
1125
1126 #[tokio::test]
1127 async fn supports_tool_use_default_returns_false() {
1128 let provider = StubProvider {
1129 response: String::new(),
1130 };
1131 assert!(!provider.supports_tool_use());
1132 }
1133
1134 #[tokio::test]
1135 async fn chat_with_tools_default_delegates_to_chat() {
1136 let provider = StubProvider {
1137 response: "hello".into(),
1138 };
1139 let messages = vec![Message::from_legacy(Role::User, "test")];
1140 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1141 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1142 }
1143
1144 #[test]
1145 fn tool_output_compacted_at_serde_default() {
1146 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1147 let part: MessagePart = serde_json::from_str(json).unwrap();
1148 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1149 assert!(compacted_at.is_none());
1150 } else {
1151 panic!("expected ToolOutput");
1152 }
1153 }
1154
1155 #[test]
1158 fn strip_json_fences_plain_json() {
1159 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1160 }
1161
1162 #[test]
1163 fn strip_json_fences_with_json_fence() {
1164 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1165 }
1166
1167 #[test]
1168 fn strip_json_fences_with_plain_fence() {
1169 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1170 }
1171
1172 #[test]
1173 fn strip_json_fences_whitespace() {
1174 assert_eq!(strip_json_fences(" \n "), "");
1175 }
1176
1177 #[test]
1178 fn strip_json_fences_empty() {
1179 assert_eq!(strip_json_fences(""), "");
1180 }
1181
1182 #[test]
1183 fn strip_json_fences_outer_whitespace() {
1184 assert_eq!(
1185 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1186 r#"{"a": 1}"#
1187 );
1188 }
1189
1190 #[test]
1191 fn strip_json_fences_only_opening_fence() {
1192 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1193 }
1194
1195 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1198 struct TestOutput {
1199 value: String,
1200 }
1201
1202 struct SequentialStub {
1203 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1204 }
1205
1206 impl SequentialStub {
1207 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1208 Self {
1209 responses: std::sync::Mutex::new(responses),
1210 }
1211 }
1212 }
1213
1214 impl LlmProvider for SequentialStub {
1215 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1216 let mut responses = self.responses.lock().unwrap();
1217 if responses.is_empty() {
1218 return Err(LlmError::Other("no more responses".into()));
1219 }
1220 responses.remove(0)
1221 }
1222
1223 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1224 let response = self.chat(messages).await?;
1225 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1226 response,
1227 )))))
1228 }
1229
1230 fn supports_streaming(&self) -> bool {
1231 false
1232 }
1233
1234 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1235 Err(LlmError::EmbedUnsupported {
1236 provider: "sequential-stub".into(),
1237 })
1238 }
1239
1240 fn supports_embeddings(&self) -> bool {
1241 false
1242 }
1243
1244 fn name(&self) -> &'static str {
1245 "sequential-stub"
1246 }
1247 }
1248
1249 #[tokio::test]
1250 async fn chat_typed_happy_path() {
1251 let provider = StubProvider {
1252 response: r#"{"value": "hello"}"#.into(),
1253 };
1254 let messages = vec![Message::from_legacy(Role::User, "test")];
1255 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1256 assert_eq!(
1257 result,
1258 TestOutput {
1259 value: "hello".into()
1260 }
1261 );
1262 }
1263
1264 #[tokio::test]
1265 async fn chat_typed_retry_succeeds() {
1266 let provider = SequentialStub::new(vec![
1267 Ok("not valid json".into()),
1268 Ok(r#"{"value": "ok"}"#.into()),
1269 ]);
1270 let messages = vec![Message::from_legacy(Role::User, "test")];
1271 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1272 assert_eq!(result, TestOutput { value: "ok".into() });
1273 }
1274
1275 #[tokio::test]
1276 async fn chat_typed_both_fail() {
1277 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1278 let messages = vec![Message::from_legacy(Role::User, "test")];
1279 let result = provider.chat_typed::<TestOutput>(&messages).await;
1280 let err = result.unwrap_err();
1281 assert!(err.to_string().contains("parse failed after retry"));
1282 }
1283
1284 #[tokio::test]
1285 async fn chat_typed_chat_error_propagates() {
1286 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1287 let messages = vec![Message::from_legacy(Role::User, "test")];
1288 let result = provider.chat_typed::<TestOutput>(&messages).await;
1289 assert!(matches!(result, Err(LlmError::Unavailable)));
1290 }
1291
1292 #[tokio::test]
1293 async fn chat_typed_strips_fences() {
1294 let provider = StubProvider {
1295 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1296 };
1297 let messages = vec![Message::from_legacy(Role::User, "test")];
1298 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1299 assert_eq!(
1300 result,
1301 TestOutput {
1302 value: "fenced".into()
1303 }
1304 );
1305 }
1306
1307 #[test]
1308 fn supports_structured_output_default_false() {
1309 let provider = StubProvider {
1310 response: String::new(),
1311 };
1312 assert!(!provider.supports_structured_output());
1313 }
1314
1315 #[test]
1316 fn structured_parse_error_display() {
1317 let err = LlmError::StructuredParse("test error".into());
1318 assert_eq!(
1319 err.to_string(),
1320 "structured output parse failed: test error"
1321 );
1322 }
1323
1324 #[test]
1325 fn message_part_image_roundtrip_json() {
1326 let part = MessagePart::Image(Box::new(ImageData {
1327 data: vec![1, 2, 3, 4],
1328 mime_type: "image/jpeg".into(),
1329 }));
1330 let json = serde_json::to_string(&part).unwrap();
1331 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1332 match decoded {
1333 MessagePart::Image(img) => {
1334 assert_eq!(img.data, vec![1, 2, 3, 4]);
1335 assert_eq!(img.mime_type, "image/jpeg");
1336 }
1337 _ => panic!("expected Image variant"),
1338 }
1339 }
1340
1341 #[test]
1342 fn flatten_parts_includes_image_placeholder() {
1343 let msg = Message::from_parts(
1344 Role::User,
1345 vec![
1346 MessagePart::Text {
1347 text: "see this".into(),
1348 },
1349 MessagePart::Image(Box::new(ImageData {
1350 data: vec![0u8; 100],
1351 mime_type: "image/png".into(),
1352 })),
1353 ],
1354 );
1355 let content = msg.to_llm_content();
1356 assert!(content.contains("see this"));
1357 assert!(content.contains("[image: image/png"));
1358 }
1359
1360 #[test]
1361 fn supports_vision_default_false() {
1362 let provider = StubProvider {
1363 response: String::new(),
1364 };
1365 assert!(!provider.supports_vision());
1366 }
1367
1368 #[test]
1369 fn message_metadata_default_both_visible() {
1370 let m = MessageMetadata::default();
1371 assert!(m.agent_visible);
1372 assert!(m.user_visible);
1373 assert!(m.compacted_at.is_none());
1374 }
1375
1376 #[test]
1377 fn message_metadata_agent_only() {
1378 let m = MessageMetadata::agent_only();
1379 assert!(m.agent_visible);
1380 assert!(!m.user_visible);
1381 }
1382
1383 #[test]
1384 fn message_metadata_user_only() {
1385 let m = MessageMetadata::user_only();
1386 assert!(!m.agent_visible);
1387 assert!(m.user_visible);
1388 }
1389
1390 #[test]
1391 fn message_metadata_serde_default() {
1392 let json = r#"{"role":"user","content":"hello"}"#;
1393 let msg: Message = serde_json::from_str(json).unwrap();
1394 assert!(msg.metadata.agent_visible);
1395 assert!(msg.metadata.user_visible);
1396 }
1397
1398 #[test]
1399 fn message_metadata_round_trip() {
1400 let msg = Message {
1401 role: Role::User,
1402 content: "test".into(),
1403 parts: vec![],
1404 metadata: MessageMetadata::agent_only(),
1405 };
1406 let json = serde_json::to_string(&msg).unwrap();
1407 let decoded: Message = serde_json::from_str(&json).unwrap();
1408 assert!(decoded.metadata.agent_visible);
1409 assert!(!decoded.metadata.user_visible);
1410 }
1411
1412 #[test]
1413 fn message_part_compaction_round_trip() {
1414 let part = MessagePart::Compaction {
1415 summary: "Context was summarized.".to_owned(),
1416 };
1417 let json = serde_json::to_string(&part).unwrap();
1418 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1419 assert!(
1420 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1421 );
1422 }
1423
1424 #[test]
1425 fn flatten_parts_compaction_contributes_no_text() {
1426 let parts = vec![
1429 MessagePart::Text {
1430 text: "Hello".to_owned(),
1431 },
1432 MessagePart::Compaction {
1433 summary: "Summary".to_owned(),
1434 },
1435 ];
1436 let msg = Message::from_parts(Role::Assistant, parts);
1437 assert_eq!(msg.content.trim(), "Hello");
1439 }
1440
1441 #[test]
1442 fn stream_chunk_compaction_variant() {
1443 let chunk = StreamChunk::Compaction("A summary".to_owned());
1444 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1445 }
1446
1447 #[test]
1448 fn short_type_name_extracts_last_segment() {
1449 struct MyOutput;
1450 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1451 }
1452
1453 #[test]
1454 fn short_type_name_primitive_returns_full_name() {
1455 assert_eq!(short_type_name::<u32>(), "u32");
1457 assert_eq!(short_type_name::<bool>(), "bool");
1458 }
1459
1460 #[test]
1461 fn short_type_name_nested_path_returns_last() {
1462 assert_eq!(
1464 short_type_name::<std::collections::HashMap<u32, u32>>(),
1465 "HashMap<u32, u32>"
1466 );
1467 }
1468
1469 #[test]
1472 fn summary_roundtrip() {
1473 let part = MessagePart::Summary {
1474 text: "hello".to_string(),
1475 };
1476 let json = serde_json::to_string(&part).expect("serialization must not fail");
1477 assert!(
1478 json.contains("\"kind\":\"summary\""),
1479 "must use internally-tagged format, got: {json}"
1480 );
1481 assert!(
1482 !json.contains("\"Summary\""),
1483 "must not use externally-tagged format, got: {json}"
1484 );
1485 let decoded: MessagePart =
1486 serde_json::from_str(&json).expect("deserialization must not fail");
1487 match decoded {
1488 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1489 other => panic!("expected MessagePart::Summary, got {other:?}"),
1490 }
1491 }
1492}