1mod custom_message;
7pub mod message_codec;
8mod model;
9
10pub use custom_message::*;
11pub use message_codec::{
12 MessageSlot, SerializedCustomMessage, SerializedMessages, clone_messages_for_send,
13 restore_messages, restore_single_custom, serialize_messages,
14};
15pub use model::*;
16
17use std::collections::HashMap;
18use std::fmt;
19use std::ops::{Add, AddAssign};
20use std::sync::Arc;
21
22use serde::{Deserialize, Serialize};
23
24#[non_exhaustive]
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(tag = "type", rename_all = "snake_case")]
36pub enum ContentBlock {
37 Text { text: String },
39
40 Thinking {
42 thinking: String,
43 signature: Option<String>,
44 },
45
46 ToolCall {
49 id: String,
50 name: String,
51 arguments: serde_json::Value,
52 partial_json: Option<String>,
53 },
54
55 Image { source: ImageSource },
57
58 Extension {
62 type_name: String,
63 data: serde_json::Value,
64 },
65}
66
67impl ContentBlock {
68 pub fn extract_text(blocks: &[Self]) -> String {
72 let mut result = String::new();
73 for block in blocks {
74 if let Self::Text { text } = block {
75 result.push_str(text);
76 }
77 }
78 result
79 }
80}
81
82#[non_exhaustive]
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(tag = "type", rename_all = "snake_case")]
86pub enum ImageSource {
87 Base64 { media_type: String, data: String },
89
90 Url { url: String, media_type: String },
92
93 File {
95 path: std::path::PathBuf,
96 media_type: String,
97 },
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104pub struct UserMessage {
105 pub content: Vec<ContentBlock>,
106 pub timestamp: u64,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub cache_hint: Option<crate::context_cache::CacheHint>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct AssistantMessage {
115 pub content: Vec<ContentBlock>,
116 pub provider: String,
117 pub model_id: String,
118 pub usage: Usage,
119 pub cost: Cost,
120 pub stop_reason: StopReason,
121 pub error_message: Option<String>,
122 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub error_kind: Option<crate::stream_error_kind::StreamErrorKind>,
128 pub timestamp: u64,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
131 pub cache_hint: Option<crate::context_cache::CacheHint>,
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct ToolResultMessage {
137 pub tool_call_id: String,
138 pub content: Vec<ContentBlock>,
139 pub is_error: bool,
140 pub timestamp: u64,
141 #[serde(default)]
143 pub details: serde_json::Value,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub cache_hint: Option<crate::context_cache::CacheHint>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151#[serde(tag = "role", rename_all = "snake_case")]
152pub enum LlmMessage {
153 User(UserMessage),
154 Assistant(AssistantMessage),
155 ToolResult(ToolResultMessage),
156}
157
158#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
162pub struct Usage {
163 pub input: u64,
164 pub output: u64,
165 pub cache_read: u64,
166 pub cache_write: u64,
167 pub total: u64,
168 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
170 pub extra: HashMap<String, u64>,
171}
172
173impl Usage {
174 pub fn merge(&mut self, other: &Self) {
176 *self += other.clone();
177 }
178}
179
180impl Add for Usage {
181 type Output = Self;
182
183 fn add(mut self, rhs: Self) -> Self::Output {
184 self += rhs;
185 self
186 }
187}
188
189impl AddAssign for Usage {
190 fn add_assign(&mut self, rhs: Self) {
191 self.input += rhs.input;
192 self.output += rhs.output;
193 self.cache_read += rhs.cache_read;
194 self.cache_write += rhs.cache_write;
195 self.total += rhs.total;
196 for (k, v) in rhs.extra {
197 *self.extra.entry(k).or_insert(0) += v;
198 }
199 }
200}
201
202#[derive(Debug, Clone, Default, Serialize, Deserialize)]
204pub struct Cost {
205 pub input: f64,
206 pub output: f64,
207 pub cache_read: f64,
208 pub cache_write: f64,
209 pub total: f64,
210 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
212 pub extra: HashMap<String, f64>,
213}
214
215impl Add for Cost {
216 type Output = Self;
217
218 fn add(mut self, rhs: Self) -> Self::Output {
219 self += rhs;
220 self
221 }
222}
223
224impl AddAssign for Cost {
225 fn add_assign(&mut self, rhs: Self) {
226 self.input += rhs.input;
227 self.output += rhs.output;
228 self.cache_read += rhs.cache_read;
229 self.cache_write += rhs.cache_write;
230 self.total += rhs.total;
231 for (k, v) in rhs.extra {
232 *self.extra.entry(k).or_insert(0.0) += v;
233 }
234 }
235}
236
237#[non_exhaustive]
241#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
242#[serde(rename_all = "snake_case")]
243pub enum StopReason {
244 Stop,
246 Length,
248 ToolUse,
250 Aborted,
252 Error,
254 Transfer,
256}
257
258pub struct AgentResult {
262 pub messages: Vec<AgentMessage>,
264 pub stop_reason: StopReason,
266 pub usage: Usage,
268 pub cost: Cost,
270 pub error: Option<String>,
272 pub transfer_signal: Option<crate::transfer::TransferSignal>,
274}
275
276impl AgentResult {
277 pub fn assistant_text(&self) -> String {
283 self.messages
284 .iter()
285 .rev()
286 .find_map(|msg| match msg {
287 AgentMessage::Llm(LlmMessage::Assistant(a)) => Some(a),
288 _ => None,
289 })
290 .map(|a| ContentBlock::extract_text(&a.content))
291 .unwrap_or_default()
292 }
293}
294
295impl fmt::Debug for AgentResult {
296 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 f.debug_struct("AgentResult")
298 .field("messages", &self.messages)
299 .field("stop_reason", &self.stop_reason)
300 .field("usage", &self.usage)
301 .field("cost", &self.cost)
302 .field("error", &self.error)
303 .field("transfer_signal", &self.transfer_signal)
304 .finish()
305 }
306}
307
308pub struct AgentContext {
316 pub system_prompt: String,
317 pub messages: Vec<AgentMessage>,
318 pub tools: Vec<Arc<dyn crate::tool::AgentTool>>,
320}
321
322impl fmt::Debug for AgentContext {
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 f.debug_struct("AgentContext")
325 .field("system_prompt", &self.system_prompt)
326 .field("messages", &self.messages)
327 .field("tools", &format_args!("[{} tool(s)]", self.tools.len()))
328 .finish()
329 }
330}
331
332fn serialize_arc_vec<S, T>(value: &Arc<Vec<T>>, serializer: S) -> Result<S::Ok, S::Error>
335where
336 S: serde::Serializer,
337 T: Serialize,
338{
339 value.as_ref().serialize(serializer)
340}
341
342fn deserialize_arc_vec<'de, D, T>(deserializer: D) -> Result<Arc<Vec<T>>, D::Error>
343where
344 D: serde::Deserializer<'de>,
345 T: Deserialize<'de>,
346{
347 let v = Vec::<T>::deserialize(deserializer)?;
348 Ok(Arc::new(v))
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct TurnSnapshot {
359 pub turn_index: usize,
361 #[serde(
366 serialize_with = "serialize_arc_vec",
367 deserialize_with = "deserialize_arc_vec"
368 )]
369 pub messages: Arc<Vec<LlmMessage>>,
370 pub usage: Usage,
372 pub cost: Cost,
374 pub stop_reason: StopReason,
376 #[serde(default, skip_serializing_if = "Option::is_none")]
378 pub state_delta: Option<crate::StateDelta>,
379}
380
381const _: () = {
384 const fn assert_send_sync<T: Send + Sync>() {}
385
386 assert_send_sync::<ContentBlock>();
387 assert_send_sync::<ImageSource>();
388 assert_send_sync::<UserMessage>();
389 assert_send_sync::<AssistantMessage>();
390 assert_send_sync::<ToolResultMessage>();
391 assert_send_sync::<LlmMessage>();
392 assert_send_sync::<AgentMessage>();
393 assert_send_sync::<Usage>();
394 assert_send_sync::<Cost>();
395 assert_send_sync::<StopReason>();
396 assert_send_sync::<ThinkingLevel>();
397 assert_send_sync::<ThinkingBudgets>();
398 assert_send_sync::<ModelCapabilities>();
399 assert_send_sync::<ModelSpec>();
400 assert_send_sync::<AgentResult>();
401 assert_send_sync::<AgentContext>();
402 assert_send_sync::<TurnSnapshot>();
403 assert_send_sync::<CustomMessageRegistry>();
404 assert_send_sync::<crate::error::DowncastError>();
405};
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn content_block_extension_serde_roundtrip() {
413 let block = ContentBlock::Extension {
414 type_name: "audio_clip".into(),
415 data: serde_json::json!({"duration_ms": 1500, "codec": "opus"}),
416 };
417 let json = serde_json::to_string(&block).unwrap();
418 let parsed: ContentBlock = serde_json::from_str(&json).unwrap();
419 assert_eq!(block, parsed);
420 }
421
422 #[test]
423 fn extract_text_ignores_extension() {
424 let blocks = vec![
425 ContentBlock::Text {
426 text: "hello ".into(),
427 },
428 ContentBlock::Extension {
429 type_name: "image".into(),
430 data: serde_json::json!({"url": "https://example.com/img.png"}),
431 },
432 ContentBlock::Text {
433 text: "world".into(),
434 },
435 ];
436 assert_eq!(ContentBlock::extract_text(&blocks), "hello world");
437 }
438
439 #[test]
440 fn usage_extra_add_merges_maps() {
441 let a = Usage {
442 input: 10,
443 output: 5,
444 extra: HashMap::from([
445 ("reasoning_tokens".into(), 100),
446 ("search_tokens".into(), 50),
447 ]),
448 ..Default::default()
449 };
450 let b = Usage {
451 input: 20,
452 output: 10,
453 extra: HashMap::from([("reasoning_tokens".into(), 200), ("new_metric".into(), 30)]),
454 ..Default::default()
455 };
456 let c = a + b;
457 assert_eq!(c.input, 30);
458 assert_eq!(c.output, 15);
459 assert_eq!(c.extra["reasoning_tokens"], 300);
460 assert_eq!(c.extra["search_tokens"], 50);
461 assert_eq!(c.extra["new_metric"], 30);
462 }
463
464 #[test]
465 fn cost_extra_add_merges_maps() {
466 let a = Cost {
467 input: 0.01,
468 output: 0.02,
469 extra: HashMap::from([("reasoning_cost".into(), 0.05)]),
470 ..Default::default()
471 };
472 let b = Cost {
473 input: 0.03,
474 output: 0.04,
475 extra: HashMap::from([
476 ("reasoning_cost".into(), 0.10),
477 ("search_cost".into(), 0.02),
478 ]),
479 ..Default::default()
480 };
481 let c = a + b;
482 assert!((c.input - 0.04).abs() < f64::EPSILON);
483 assert!((c.output - 0.06).abs() < f64::EPSILON);
484 assert!((c.extra["reasoning_cost"] - 0.15).abs() < f64::EPSILON);
485 assert!((c.extra["search_cost"] - 0.02).abs() < f64::EPSILON);
486 }
487
488 #[test]
489 fn model_spec_with_provider_config() {
490 let config = serde_json::json!({
491 "temperature": 0.7,
492 "top_p": 0.9,
493 });
494
495 let spec = ModelSpec::new("anthropic", "claude-3").with_provider_config(config.clone());
496
497 assert_eq!(spec.provider_config, Some(config));
498 assert_eq!(spec.provider, "anthropic");
499 assert_eq!(spec.model_id, "claude-3");
500 }
501
502 #[test]
503 fn provider_config_as_typed() {
504 #[derive(Debug, Deserialize, PartialEq)]
505 struct MyConfig {
506 temperature: f64,
507 max_tokens: u32,
508 }
509
510 let spec = ModelSpec::new("openai", "gpt-4").with_provider_config(serde_json::json!({
511 "temperature": 0.5,
512 "max_tokens": 1024,
513 }));
514
515 let config: Option<MyConfig> = spec.provider_config_as();
516 assert_eq!(
517 config,
518 Some(MyConfig {
519 temperature: 0.5,
520 max_tokens: 1024,
521 })
522 );
523
524 let spec_none = ModelSpec::new("openai", "gpt-4");
526 let config_none: Option<MyConfig> = spec_none.provider_config_as();
527 assert!(config_none.is_none());
528 }
529
530 #[test]
531 fn model_capabilities_builder_chain() {
532 let caps = ModelCapabilities::none()
533 .with_thinking(true)
534 .with_vision(true)
535 .with_tool_use(true)
536 .with_streaming(true)
537 .with_structured_output(true)
538 .with_max_context_window(200_000)
539 .with_max_output_tokens(16384);
540
541 assert!(caps.supports_thinking);
542 assert!(caps.supports_vision);
543 assert!(caps.supports_tool_use);
544 assert!(caps.supports_streaming);
545 assert!(caps.supports_structured_output);
546 assert_eq!(caps.max_context_window, Some(200_000));
547 assert_eq!(caps.max_output_tokens, Some(16384));
548 }
549
550 #[test]
551 fn model_capabilities_serde_roundtrip() {
552 let caps = ModelCapabilities::none()
553 .with_thinking(true)
554 .with_tool_use(true)
555 .with_max_context_window(128_000);
556 let json = serde_json::to_string(&caps).unwrap();
557 let parsed: ModelCapabilities = serde_json::from_str(&json).unwrap();
558 assert_eq!(caps, parsed);
559 }
560
561 #[test]
562 fn model_spec_with_capabilities() {
563 let caps = ModelCapabilities::none()
564 .with_thinking(true)
565 .with_streaming(true);
566 let spec = ModelSpec::new("test", "model-1").with_capabilities(caps.clone());
567 assert_eq!(spec.capabilities, Some(caps.clone()));
568 assert_eq!(spec.capabilities(), caps);
569 }
570
571 #[test]
572 fn model_spec_capabilities_defaults_when_none() {
573 let spec = ModelSpec::new("test", "model-1");
574 assert!(spec.capabilities.is_none());
575 let caps = spec.capabilities();
576 assert!(!caps.supports_thinking);
577 assert_eq!(caps.max_context_window, None);
578 }
579
580 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
583 struct MockNotification {
584 title: String,
585 body: String,
586 }
587
588 impl CustomMessage for MockNotification {
589 fn as_any(&self) -> &dyn std::any::Any {
590 self
591 }
592
593 fn type_name(&self) -> Option<&str> {
594 Some("mock_notification")
595 }
596
597 fn to_json(&self) -> Option<serde_json::Value> {
598 serde_json::to_value(self).ok()
599 }
600 }
601
602 #[test]
603 fn custom_message_serialize_roundtrip() {
604 let msg = MockNotification {
605 title: "Hello".into(),
606 body: "World".into(),
607 };
608
609 let envelope = serialize_custom_message(&msg).expect("serialization supported");
610 assert_eq!(envelope["type"], "mock_notification");
611 assert_eq!(envelope["data"]["title"], "Hello");
612
613 let mut registry = CustomMessageRegistry::new();
614 registry.register_type::<MockNotification>("mock_notification");
615
616 let restored = deserialize_custom_message(®istry, &envelope).unwrap();
617 let downcasted = restored
618 .as_any()
619 .downcast_ref::<MockNotification>()
620 .unwrap();
621 assert_eq!(downcasted, &msg);
622 }
623
624 #[test]
625 fn custom_message_default_returns_none() {
626 #[derive(Debug)]
627 struct Bare;
628 impl CustomMessage for Bare {
629 fn as_any(&self) -> &dyn std::any::Any {
630 self
631 }
632 }
633 let bare = Bare;
634 assert!(bare.type_name().is_none());
635 assert!(bare.to_json().is_none());
636 assert!(serialize_custom_message(&bare).is_none());
637 }
638
639 #[test]
640 fn registry_unknown_type_returns_error() {
641 let registry = CustomMessageRegistry::new();
642 let envelope = serde_json::json!({"type": "unknown", "data": {}});
643 let result = deserialize_custom_message(®istry, &envelope);
644 assert!(result.is_err());
645 assert!(result.unwrap_err().contains("no deserializer registered"));
646 }
647
648 #[test]
649 fn registry_contains_check() {
650 let mut registry = CustomMessageRegistry::new();
651 assert!(!registry.has_type_name("mock_notification"));
652 registry.register_type::<MockNotification>("mock_notification");
653 assert!(registry.has_type_name("mock_notification"));
654 }
655
656 #[test]
657 fn assistant_text_extracts_last_assistant_message() {
658 let result = AgentResult {
659 messages: vec![
660 AgentMessage::Llm(LlmMessage::User(UserMessage {
661 content: vec![ContentBlock::Text {
662 text: "hi".to_string(),
663 }],
664 timestamp: 0,
665 cache_hint: None,
666 })),
667 AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
668 content: vec![ContentBlock::Text {
669 text: "first".to_string(),
670 }],
671 provider: "test".to_string(),
672 model_id: "m".to_string(),
673 usage: Usage::default(),
674 cost: Cost::default(),
675 stop_reason: StopReason::Stop,
676 error_message: None,
677 error_kind: None,
678 timestamp: 0,
679 cache_hint: None,
680 })),
681 AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
682 content: vec![ContentBlock::Text {
683 text: "second".to_string(),
684 }],
685 provider: "test".to_string(),
686 model_id: "m".to_string(),
687 usage: Usage::default(),
688 cost: Cost::default(),
689 stop_reason: StopReason::Stop,
690 error_message: None,
691 error_kind: None,
692 timestamp: 0,
693 cache_hint: None,
694 })),
695 ],
696 stop_reason: StopReason::Stop,
697 usage: Usage::default(),
698 cost: Cost::default(),
699 error: None,
700 transfer_signal: None,
701 };
702 assert_eq!(result.assistant_text(), "second");
703 }
704
705 #[test]
706 fn assistant_text_returns_empty_when_no_assistant() {
707 let result = AgentResult {
708 messages: vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
709 content: vec![ContentBlock::Text {
710 text: "hi".to_string(),
711 }],
712 timestamp: 0,
713 cache_hint: None,
714 }))],
715 stop_reason: StopReason::Stop,
716 usage: Usage::default(),
717 cost: Cost::default(),
718 error: None,
719 transfer_signal: None,
720 };
721 assert_eq!(result.assistant_text(), "");
722 }
723
724 #[test]
725 fn assistant_text_returns_empty_when_no_messages() {
726 let result = AgentResult {
727 messages: vec![],
728 stop_reason: StopReason::Stop,
729 usage: Usage::default(),
730 cost: Cost::default(),
731 error: None,
732 transfer_signal: None,
733 };
734 assert_eq!(result.assistant_text(), "");
735 }
736
737 #[test]
738 fn deserialize_custom_message_missing_fields() {
739 let registry = CustomMessageRegistry::new();
740
741 let no_type = serde_json::json!({"data": {}});
742 assert!(
743 deserialize_custom_message(®istry, &no_type)
744 .unwrap_err()
745 .contains("missing 'type'")
746 );
747
748 let no_data = serde_json::json!({"type": "foo"});
749 assert!(
750 deserialize_custom_message(®istry, &no_data)
751 .unwrap_err()
752 .contains("missing 'data'")
753 );
754 }
755}