walrus_core/model/
message.rs1use crate::model::{StreamChunk, ToolCall};
4use compact_str::CompactString;
5use serde::{Deserialize, Serialize};
6use smallvec::SmallVec;
7use std::collections::BTreeMap;
8
9#[derive(Debug, Clone, Deserialize, Serialize, Default)]
11pub struct Message {
12 pub role: Role,
14
15 #[serde(skip_serializing_if = "String::is_empty")]
17 pub content: String,
18
19 #[serde(skip_serializing_if = "String::is_empty")]
21 pub reasoning_content: String,
22
23 #[serde(skip_serializing_if = "CompactString::is_empty")]
25 pub tool_call_id: CompactString,
26
27 #[serde(skip_serializing_if = "SmallVec::is_empty")]
29 pub tool_calls: SmallVec<[ToolCall; 4]>,
30
31 #[serde(skip)]
36 pub sender: CompactString,
37}
38
39impl Message {
40 pub fn system(content: impl Into<String>) -> Self {
42 Self {
43 role: Role::System,
44 content: content.into(),
45 ..Default::default()
46 }
47 }
48
49 pub fn user(content: impl Into<String>) -> Self {
51 Self {
52 role: Role::User,
53 content: content.into(),
54 ..Default::default()
55 }
56 }
57
58 pub fn user_with_sender(content: impl Into<String>, sender: impl Into<CompactString>) -> Self {
60 Self {
61 role: Role::User,
62 content: content.into(),
63 sender: sender.into(),
64 ..Default::default()
65 }
66 }
67
68 pub fn assistant(
70 content: impl Into<String>,
71 reasoning: Option<String>,
72 tool_calls: Option<&[ToolCall]>,
73 ) -> Self {
74 Self {
75 role: Role::Assistant,
76 content: content.into(),
77 reasoning_content: reasoning.unwrap_or_default(),
78 tool_calls: tool_calls
79 .map(|tc| tc.iter().cloned().collect())
80 .unwrap_or_default(),
81 ..Default::default()
82 }
83 }
84
85 pub fn tool(content: impl Into<String>, call: impl Into<CompactString>) -> Self {
87 Self {
88 role: Role::Tool,
89 content: content.into(),
90 tool_call_id: call.into(),
91 ..Default::default()
92 }
93 }
94
95 pub fn builder(role: Role) -> MessageBuilder {
97 MessageBuilder::new(role)
98 }
99
100 pub fn estimate_tokens(&self) -> usize {
104 let chars = self.content.len()
105 + self.reasoning_content.len()
106 + self.tool_call_id.len()
107 + self
108 .tool_calls
109 .iter()
110 .map(|tc| tc.function.name.len() + tc.function.arguments.len())
111 .sum::<usize>();
112 (chars / 4).max(1)
113 }
114}
115
116pub fn estimate_tokens(messages: &[Message]) -> usize {
118 messages.iter().map(|m| m.estimate_tokens()).sum()
119}
120
121pub struct MessageBuilder {
123 message: Message,
125 calls: BTreeMap<u32, ToolCall>,
127}
128
129impl MessageBuilder {
130 pub fn new(role: Role) -> Self {
132 Self {
133 message: Message {
134 role,
135 ..Default::default()
136 },
137 calls: BTreeMap::new(),
138 }
139 }
140
141 pub fn accept(&mut self, chunk: &StreamChunk) -> bool {
143 if let Some(calls) = chunk.tool_calls() {
144 for call in calls {
145 let entry = self.calls.entry(call.index).or_default();
146 entry.merge(call);
147 }
148 }
149
150 let mut has_content = false;
151 if let Some(content) = chunk.content() {
152 self.message.content.push_str(content);
153 has_content = true;
154 }
155
156 if let Some(reason) = chunk.reasoning_content() {
157 self.message.reasoning_content.push_str(reason);
158 }
159
160 has_content
161 }
162
163 pub fn build(mut self) -> Message {
165 if !self.calls.is_empty() {
166 self.message.tool_calls = self.calls.into_values().collect();
167 }
168 self.message
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Default)]
174pub enum Role {
175 #[serde(rename = "user")]
177 #[default]
178 User,
179 #[serde(rename = "assistant")]
181 Assistant,
182 #[serde(rename = "system")]
184 System,
185 #[serde(rename = "tool")]
187 Tool,
188}