Skip to main content

uira_core/protocol/
messages.rs

1//! Message types for model communication
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6/// Role of a message participant
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10    System,
11    User,
12    Assistant,
13    Tool,
14}
15
16/// A message in the conversation
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Message {
19    pub role: Role,
20    pub content: MessageContent,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub name: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub tool_call_id: Option<String>,
25}
26
27impl Message {
28    pub fn system(content: impl Into<String>) -> Self {
29        Self {
30            role: Role::System,
31            content: MessageContent::Text(content.into()),
32            name: None,
33            tool_call_id: None,
34        }
35    }
36
37    pub fn user(content: impl Into<String>) -> Self {
38        Self {
39            role: Role::User,
40            content: MessageContent::Text(content.into()),
41            name: None,
42            tool_call_id: None,
43        }
44    }
45
46    pub fn user_prompt(prompt: impl AsRef<str>) -> Self {
47        Self {
48            role: Role::User,
49            content: MessageContent::from_prompt(prompt.as_ref()),
50            name: None,
51            tool_call_id: None,
52        }
53    }
54
55    pub fn assistant(content: impl Into<String>) -> Self {
56        Self {
57            role: Role::Assistant,
58            content: MessageContent::Text(content.into()),
59            name: None,
60            tool_call_id: None,
61        }
62    }
63
64    pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
65        Self {
66            role: Role::Assistant,
67            content: MessageContent::ToolCalls(tool_calls),
68            name: None,
69            tool_call_id: None,
70        }
71    }
72
73    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
74        Self {
75            role: Role::Tool,
76            content: MessageContent::Text(content.into()),
77            name: None,
78            tool_call_id: Some(tool_call_id.into()),
79        }
80    }
81
82    pub fn with_blocks(role: Role, blocks: Vec<ContentBlock>) -> Self {
83        Self {
84            role,
85            content: MessageContent::Blocks(blocks),
86            name: None,
87            tool_call_id: None,
88        }
89    }
90
91    /// Estimate token count for this message (~4 chars per token)
92    pub fn estimate_tokens(&self) -> usize {
93        let content_len = match &self.content {
94            MessageContent::Text(s) => s.len(),
95            MessageContent::Blocks(blocks) => blocks.iter().map(|b| b.estimate_chars()).sum(),
96            MessageContent::ToolCalls(calls) => calls.iter().map(|c| c.estimate_chars()).sum(),
97        };
98        content_len.div_ceil(4)
99    }
100}
101
102/// Content of a message
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(untagged)]
105pub enum MessageContent {
106    Text(String),
107    Blocks(Vec<ContentBlock>),
108    ToolCalls(Vec<ToolCall>),
109}
110
111impl MessageContent {
112    pub fn as_text(&self) -> Option<&str> {
113        match self {
114            Self::Text(s) => Some(s),
115            _ => None,
116        }
117    }
118
119    pub fn from_prompt(prompt: &str) -> Self {
120        let references = parse_prompt_image_references(prompt);
121
122        if references.is_empty() {
123            return Self::Text(prompt.to_string());
124        }
125
126        let mut blocks = Vec::new();
127        let mut cursor = 0;
128
129        for reference in references {
130            if reference.start > cursor {
131                let text = &prompt[cursor..reference.start];
132                if !text.is_empty() {
133                    blocks.push(ContentBlock::Text {
134                        text: text.to_string(),
135                    });
136                }
137            }
138
139            blocks.push(ContentBlock::Image {
140                source: ImageSource::FilePath {
141                    path: reference.path,
142                },
143            });
144            cursor = reference.end;
145        }
146
147        if cursor < prompt.len() {
148            let text = &prompt[cursor..];
149            if !text.is_empty() {
150                blocks.push(ContentBlock::Text {
151                    text: text.to_string(),
152                });
153            }
154        }
155
156        if blocks.is_empty() {
157            Self::Text(prompt.to_string())
158        } else {
159            Self::Blocks(blocks)
160        }
161    }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165struct PromptImageReference {
166    start: usize,
167    end: usize,
168    path: String,
169}
170
171fn parse_prompt_image_references(prompt: &str) -> Vec<PromptImageReference> {
172    let mut refs = Vec::new();
173    let mut cursor = 0;
174
175    while let Some(reference) = find_next_prompt_image_reference(prompt, cursor) {
176        cursor = reference.end;
177        refs.push(reference);
178    }
179
180    refs
181}
182
183fn find_next_prompt_image_reference(prompt: &str, from: usize) -> Option<PromptImageReference> {
184    let markdown = find_markdown_image_reference(prompt, from);
185    let bracket = find_bracket_image_reference(prompt, from);
186
187    match (markdown, bracket) {
188        (Some(m), Some(b)) => {
189            if m.start <= b.start {
190                Some(m)
191            } else {
192                Some(b)
193            }
194        }
195        (Some(m), None) => Some(m),
196        (None, Some(b)) => Some(b),
197        (None, None) => None,
198    }
199}
200
201fn find_markdown_image_reference(prompt: &str, from: usize) -> Option<PromptImageReference> {
202    let mut cursor = from;
203
204    while let Some(relative_start) = prompt[cursor..].find("![") {
205        let start = cursor + relative_start;
206        let after_marker = start + 2;
207
208        let Some(relative_mid) = prompt[after_marker..].find("](") else {
209            cursor = after_marker;
210            continue;
211        };
212
213        let path_start = after_marker + relative_mid + 2;
214        let Some(relative_end) = prompt[path_start..].find(')') else {
215            cursor = path_start;
216            continue;
217        };
218
219        let path_end = path_start + relative_end;
220        let end = path_end + 1;
221        let raw_path = &prompt[path_start..path_end];
222
223        if let Some(path) = normalize_image_path(raw_path) {
224            return Some(PromptImageReference { start, end, path });
225        }
226
227        cursor = end;
228    }
229
230    None
231}
232
233fn find_bracket_image_reference(prompt: &str, from: usize) -> Option<PromptImageReference> {
234    let marker = "[image:";
235    let mut cursor = from;
236
237    while let Some(relative_start) = prompt[cursor..].find(marker) {
238        let start = cursor + relative_start;
239        let path_start = start + marker.len();
240
241        let Some(relative_end) = prompt[path_start..].find(']') else {
242            cursor = path_start;
243            continue;
244        };
245
246        let path_end = path_start + relative_end;
247        let end = path_end + 1;
248        let raw_path = &prompt[path_start..path_end];
249
250        if let Some(path) = normalize_image_path(raw_path) {
251            return Some(PromptImageReference { start, end, path });
252        }
253
254        cursor = end;
255    }
256
257    None
258}
259
260fn normalize_image_path(path: &str) -> Option<String> {
261    let trimmed = path.trim();
262    if trimmed.is_empty() {
263        return None;
264    }
265
266    let unquoted = trimmed
267        .strip_prefix('"')
268        .and_then(|v| v.strip_suffix('"'))
269        .or_else(|| {
270            trimmed
271                .strip_prefix('\'')
272                .and_then(|v| v.strip_suffix('\''))
273        })
274        .unwrap_or(trimmed)
275        .trim();
276
277    if unquoted.is_empty() {
278        None
279    } else {
280        Some(unquoted.to_string())
281    }
282}
283
284/// A content block within a message
285#[derive(Debug, Clone, Serialize, Deserialize)]
286#[serde(tag = "type", rename_all = "snake_case")]
287pub enum ContentBlock {
288    Text {
289        text: String,
290    },
291    Image {
292        source: ImageSource,
293    },
294    ToolUse {
295        id: String,
296        name: String,
297        input: Value,
298    },
299    ToolResult {
300        tool_use_id: String,
301        content: String,
302        #[serde(default)]
303        is_error: bool,
304    },
305    Thinking {
306        thinking: String,
307        #[serde(skip_serializing_if = "Option::is_none")]
308        signature: Option<String>,
309    },
310}
311
312impl ContentBlock {
313    pub fn text(s: impl Into<String>) -> Self {
314        Self::Text { text: s.into() }
315    }
316
317    pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: Value) -> Self {
318        Self::ToolUse {
319            id: id.into(),
320            name: name.into(),
321            input,
322        }
323    }
324
325    pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
326        Self::ToolResult {
327            tool_use_id: tool_use_id.into(),
328            content: content.into(),
329            is_error: false,
330        }
331    }
332
333    pub fn tool_error(tool_use_id: impl Into<String>, error: impl Into<String>) -> Self {
334        Self::ToolResult {
335            tool_use_id: tool_use_id.into(),
336            content: error.into(),
337            is_error: true,
338        }
339    }
340
341    fn estimate_chars(&self) -> usize {
342        match self {
343            Self::Text { text } => text.len(),
344            Self::Image { .. } => 4000,
345            Self::ToolUse { name, input, .. } => name.len() + input.to_string().len(),
346            Self::ToolResult { content, .. } => content.len(),
347            Self::Thinking { thinking, .. } => thinking.len(),
348        }
349    }
350}
351
352/// Image source for multimodal messages
353#[derive(Debug, Clone, Serialize, Deserialize)]
354#[serde(tag = "type", rename_all = "snake_case")]
355pub enum ImageSource {
356    Base64 { media_type: String, data: String },
357    Url { url: String },
358    FilePath { path: String },
359}
360
361/// A tool call from the model
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct ToolCall {
364    pub id: String,
365    pub name: String,
366    pub input: Value,
367}
368
369impl ToolCall {
370    pub fn new(id: impl Into<String>, name: impl Into<String>, input: Value) -> Self {
371        Self {
372            id: id.into(),
373            name: name.into(),
374            input,
375        }
376    }
377
378    fn estimate_chars(&self) -> usize {
379        self.name.len() + self.input.to_string().len()
380    }
381}
382
383/// Response from the model
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct ModelResponse {
386    pub id: String,
387    pub model: String,
388    pub content: Vec<ContentBlock>,
389    pub stop_reason: Option<crate::StopReason>,
390    pub usage: crate::TokenUsage,
391}
392
393impl ModelResponse {
394    /// Extract text content from the response
395    pub fn text(&self) -> String {
396        self.content
397            .iter()
398            .filter_map(|block| {
399                if let ContentBlock::Text { text } = block {
400                    Some(text.as_str())
401                } else {
402                    None
403                }
404            })
405            .collect::<Vec<_>>()
406            .join("")
407    }
408
409    /// Extract tool calls from the response
410    pub fn tool_calls(&self) -> Vec<ToolCall> {
411        self.content
412            .iter()
413            .filter_map(|block| {
414                if let ContentBlock::ToolUse { id, name, input } = block {
415                    Some(ToolCall {
416                        id: id.clone(),
417                        name: name.clone(),
418                        input: input.clone(),
419                    })
420                } else {
421                    None
422                }
423            })
424            .collect()
425    }
426
427    /// Check if the response contains tool calls
428    pub fn has_tool_calls(&self) -> bool {
429        self.content
430            .iter()
431            .any(|block| matches!(block, ContentBlock::ToolUse { .. }))
432    }
433}
434
435/// Streaming chunk from the model
436#[derive(Debug, Clone, Serialize, Deserialize)]
437#[serde(tag = "type", rename_all = "snake_case")]
438pub enum StreamChunk {
439    MessageStart {
440        message: StreamMessageStart,
441    },
442    ContentBlockStart {
443        index: usize,
444        content_block: ContentBlock,
445    },
446    ContentBlockDelta {
447        index: usize,
448        delta: ContentDelta,
449    },
450    ContentBlockStop {
451        index: usize,
452    },
453    MessageDelta {
454        delta: MessageDelta,
455        usage: Option<crate::TokenUsage>,
456    },
457    MessageStop,
458    Ping,
459    Error {
460        error: StreamError,
461    },
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
465pub struct StreamMessageStart {
466    pub id: String,
467    pub model: String,
468    #[serde(default)]
469    pub usage: crate::TokenUsage,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
473#[serde(tag = "type", rename_all = "snake_case")]
474pub enum ContentDelta {
475    TextDelta { text: String },
476    InputJsonDelta { partial_json: String },
477    ThinkingDelta { thinking: String },
478    SignatureDelta { signature: String },
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct MessageDelta {
483    #[serde(skip_serializing_if = "Option::is_none")]
484    pub stop_reason: Option<crate::StopReason>,
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct StreamError {
489    pub r#type: String,
490    pub message: String,
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_message_constructors() {
499        let system = Message::system("You are a helpful assistant");
500        assert_eq!(system.role, Role::System);
501
502        let user = Message::user("Hello");
503        assert_eq!(user.role, Role::User);
504
505        let assistant = Message::assistant("Hi there!");
506        assert_eq!(assistant.role, Role::Assistant);
507    }
508
509    #[test]
510    fn test_content_block_serialization() {
511        let block = ContentBlock::text("Hello");
512        let json = serde_json::to_string(&block).unwrap();
513        assert!(json.contains("\"type\":\"text\""));
514    }
515
516    #[test]
517    fn test_tool_call() {
518        let call = ToolCall::new(
519            "tc_123",
520            "read_file",
521            serde_json::json!({"path": "/tmp/test"}),
522        );
523        assert_eq!(call.name, "read_file");
524    }
525
526    #[test]
527    fn test_model_response_text() {
528        let response = ModelResponse {
529            id: "msg_123".to_string(),
530            model: "claude-3-opus".to_string(),
531            content: vec![ContentBlock::text("Hello, "), ContentBlock::text("world!")],
532            stop_reason: Some(crate::StopReason::EndTurn),
533            usage: Default::default(),
534        };
535        assert_eq!(response.text(), "Hello, world!");
536    }
537
538    #[test]
539    fn test_estimate_tokens() {
540        let msg = Message::user("Hello world"); // 11 chars
541        let tokens = msg.estimate_tokens();
542        assert!(tokens >= 2 && tokens <= 4); // ~11/4 = 2-3
543    }
544
545    #[test]
546    fn test_user_prompt_without_images() {
547        let msg = Message::user_prompt("Describe this bug");
548        assert!(matches!(msg.content, MessageContent::Text(_)));
549    }
550
551    #[test]
552    fn test_user_prompt_with_markdown_image() {
553        let msg = Message::user_prompt("Please review ![mockup](./mockup.png) now");
554
555        match msg.content {
556            MessageContent::Blocks(blocks) => {
557                assert_eq!(blocks.len(), 3);
558                assert!(matches!(blocks[0], ContentBlock::Text { .. }));
559                assert!(matches!(
560                    blocks[1],
561                    ContentBlock::Image {
562                        source: ImageSource::FilePath { .. }
563                    }
564                ));
565                assert!(matches!(blocks[2], ContentBlock::Text { .. }));
566            }
567            _ => panic!("expected blocks"),
568        }
569    }
570
571    #[test]
572    fn test_user_prompt_with_bracket_image() {
573        let msg = Message::user_prompt("[image: ./screenshots/error.png]");
574
575        match msg.content {
576            MessageContent::Blocks(blocks) => {
577                assert_eq!(blocks.len(), 1);
578                match &blocks[0] {
579                    ContentBlock::Image {
580                        source: ImageSource::FilePath { path },
581                    } => assert_eq!(path, "./screenshots/error.png"),
582                    _ => panic!("expected file path image"),
583                }
584            }
585            _ => panic!("expected blocks"),
586        }
587    }
588}