walrus_core/model/
message.rs1use crate::model::{StreamChunk, ToolCall};
4use compact_str::CompactString;
5use serde::{Deserialize, Serialize};
6use smallvec::SmallVec;
7use std::collections::BTreeMap;
8
9#[derive(Debug, Clone, Deserialize, Serialize, Default)]
11pub struct Message {
12 pub role: Role,
14
15 #[serde(skip_serializing_if = "String::is_empty")]
17 pub content: String,
18
19 #[serde(skip_serializing_if = "String::is_empty")]
21 pub reasoning_content: String,
22
23 #[serde(skip_serializing_if = "CompactString::is_empty")]
25 pub tool_call_id: CompactString,
26
27 #[serde(skip_serializing_if = "SmallVec::is_empty")]
29 pub tool_calls: SmallVec<[ToolCall; 4]>,
30}
31
32impl Message {
33 pub fn system(content: impl Into<String>) -> Self {
35 Self {
36 role: Role::System,
37 content: content.into(),
38 ..Default::default()
39 }
40 }
41
42 pub fn user(content: impl Into<String>) -> Self {
44 Self {
45 role: Role::User,
46 content: content.into(),
47 ..Default::default()
48 }
49 }
50
51 pub fn assistant(
53 content: impl Into<String>,
54 reasoning: Option<String>,
55 tool_calls: Option<&[ToolCall]>,
56 ) -> Self {
57 Self {
58 role: Role::Assistant,
59 content: content.into(),
60 reasoning_content: reasoning.unwrap_or_default(),
61 tool_calls: tool_calls
62 .map(|tc| tc.iter().cloned().collect())
63 .unwrap_or_default(),
64 ..Default::default()
65 }
66 }
67
68 pub fn tool(content: impl Into<String>, call: impl Into<CompactString>) -> Self {
70 Self {
71 role: Role::Tool,
72 content: content.into(),
73 tool_call_id: call.into(),
74 ..Default::default()
75 }
76 }
77
78 pub fn builder(role: Role) -> MessageBuilder {
80 MessageBuilder::new(role)
81 }
82
83 pub fn estimate_tokens(&self) -> usize {
87 let chars = self.content.len()
88 + self.reasoning_content.len()
89 + self.tool_call_id.len()
90 + self
91 .tool_calls
92 .iter()
93 .map(|tc| tc.function.name.len() + tc.function.arguments.len())
94 .sum::<usize>();
95 (chars / 4).max(1)
96 }
97}
98
99pub fn estimate_tokens(messages: &[Message]) -> usize {
101 messages.iter().map(|m| m.estimate_tokens()).sum()
102}
103
104pub struct MessageBuilder {
106 message: Message,
108 calls: BTreeMap<u32, ToolCall>,
110}
111
112impl MessageBuilder {
113 pub fn new(role: Role) -> Self {
115 Self {
116 message: Message {
117 role,
118 ..Default::default()
119 },
120 calls: BTreeMap::new(),
121 }
122 }
123
124 pub fn accept(&mut self, chunk: &StreamChunk) -> bool {
126 if let Some(calls) = chunk.tool_calls() {
127 for call in calls {
128 let entry = self.calls.entry(call.index).or_default();
129 entry.merge(call);
130 }
131 }
132
133 let mut has_content = false;
134 if let Some(content) = chunk.content() {
135 self.message.content.push_str(content);
136 has_content = true;
137 }
138
139 if let Some(reason) = chunk.reasoning_content() {
140 self.message.reasoning_content.push_str(reason);
141 }
142
143 has_content
144 }
145
146 pub fn build(mut self) -> Message {
148 if !self.calls.is_empty() {
149 self.message.tool_calls = self.calls.into_values().collect();
150 }
151 self.message
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Default)]
157pub enum Role {
158 #[serde(rename = "user")]
160 #[default]
161 User,
162 #[serde(rename = "assistant")]
164 Assistant,
165 #[serde(rename = "system")]
167 System,
168 #[serde(rename = "tool")]
170 Tool,
171}