Skip to main content

walrus_core/model/
message.rs

1//! Turbofish LLM message
2
3use crate::model::{StreamChunk, ToolCall};
4use compact_str::CompactString;
5use serde::{Deserialize, Serialize};
6use smallvec::SmallVec;
7use std::collections::BTreeMap;
8
9/// A message in the chat
10#[derive(Debug, Clone, Deserialize, Serialize, Default)]
11pub struct Message {
12    /// The role of the message
13    pub role: Role,
14
15    /// The content of the message
16    #[serde(skip_serializing_if = "String::is_empty")]
17    pub content: String,
18
19    /// The reasoning content
20    #[serde(skip_serializing_if = "String::is_empty")]
21    pub reasoning_content: String,
22
23    /// The tool call id
24    #[serde(skip_serializing_if = "CompactString::is_empty")]
25    pub tool_call_id: CompactString,
26
27    /// The tool calls
28    #[serde(skip_serializing_if = "SmallVec::is_empty")]
29    pub tool_calls: SmallVec<[ToolCall; 4]>,
30}
31
32impl Message {
33    /// Create a new system message
34    pub fn system(content: impl Into<String>) -> Self {
35        Self {
36            role: Role::System,
37            content: content.into(),
38            ..Default::default()
39        }
40    }
41
42    /// Create a new user message
43    pub fn user(content: impl Into<String>) -> Self {
44        Self {
45            role: Role::User,
46            content: content.into(),
47            ..Default::default()
48        }
49    }
50
51    /// Create a new assistant message
52    pub fn assistant(
53        content: impl Into<String>,
54        reasoning: Option<String>,
55        tool_calls: Option<&[ToolCall]>,
56    ) -> Self {
57        Self {
58            role: Role::Assistant,
59            content: content.into(),
60            reasoning_content: reasoning.unwrap_or_default(),
61            tool_calls: tool_calls
62                .map(|tc| tc.iter().cloned().collect())
63                .unwrap_or_default(),
64            ..Default::default()
65        }
66    }
67
68    /// Create a new tool message
69    pub fn tool(content: impl Into<String>, call: impl Into<CompactString>) -> Self {
70        Self {
71            role: Role::Tool,
72            content: content.into(),
73            tool_call_id: call.into(),
74            ..Default::default()
75        }
76    }
77
78    /// Create a new message builder
79    pub fn builder(role: Role) -> MessageBuilder {
80        MessageBuilder::new(role)
81    }
82
83    /// Estimate the number of tokens in this message.
84    ///
85    /// Uses a simple heuristic: ~4 characters per token.
86    pub fn estimate_tokens(&self) -> usize {
87        let chars = self.content.len()
88            + self.reasoning_content.len()
89            + self.tool_call_id.len()
90            + self
91                .tool_calls
92                .iter()
93                .map(|tc| tc.function.name.len() + tc.function.arguments.len())
94                .sum::<usize>();
95        (chars / 4).max(1)
96    }
97}
98
99/// Estimate total tokens across a slice of messages.
100pub fn estimate_tokens(messages: &[Message]) -> usize {
101    messages.iter().map(|m| m.estimate_tokens()).sum()
102}
103
104/// A builder for messages
105pub struct MessageBuilder {
106    /// The message
107    message: Message,
108    /// The tool calls
109    calls: BTreeMap<u32, ToolCall>,
110}
111
112impl MessageBuilder {
113    /// Create a new message builder
114    pub fn new(role: Role) -> Self {
115        Self {
116            message: Message {
117                role,
118                ..Default::default()
119            },
120            calls: BTreeMap::new(),
121        }
122    }
123
124    /// Accept a chunk from the stream
125    pub fn accept(&mut self, chunk: &StreamChunk) -> bool {
126        if let Some(calls) = chunk.tool_calls() {
127            for call in calls {
128                let entry = self.calls.entry(call.index).or_default();
129                entry.merge(call);
130            }
131        }
132
133        let mut has_content = false;
134        if let Some(content) = chunk.content() {
135            self.message.content.push_str(content);
136            has_content = true;
137        }
138
139        if let Some(reason) = chunk.reasoning_content() {
140            self.message.reasoning_content.push_str(reason);
141        }
142
143        has_content
144    }
145
146    /// Build the message
147    pub fn build(mut self) -> Message {
148        if !self.calls.is_empty() {
149            self.message.tool_calls = self.calls.into_values().collect();
150        }
151        self.message
152    }
153}
154
155/// The role of a message
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Default)]
157pub enum Role {
158    /// The user role
159    #[serde(rename = "user")]
160    #[default]
161    User,
162    /// The assistant role
163    #[serde(rename = "assistant")]
164    Assistant,
165    /// The system role
166    #[serde(rename = "system")]
167    System,
168    /// The tool role
169    #[serde(rename = "tool")]
170    Tool,
171}