1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct AgentInput {
18 pub system_prompt: String,
19 pub user_message: String,
20 pub max_turns: u32,
22}
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
26pub enum StopReason {
27 FinalAnswer,
29 MaxTurnsReached,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentOutput {
36 pub final_answer: String,
37 pub stop_reason: StopReason,
38 pub turns_used: u32,
39 pub tool_calls: u32,
40}
41
42#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44#[serde(rename_all = "snake_case")]
45pub enum Role {
46 System,
47 User,
48 Assistant,
49 Tool,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54pub struct ToolCall {
55 pub id: String,
56 pub name: String,
57 pub args: serde_json::Value,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62pub struct ToolResult {
63 pub call_id: String,
64 pub output: serde_json::Value,
65 pub error: Option<String>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub struct Message {
73 pub role: Role,
74 #[serde(default)]
75 pub content: String,
76 #[serde(default, skip_serializing_if = "Vec::is_empty")]
78 pub tool_calls: Vec<ToolCall>,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub tool_call_id: Option<String>,
82}
83
84impl Message {
85 pub fn system(content: impl Into<String>) -> Self {
86 Self {
87 role: Role::System,
88 content: content.into(),
89 tool_calls: vec![],
90 tool_call_id: None,
91 }
92 }
93
94 pub fn user(content: impl Into<String>) -> Self {
95 Self {
96 role: Role::User,
97 content: content.into(),
98 tool_calls: vec![],
99 tool_call_id: None,
100 }
101 }
102
103 pub fn assistant_text(content: impl Into<String>) -> Self {
104 Self {
105 role: Role::Assistant,
106 content: content.into(),
107 tool_calls: vec![],
108 tool_call_id: None,
109 }
110 }
111
112 pub fn assistant_with_tools(calls: Vec<ToolCall>) -> Self {
113 Self {
114 role: Role::Assistant,
115 content: String::new(),
116 tool_calls: calls,
117 tool_call_id: None,
118 }
119 }
120
121 pub fn tool_result(result: &ToolResult) -> Self {
122 let content = match &result.error {
123 Some(err) => format!("ERROR: {err}"),
124 None => result.output.to_string(),
125 };
126 Self {
127 role: Role::Tool,
128 content,
129 tool_calls: vec![],
130 tool_call_id: Some(result.call_id.clone()),
131 }
132 }
133}
134
135#[derive(Debug, Clone, Default, Serialize, Deserialize)]
137pub struct AgentState {
138 pub input: AgentInput,
139 pub history: Vec<Message>,
140 pub turn: u32,
141 pub tool_calls_executed: u32,
142 #[serde(default)]
147 pub pending_user_messages: Vec<String>,
148}
149
150impl AgentState {
151 pub fn new(input: AgentInput) -> Self {
152 let history = vec![
153 Message::system(&input.system_prompt),
154 Message::user(&input.user_message),
155 ];
156 Self {
157 input,
158 history,
159 turn: 0,
160 tool_calls_executed: 0,
161 pending_user_messages: vec![],
162 }
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
172#[serde(tag = "kind", rename_all = "snake_case")]
173pub enum LlmResponse {
174 Final { answer: String },
176 UseTools { calls: Vec<ToolCall> },
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct LlmChatInput {
183 pub messages: Vec<Message>,
184 pub tools: Vec<ToolSchema>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ToolSchema {
190 pub name: String,
191 pub description: String,
192 pub args_schema: serde_json::Value,
193}
194
195pub fn compact(state: &AgentState, keep_recent: usize) -> AgentInput {
199 let mut summary_lines = Vec::new();
200 let total = state.history.len();
201 let drop_until = total.saturating_sub(keep_recent);
202
203 for msg in state.history.iter().take(drop_until) {
204 let line = match msg.role {
205 Role::System if summary_lines.is_empty() => continue,
206 Role::User => format!("user: {}", truncate(&msg.content, 200)),
207 Role::Assistant if !msg.tool_calls.is_empty() => {
208 let names: Vec<&str> = msg.tool_calls.iter().map(|c| c.name.as_str()).collect();
209 format!("assistant: called tools [{}]", names.join(", "))
210 }
211 Role::Assistant => format!("assistant: {}", truncate(&msg.content, 200)),
212 Role::Tool => format!("tool: {}", truncate(&msg.content, 120)),
213 Role::System => continue,
214 };
215 summary_lines.push(line);
216 }
217
218 let summary = if summary_lines.is_empty() {
219 String::new()
220 } else {
221 format!(
222 "\n\n[Prior conversation summary, {} messages dropped]\n{}",
223 drop_until,
224 summary_lines.join("\n")
225 )
226 };
227
228 let recent_user = state
229 .history
230 .iter()
231 .rev()
232 .find(|m| m.role == Role::User)
233 .map(|m| m.content.clone())
234 .unwrap_or_default();
235
236 AgentInput {
237 system_prompt: format!("{}{}", state.input.system_prompt, summary),
238 user_message: recent_user,
239 max_turns: state.input.max_turns,
240 }
241}
242
243fn truncate(s: &str, max: usize) -> String {
244 if s.len() <= max {
245 return s.to_string();
246 }
247 let mut boundary = max;
248 while boundary > 0 && !s.is_char_boundary(boundary) {
249 boundary -= 1;
250 }
251 format!("{}…", &s[..boundary])
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn agent_state_seeds_system_and_user() {
260 let s = AgentState::new(AgentInput {
261 system_prompt: "be helpful".into(),
262 user_message: "hi".into(),
263 max_turns: 5,
264 });
265 assert_eq!(s.history.len(), 2);
266 assert_eq!(s.history[0].role, Role::System);
267 assert_eq!(s.history[1].role, Role::User);
268 assert_eq!(s.turn, 0);
269 }
270
271 #[test]
272 fn compact_keeps_system_and_recent() {
273 let mut state = AgentState::new(AgentInput {
274 system_prompt: "sys".into(),
275 user_message: "u0".into(),
276 max_turns: 50,
277 });
278 for i in 1..30 {
279 state.history.push(Message::user(format!("u{i}")));
280 state.history.push(Message::assistant_text(format!("a{i}")));
281 }
282 let compacted = compact(&state, 10);
283 assert!(compacted.system_prompt.starts_with("sys"));
284 assert!(
285 compacted
286 .system_prompt
287 .contains("Prior conversation summary")
288 );
289 assert_eq!(compacted.max_turns, state.input.max_turns);
290 }
291
292 #[test]
293 fn truncate_respects_utf8_char_boundary() {
294 let t = truncate("héllo world", 2);
296 assert_eq!(t, "h…");
297 assert_eq!(truncate("hi", 10), "hi");
299 assert_eq!(truncate("🦀rust", 2), "…");
301 }
302
303 #[test]
304 fn message_roundtrips_through_json() {
305 let m = Message::assistant_with_tools(vec![ToolCall {
306 id: "c1".into(),
307 name: "add".into(),
308 args: serde_json::json!({"a": 1, "b": 2}),
309 }]);
310 let s = serde_json::to_string(&m).unwrap();
311 let back: Message = serde_json::from_str(&s).unwrap();
312 assert_eq!(m, back);
313 }
314}