1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10 System,
11 User,
12 Assistant,
13 Tool,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Message {
19 pub role: Role,
20 pub content: MessageContent,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub name: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub tool_call_id: Option<String>,
25}
26
27impl Message {
28 pub fn system(content: impl Into<String>) -> Self {
29 Self {
30 role: Role::System,
31 content: MessageContent::Text(content.into()),
32 name: None,
33 tool_call_id: None,
34 }
35 }
36
37 pub fn user(content: impl Into<String>) -> Self {
38 Self {
39 role: Role::User,
40 content: MessageContent::Text(content.into()),
41 name: None,
42 tool_call_id: None,
43 }
44 }
45
46 pub fn user_prompt(prompt: impl AsRef<str>) -> Self {
47 Self {
48 role: Role::User,
49 content: MessageContent::from_prompt(prompt.as_ref()),
50 name: None,
51 tool_call_id: None,
52 }
53 }
54
55 pub fn assistant(content: impl Into<String>) -> Self {
56 Self {
57 role: Role::Assistant,
58 content: MessageContent::Text(content.into()),
59 name: None,
60 tool_call_id: None,
61 }
62 }
63
64 pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
65 Self {
66 role: Role::Assistant,
67 content: MessageContent::ToolCalls(tool_calls),
68 name: None,
69 tool_call_id: None,
70 }
71 }
72
73 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
74 Self {
75 role: Role::Tool,
76 content: MessageContent::Text(content.into()),
77 name: None,
78 tool_call_id: Some(tool_call_id.into()),
79 }
80 }
81
82 pub fn with_blocks(role: Role, blocks: Vec<ContentBlock>) -> Self {
83 Self {
84 role,
85 content: MessageContent::Blocks(blocks),
86 name: None,
87 tool_call_id: None,
88 }
89 }
90
91 pub fn estimate_tokens(&self) -> usize {
93 let content_len = match &self.content {
94 MessageContent::Text(s) => s.len(),
95 MessageContent::Blocks(blocks) => blocks.iter().map(|b| b.estimate_chars()).sum(),
96 MessageContent::ToolCalls(calls) => calls.iter().map(|c| c.estimate_chars()).sum(),
97 };
98 content_len.div_ceil(4)
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(untagged)]
105pub enum MessageContent {
106 Text(String),
107 Blocks(Vec<ContentBlock>),
108 ToolCalls(Vec<ToolCall>),
109}
110
111impl MessageContent {
112 pub fn as_text(&self) -> Option<&str> {
113 match self {
114 Self::Text(s) => Some(s),
115 _ => None,
116 }
117 }
118
119 pub fn from_prompt(prompt: &str) -> Self {
120 let references = parse_prompt_image_references(prompt);
121
122 if references.is_empty() {
123 return Self::Text(prompt.to_string());
124 }
125
126 let mut blocks = Vec::new();
127 let mut cursor = 0;
128
129 for reference in references {
130 if reference.start > cursor {
131 let text = &prompt[cursor..reference.start];
132 if !text.is_empty() {
133 blocks.push(ContentBlock::Text {
134 text: text.to_string(),
135 });
136 }
137 }
138
139 blocks.push(ContentBlock::Image {
140 source: ImageSource::FilePath {
141 path: reference.path,
142 },
143 });
144 cursor = reference.end;
145 }
146
147 if cursor < prompt.len() {
148 let text = &prompt[cursor..];
149 if !text.is_empty() {
150 blocks.push(ContentBlock::Text {
151 text: text.to_string(),
152 });
153 }
154 }
155
156 if blocks.is_empty() {
157 Self::Text(prompt.to_string())
158 } else {
159 Self::Blocks(blocks)
160 }
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165struct PromptImageReference {
166 start: usize,
167 end: usize,
168 path: String,
169}
170
171fn parse_prompt_image_references(prompt: &str) -> Vec<PromptImageReference> {
172 let mut refs = Vec::new();
173 let mut cursor = 0;
174
175 while let Some(reference) = find_next_prompt_image_reference(prompt, cursor) {
176 cursor = reference.end;
177 refs.push(reference);
178 }
179
180 refs
181}
182
183fn find_next_prompt_image_reference(prompt: &str, from: usize) -> Option<PromptImageReference> {
184 let markdown = find_markdown_image_reference(prompt, from);
185 let bracket = find_bracket_image_reference(prompt, from);
186
187 match (markdown, bracket) {
188 (Some(m), Some(b)) => {
189 if m.start <= b.start {
190 Some(m)
191 } else {
192 Some(b)
193 }
194 }
195 (Some(m), None) => Some(m),
196 (None, Some(b)) => Some(b),
197 (None, None) => None,
198 }
199}
200
201fn find_markdown_image_reference(prompt: &str, from: usize) -> Option<PromptImageReference> {
202 let mut cursor = from;
203
204 while let Some(relative_start) = prompt[cursor..].find("![") {
205 let start = cursor + relative_start;
206 let after_marker = start + 2;
207
208 let Some(relative_mid) = prompt[after_marker..].find("](") else {
209 cursor = after_marker;
210 continue;
211 };
212
213 let path_start = after_marker + relative_mid + 2;
214 let Some(relative_end) = prompt[path_start..].find(')') else {
215 cursor = path_start;
216 continue;
217 };
218
219 let path_end = path_start + relative_end;
220 let end = path_end + 1;
221 let raw_path = &prompt[path_start..path_end];
222
223 if let Some(path) = normalize_image_path(raw_path) {
224 return Some(PromptImageReference { start, end, path });
225 }
226
227 cursor = end;
228 }
229
230 None
231}
232
233fn find_bracket_image_reference(prompt: &str, from: usize) -> Option<PromptImageReference> {
234 let marker = "[image:";
235 let mut cursor = from;
236
237 while let Some(relative_start) = prompt[cursor..].find(marker) {
238 let start = cursor + relative_start;
239 let path_start = start + marker.len();
240
241 let Some(relative_end) = prompt[path_start..].find(']') else {
242 cursor = path_start;
243 continue;
244 };
245
246 let path_end = path_start + relative_end;
247 let end = path_end + 1;
248 let raw_path = &prompt[path_start..path_end];
249
250 if let Some(path) = normalize_image_path(raw_path) {
251 return Some(PromptImageReference { start, end, path });
252 }
253
254 cursor = end;
255 }
256
257 None
258}
259
260fn normalize_image_path(path: &str) -> Option<String> {
261 let trimmed = path.trim();
262 if trimmed.is_empty() {
263 return None;
264 }
265
266 let unquoted = trimmed
267 .strip_prefix('"')
268 .and_then(|v| v.strip_suffix('"'))
269 .or_else(|| {
270 trimmed
271 .strip_prefix('\'')
272 .and_then(|v| v.strip_suffix('\''))
273 })
274 .unwrap_or(trimmed)
275 .trim();
276
277 if unquoted.is_empty() {
278 None
279 } else {
280 Some(unquoted.to_string())
281 }
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286#[serde(tag = "type", rename_all = "snake_case")]
287pub enum ContentBlock {
288 Text {
289 text: String,
290 },
291 Image {
292 source: ImageSource,
293 },
294 ToolUse {
295 id: String,
296 name: String,
297 input: Value,
298 },
299 ToolResult {
300 tool_use_id: String,
301 content: String,
302 #[serde(default)]
303 is_error: bool,
304 },
305 Thinking {
306 thinking: String,
307 #[serde(skip_serializing_if = "Option::is_none")]
308 signature: Option<String>,
309 },
310}
311
312impl ContentBlock {
313 pub fn text(s: impl Into<String>) -> Self {
314 Self::Text { text: s.into() }
315 }
316
317 pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: Value) -> Self {
318 Self::ToolUse {
319 id: id.into(),
320 name: name.into(),
321 input,
322 }
323 }
324
325 pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
326 Self::ToolResult {
327 tool_use_id: tool_use_id.into(),
328 content: content.into(),
329 is_error: false,
330 }
331 }
332
333 pub fn tool_error(tool_use_id: impl Into<String>, error: impl Into<String>) -> Self {
334 Self::ToolResult {
335 tool_use_id: tool_use_id.into(),
336 content: error.into(),
337 is_error: true,
338 }
339 }
340
341 fn estimate_chars(&self) -> usize {
342 match self {
343 Self::Text { text } => text.len(),
344 Self::Image { .. } => 4000,
345 Self::ToolUse { name, input, .. } => name.len() + input.to_string().len(),
346 Self::ToolResult { content, .. } => content.len(),
347 Self::Thinking { thinking, .. } => thinking.len(),
348 }
349 }
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
354#[serde(tag = "type", rename_all = "snake_case")]
355pub enum ImageSource {
356 Base64 { media_type: String, data: String },
357 Url { url: String },
358 FilePath { path: String },
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct ToolCall {
364 pub id: String,
365 pub name: String,
366 pub input: Value,
367}
368
369impl ToolCall {
370 pub fn new(id: impl Into<String>, name: impl Into<String>, input: Value) -> Self {
371 Self {
372 id: id.into(),
373 name: name.into(),
374 input,
375 }
376 }
377
378 fn estimate_chars(&self) -> usize {
379 self.name.len() + self.input.to_string().len()
380 }
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct ModelResponse {
386 pub id: String,
387 pub model: String,
388 pub content: Vec<ContentBlock>,
389 pub stop_reason: Option<crate::StopReason>,
390 pub usage: crate::TokenUsage,
391}
392
393impl ModelResponse {
394 pub fn text(&self) -> String {
396 self.content
397 .iter()
398 .filter_map(|block| {
399 if let ContentBlock::Text { text } = block {
400 Some(text.as_str())
401 } else {
402 None
403 }
404 })
405 .collect::<Vec<_>>()
406 .join("")
407 }
408
409 pub fn tool_calls(&self) -> Vec<ToolCall> {
411 self.content
412 .iter()
413 .filter_map(|block| {
414 if let ContentBlock::ToolUse { id, name, input } = block {
415 Some(ToolCall {
416 id: id.clone(),
417 name: name.clone(),
418 input: input.clone(),
419 })
420 } else {
421 None
422 }
423 })
424 .collect()
425 }
426
427 pub fn has_tool_calls(&self) -> bool {
429 self.content
430 .iter()
431 .any(|block| matches!(block, ContentBlock::ToolUse { .. }))
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437#[serde(tag = "type", rename_all = "snake_case")]
438pub enum StreamChunk {
439 MessageStart {
440 message: StreamMessageStart,
441 },
442 ContentBlockStart {
443 index: usize,
444 content_block: ContentBlock,
445 },
446 ContentBlockDelta {
447 index: usize,
448 delta: ContentDelta,
449 },
450 ContentBlockStop {
451 index: usize,
452 },
453 MessageDelta {
454 delta: MessageDelta,
455 usage: Option<crate::TokenUsage>,
456 },
457 MessageStop,
458 Ping,
459 Error {
460 error: StreamError,
461 },
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
465pub struct StreamMessageStart {
466 pub id: String,
467 pub model: String,
468 #[serde(default)]
469 pub usage: crate::TokenUsage,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
473#[serde(tag = "type", rename_all = "snake_case")]
474pub enum ContentDelta {
475 TextDelta { text: String },
476 InputJsonDelta { partial_json: String },
477 ThinkingDelta { thinking: String },
478 SignatureDelta { signature: String },
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct MessageDelta {
483 #[serde(skip_serializing_if = "Option::is_none")]
484 pub stop_reason: Option<crate::StopReason>,
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct StreamError {
489 pub r#type: String,
490 pub message: String,
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_message_constructors() {
499 let system = Message::system("You are a helpful assistant");
500 assert_eq!(system.role, Role::System);
501
502 let user = Message::user("Hello");
503 assert_eq!(user.role, Role::User);
504
505 let assistant = Message::assistant("Hi there!");
506 assert_eq!(assistant.role, Role::Assistant);
507 }
508
509 #[test]
510 fn test_content_block_serialization() {
511 let block = ContentBlock::text("Hello");
512 let json = serde_json::to_string(&block).unwrap();
513 assert!(json.contains("\"type\":\"text\""));
514 }
515
516 #[test]
517 fn test_tool_call() {
518 let call = ToolCall::new(
519 "tc_123",
520 "read_file",
521 serde_json::json!({"path": "/tmp/test"}),
522 );
523 assert_eq!(call.name, "read_file");
524 }
525
526 #[test]
527 fn test_model_response_text() {
528 let response = ModelResponse {
529 id: "msg_123".to_string(),
530 model: "claude-3-opus".to_string(),
531 content: vec![ContentBlock::text("Hello, "), ContentBlock::text("world!")],
532 stop_reason: Some(crate::StopReason::EndTurn),
533 usage: Default::default(),
534 };
535 assert_eq!(response.text(), "Hello, world!");
536 }
537
538 #[test]
539 fn test_estimate_tokens() {
540 let msg = Message::user("Hello world"); let tokens = msg.estimate_tokens();
542 assert!(tokens >= 2 && tokens <= 4); }
544
545 #[test]
546 fn test_user_prompt_without_images() {
547 let msg = Message::user_prompt("Describe this bug");
548 assert!(matches!(msg.content, MessageContent::Text(_)));
549 }
550
551 #[test]
552 fn test_user_prompt_with_markdown_image() {
553 let msg = Message::user_prompt("Please review  now");
554
555 match msg.content {
556 MessageContent::Blocks(blocks) => {
557 assert_eq!(blocks.len(), 3);
558 assert!(matches!(blocks[0], ContentBlock::Text { .. }));
559 assert!(matches!(
560 blocks[1],
561 ContentBlock::Image {
562 source: ImageSource::FilePath { .. }
563 }
564 ));
565 assert!(matches!(blocks[2], ContentBlock::Text { .. }));
566 }
567 _ => panic!("expected blocks"),
568 }
569 }
570
571 #[test]
572 fn test_user_prompt_with_bracket_image() {
573 let msg = Message::user_prompt("[image: ./screenshots/error.png]");
574
575 match msg.content {
576 MessageContent::Blocks(blocks) => {
577 assert_eq!(blocks.len(), 1);
578 match &blocks[0] {
579 ContentBlock::Image {
580 source: ImageSource::FilePath { path },
581 } => assert_eq!(path, "./screenshots/error.png"),
582 _ => panic!("expected file path image"),
583 }
584 }
585 _ => panic!("expected blocks"),
586 }
587 }
588}