Skip to main content

symbi_runtime/reasoning/
conversation.rs

1//! Multi-turn conversation management
2//!
3//! Provides a `Conversation` type that manages a sequence of messages
4//! across System, User, Assistant, ToolCall, and ToolResult roles.
5//! Supports serialization to OpenAI and Anthropic API formats and
6//! token estimation for context window management.
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Role of a message in a conversation.
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum MessageRole {
15    System,
16    User,
17    Assistant,
18    Tool,
19}
20
21/// A single tool call embedded in an assistant message.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ToolCall {
24    /// Unique identifier for this tool call (used to correlate with results).
25    pub id: String,
26    /// Name of the tool being called.
27    pub name: String,
28    /// JSON-encoded arguments for the tool.
29    pub arguments: String,
30}
31
32/// A single message in a conversation.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ConversationMessage {
35    /// The role of the message sender.
36    pub role: MessageRole,
37    /// Text content of the message (may be empty for pure tool-call messages).
38    pub content: String,
39    /// Tool calls made by the assistant (only present when role is Assistant).
40    #[serde(default, skip_serializing_if = "Vec::is_empty")]
41    pub tool_calls: Vec<ToolCall>,
42    /// The tool call ID this message is responding to (only present when role is Tool).
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub tool_call_id: Option<String>,
45    /// The tool name this result corresponds to (only present when role is Tool).
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub tool_name: Option<String>,
48}
49
50impl ConversationMessage {
51    /// Create a system message.
52    pub fn system(content: impl Into<String>) -> Self {
53        Self {
54            role: MessageRole::System,
55            content: content.into(),
56            tool_calls: Vec::new(),
57            tool_call_id: None,
58            tool_name: None,
59        }
60    }
61
62    /// Create a user message.
63    pub fn user(content: impl Into<String>) -> Self {
64        Self {
65            role: MessageRole::User,
66            content: content.into(),
67            tool_calls: Vec::new(),
68            tool_call_id: None,
69            tool_name: None,
70        }
71    }
72
73    /// Create an assistant message with text content.
74    pub fn assistant(content: impl Into<String>) -> Self {
75        Self {
76            role: MessageRole::Assistant,
77            content: content.into(),
78            tool_calls: Vec::new(),
79            tool_call_id: None,
80            tool_name: None,
81        }
82    }
83
84    /// Create an assistant message with tool calls.
85    pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
86        Self {
87            role: MessageRole::Assistant,
88            content: String::new(),
89            tool_calls,
90            tool_call_id: None,
91            tool_name: None,
92        }
93    }
94
95    /// Create a tool result message.
96    pub fn tool_result(
97        tool_call_id: impl Into<String>,
98        tool_name: impl Into<String>,
99        content: impl Into<String>,
100    ) -> Self {
101        Self {
102            role: MessageRole::Tool,
103            content: content.into(),
104            tool_calls: Vec::new(),
105            tool_call_id: Some(tool_call_id.into()),
106            tool_name: Some(tool_name.into()),
107        }
108    }
109
110    /// Estimate token count for this message.
111    ///
112    /// Uses ~3.3 chars/token (Anthropic's tokenizer averages 3-3.5 for mixed content
113    /// including JSON, code, and prose). Adds per-message framing overhead for the
114    /// role field, JSON structure, and tool metadata that the API sees but we don't
115    /// count in raw content length.
116    pub fn estimate_tokens(&self) -> usize {
117        let mut chars = self.content.len();
118        for tc in &self.tool_calls {
119            // Tool call JSON structure adds overhead beyond just name+args
120            chars += tc.name.len() + tc.arguments.len() + tc.id.len() + 30; // JSON framing
121        }
122        if let Some(ref id) = self.tool_call_id {
123            chars += id.len() + 20; // tool_result framing
124        }
125        // ~3.3 chars per token (10 tokens per 33 chars), plus per-message overhead
126        // The overhead covers role, JSON structure, content block wrapping
127        (chars * 10 / 33).max(1) + 7
128    }
129}
130
131/// An ordered sequence of conversation messages with serialization helpers.
132#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct Conversation {
134    messages: Vec<ConversationMessage>,
135}
136
137impl Conversation {
138    /// Create a new empty conversation.
139    pub fn new() -> Self {
140        Self {
141            messages: Vec::new(),
142        }
143    }
144
145    /// Create a conversation with a system message.
146    pub fn with_system(system_prompt: impl Into<String>) -> Self {
147        Self {
148            messages: vec![ConversationMessage::system(system_prompt)],
149        }
150    }
151
152    /// Append a message to the conversation.
153    pub fn push(&mut self, message: ConversationMessage) {
154        self.messages.push(message);
155    }
156
157    /// Get the messages in the conversation.
158    pub fn messages(&self) -> &[ConversationMessage] {
159        &self.messages
160    }
161
162    /// Get the number of messages.
163    pub fn len(&self) -> usize {
164        self.messages.len()
165    }
166
167    /// Check if the conversation is empty.
168    pub fn is_empty(&self) -> bool {
169        self.messages.is_empty()
170    }
171
172    /// Estimate total token count across all messages.
173    pub fn estimate_tokens(&self) -> usize {
174        self.messages.iter().map(|m| m.estimate_tokens()).sum()
175    }
176
177    /// Get the system message if present (first message with System role).
178    pub fn system_message(&self) -> Option<&ConversationMessage> {
179        self.messages.iter().find(|m| m.role == MessageRole::System)
180    }
181
182    /// Get the last assistant message.
183    pub fn last_assistant_message(&self) -> Option<&ConversationMessage> {
184        self.messages
185            .iter()
186            .rev()
187            .find(|m| m.role == MessageRole::Assistant)
188    }
189
190    /// Serialize to OpenAI chat completions format.
191    ///
192    /// Produces a JSON array of message objects with `role`, `content`,
193    /// and optionally `tool_calls` or `tool_call_id` fields.
194    pub fn to_openai_messages(&self) -> Vec<serde_json::Value> {
195        self.messages
196            .iter()
197            .map(|msg| {
198                let mut obj = serde_json::Map::new();
199                let role_str = match msg.role {
200                    MessageRole::System => "system",
201                    MessageRole::User => "user",
202                    MessageRole::Assistant => "assistant",
203                    MessageRole::Tool => "tool",
204                };
205                obj.insert("role".into(), serde_json::Value::String(role_str.into()));
206
207                if !msg.content.is_empty() {
208                    obj.insert(
209                        "content".into(),
210                        serde_json::Value::String(msg.content.clone()),
211                    );
212                } else if msg.role != MessageRole::Assistant {
213                    // OpenAI requires content for non-assistant messages
214                    obj.insert("content".into(), serde_json::Value::String(String::new()));
215                }
216
217                if !msg.tool_calls.is_empty() {
218                    let tool_calls: Vec<serde_json::Value> = msg
219                        .tool_calls
220                        .iter()
221                        .map(|tc| {
222                            serde_json::json!({
223                                "id": tc.id,
224                                "type": "function",
225                                "function": {
226                                    "name": tc.name,
227                                    "arguments": tc.arguments,
228                                }
229                            })
230                        })
231                        .collect();
232                    obj.insert("tool_calls".into(), serde_json::Value::Array(tool_calls));
233                }
234
235                if let Some(ref id) = msg.tool_call_id {
236                    obj.insert("tool_call_id".into(), serde_json::Value::String(id.clone()));
237                }
238
239                serde_json::Value::Object(obj)
240            })
241            .collect()
242    }
243
244    /// Serialize to Anthropic Messages API format.
245    ///
246    /// Returns `(system_prompt, messages)` because Anthropic takes the system
247    /// message as a separate top-level field.
248    pub fn to_anthropic_messages(&self) -> (Option<String>, Vec<serde_json::Value>) {
249        let system = self
250            .messages
251            .iter()
252            .find(|m| m.role == MessageRole::System)
253            .map(|m| m.content.clone());
254
255        // Build raw messages first, then merge consecutive same-role messages.
256        // Anthropic requires that all tool_result blocks for a given assistant
257        // message's tool_use blocks appear in the immediately following user message.
258        let mut raw_messages: Vec<serde_json::Value> = Vec::new();
259
260        for msg in self
261            .messages
262            .iter()
263            .filter(|m| m.role != MessageRole::System)
264        {
265            let role_str = match msg.role {
266                MessageRole::User | MessageRole::Tool => "user",
267                MessageRole::Assistant => "assistant",
268                MessageRole::System => unreachable!(),
269            };
270
271            let serialized = if msg.role == MessageRole::Tool {
272                // Anthropic tool results go as user messages with tool_result content blocks
273                serde_json::json!({
274                    "role": "user",
275                    "content": [{
276                        "type": "tool_result",
277                        "tool_use_id": msg.tool_call_id.as_deref().unwrap_or(""),
278                        "content": msg.content,
279                    }]
280                })
281            } else if !msg.tool_calls.is_empty() {
282                // Assistant message with tool use
283                let mut content_blocks: Vec<serde_json::Value> = Vec::new();
284                if !msg.content.is_empty() {
285                    content_blocks.push(serde_json::json!({
286                        "type": "text",
287                        "text": msg.content,
288                    }));
289                }
290                for tc in &msg.tool_calls {
291                    let args: serde_json::Value =
292                        serde_json::from_str(&tc.arguments).unwrap_or(serde_json::json!({}));
293                    content_blocks.push(serde_json::json!({
294                        "type": "tool_use",
295                        "id": tc.id,
296                        "name": tc.name,
297                        "input": args,
298                    }));
299                }
300                serde_json::json!({
301                    "role": role_str,
302                    "content": content_blocks,
303                })
304            } else {
305                serde_json::json!({
306                    "role": role_str,
307                    "content": msg.content,
308                })
309            };
310
311            // Merge consecutive messages with the same role by combining content blocks.
312            // This is critical for tool_result blocks: Anthropic requires all tool_results
313            // for a set of tool_use blocks to be in a single user message.
314            if let Some(last) = raw_messages.last_mut() {
315                let last_role = last.get("role").and_then(|r| r.as_str()).unwrap_or("");
316                if last_role == role_str {
317                    // Merge content into the previous message
318                    let prev_content = last.get_mut("content").unwrap();
319                    let new_content = serialized.get("content").unwrap();
320
321                    // Ensure both are arrays for merging
322                    let prev_arr = if prev_content.is_array() {
323                        prev_content.as_array_mut().unwrap()
324                    } else {
325                        // Convert string content to a text block array
326                        let text = prev_content.as_str().unwrap_or("").to_string();
327                        *prev_content = serde_json::json!([{"type": "text", "text": text}]);
328                        prev_content.as_array_mut().unwrap()
329                    };
330
331                    if new_content.is_array() {
332                        prev_arr.extend(new_content.as_array().unwrap().iter().cloned());
333                    } else {
334                        let text = new_content.as_str().unwrap_or("").to_string();
335                        prev_arr.push(serde_json::json!({"type": "text", "text": text}));
336                    }
337
338                    continue;
339                }
340            }
341
342            raw_messages.push(serialized);
343        }
344
345        (system, raw_messages)
346    }
347
348    /// Truncate the conversation to fit within a token budget.
349    ///
350    /// Preserves the system message and the most recent messages.
351    /// Removes older messages from the middle until the budget is met.
352    pub fn truncate_to_budget(&mut self, max_tokens: usize) {
353        if self.estimate_tokens() <= max_tokens {
354            return;
355        }
356
357        let system_msg = if self
358            .messages
359            .first()
360            .is_some_and(|m| m.role == MessageRole::System)
361        {
362            Some(self.messages[0].clone())
363        } else {
364            None
365        };
366
367        let system_tokens = system_msg.as_ref().map_or(0, |m| m.estimate_tokens());
368        let remaining_budget = max_tokens.saturating_sub(system_tokens);
369
370        // Keep messages from the end until we exceed the budget
371        let start_idx = if system_msg.is_some() { 1 } else { 0 };
372        let non_system: Vec<ConversationMessage> = self.messages.drain(start_idx..).rev().collect();
373
374        let mut kept = Vec::new();
375        let mut used_tokens = 0;
376        for msg in non_system {
377            let msg_tokens = msg.estimate_tokens();
378            if used_tokens + msg_tokens > remaining_budget {
379                break;
380            }
381            used_tokens += msg_tokens;
382            kept.push(msg);
383        }
384        kept.reverse();
385
386        self.messages.clear();
387        if let Some(sys) = system_msg {
388            self.messages.push(sys);
389        }
390        self.messages.extend(kept);
391    }
392
393    /// Insert a system-level knowledge context message after the initial system message.
394    /// If a previous knowledge context exists (identified by a marker prefix), replace it.
395    pub fn inject_knowledge_context(&mut self, context: impl Into<String>) {
396        let marker = "[KNOWLEDGE_CONTEXT]";
397        let content = format!("{}\n{}", marker, context.into());
398        let msg = ConversationMessage::system(content);
399
400        // Find and replace existing knowledge context, or insert after system message
401        if let Some(pos) = self
402            .messages
403            .iter()
404            .position(|m| m.role == MessageRole::System && m.content.starts_with(marker))
405        {
406            self.messages[pos] = msg;
407        } else {
408            // Insert after first system message (position 1), or at 0 if no system msg
409            let insert_pos = if self
410                .messages
411                .first()
412                .is_some_and(|m| m.role == MessageRole::System)
413            {
414                1
415            } else {
416                0
417            };
418            self.messages.insert(insert_pos, msg);
419        }
420    }
421
422    /// Get metadata about the conversation for logging.
423    pub fn metadata(&self) -> HashMap<String, String> {
424        let mut meta = HashMap::new();
425        meta.insert("message_count".into(), self.messages.len().to_string());
426        meta.insert(
427            "estimated_tokens".into(),
428            self.estimate_tokens().to_string(),
429        );
430        meta.insert(
431            "has_system".into(),
432            self.system_message().is_some().to_string(),
433        );
434        let tool_call_count: usize = self.messages.iter().map(|m| m.tool_calls.len()).sum();
435        meta.insert("tool_call_count".into(), tool_call_count.to_string());
436        let tool_result_count = self
437            .messages
438            .iter()
439            .filter(|m| m.role == MessageRole::Tool)
440            .count();
441        meta.insert("tool_result_count".into(), tool_result_count.to_string());
442        meta
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_conversation_creation() {
452        let conv = Conversation::with_system("You are a helpful assistant.");
453        assert_eq!(conv.len(), 1);
454        assert!(!conv.is_empty());
455        assert!(conv.system_message().is_some());
456    }
457
458    #[test]
459    fn test_message_constructors() {
460        let sys = ConversationMessage::system("system");
461        assert_eq!(sys.role, MessageRole::System);
462        assert_eq!(sys.content, "system");
463
464        let user = ConversationMessage::user("hello");
465        assert_eq!(user.role, MessageRole::User);
466
467        let asst = ConversationMessage::assistant("hi there");
468        assert_eq!(asst.role, MessageRole::Assistant);
469
470        let tool = ConversationMessage::tool_result("call_1", "search", "results here");
471        assert_eq!(tool.role, MessageRole::Tool);
472        assert_eq!(tool.tool_call_id.as_deref(), Some("call_1"));
473        assert_eq!(tool.tool_name.as_deref(), Some("search"));
474    }
475
476    #[test]
477    fn test_openai_serialization_roundtrip() {
478        let mut conv = Conversation::with_system("You are a test agent.");
479        conv.push(ConversationMessage::user("Search for rust crates"));
480        conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
481            id: "call_1".into(),
482            name: "web_search".into(),
483            arguments: r#"{"query":"rust crates"}"#.into(),
484        }]));
485        conv.push(ConversationMessage::tool_result(
486            "call_1",
487            "web_search",
488            "Found: serde, tokio, reqwest",
489        ));
490        conv.push(ConversationMessage::assistant(
491            "I found serde, tokio, and reqwest.",
492        ));
493
494        let openai_msgs = conv.to_openai_messages();
495        assert_eq!(openai_msgs.len(), 5);
496
497        // Verify system message
498        assert_eq!(openai_msgs[0]["role"], "system");
499        assert_eq!(openai_msgs[0]["content"], "You are a test agent.");
500
501        // Verify user message
502        assert_eq!(openai_msgs[1]["role"], "user");
503
504        // Verify assistant with tool calls
505        assert_eq!(openai_msgs[2]["role"], "assistant");
506        assert!(openai_msgs[2]["tool_calls"].is_array());
507        let tool_calls = openai_msgs[2]["tool_calls"].as_array().unwrap();
508        assert_eq!(tool_calls.len(), 1);
509        assert_eq!(tool_calls[0]["function"]["name"], "web_search");
510
511        // Verify tool result
512        assert_eq!(openai_msgs[3]["role"], "tool");
513        assert_eq!(openai_msgs[3]["tool_call_id"], "call_1");
514
515        // Verify final assistant
516        assert_eq!(openai_msgs[4]["role"], "assistant");
517    }
518
519    #[test]
520    fn test_anthropic_serialization() {
521        let mut conv = Conversation::with_system("System prompt here.");
522        conv.push(ConversationMessage::user("Hello"));
523        conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
524            id: "tu_1".into(),
525            name: "calculator".into(),
526            arguments: r#"{"expr":"2+2"}"#.into(),
527        }]));
528        conv.push(ConversationMessage::tool_result("tu_1", "calculator", "4"));
529        conv.push(ConversationMessage::assistant("The result is 4."));
530
531        let (system, messages) = conv.to_anthropic_messages();
532        assert_eq!(system.as_deref(), Some("System prompt here."));
533        // System is excluded from messages
534        assert_eq!(messages.len(), 4);
535
536        // User message
537        assert_eq!(messages[0]["role"], "user");
538        assert_eq!(messages[0]["content"], "Hello");
539
540        // Assistant with tool_use
541        assert_eq!(messages[1]["role"], "assistant");
542        let content = messages[1]["content"].as_array().unwrap();
543        assert_eq!(content[0]["type"], "tool_use");
544        assert_eq!(content[0]["name"], "calculator");
545
546        // Tool result as user message
547        assert_eq!(messages[2]["role"], "user");
548        let result_content = messages[2]["content"].as_array().unwrap();
549        assert_eq!(result_content[0]["type"], "tool_result");
550        assert_eq!(result_content[0]["tool_use_id"], "tu_1");
551
552        // Final assistant
553        assert_eq!(messages[3]["role"], "assistant");
554    }
555
556    #[test]
557    fn test_token_estimation() {
558        let msg = ConversationMessage::user("Hello, world!"); // 13 chars
559        let tokens = msg.estimate_tokens();
560        // 13 * 10/33 = 3, max(3,1) + 7 = 10
561        assert_eq!(tokens, 10);
562    }
563
564    #[test]
565    fn test_conversation_token_estimation() {
566        let mut conv = Conversation::with_system("Be helpful.");
567        conv.push(ConversationMessage::user("Hi"));
568        conv.push(ConversationMessage::assistant("Hello!"));
569        let total = conv.estimate_tokens();
570        assert!(total > 0);
571    }
572
573    #[test]
574    fn test_truncate_to_budget() {
575        let mut conv = Conversation::with_system("sys");
576        for i in 0..20 {
577            conv.push(ConversationMessage::user(format!(
578                "Message number {} with some extra text to take up tokens",
579                i
580            )));
581            conv.push(ConversationMessage::assistant(format!("Reply {}", i)));
582        }
583
584        let original_len = conv.len();
585        assert!(original_len > 10);
586
587        conv.truncate_to_budget(100);
588        assert!(conv.len() < original_len);
589        // System message preserved
590        assert_eq!(conv.messages()[0].role, MessageRole::System);
591        assert!(conv.estimate_tokens() <= 100);
592    }
593
594    #[test]
595    fn test_metadata() {
596        let mut conv = Conversation::with_system("sys");
597        conv.push(ConversationMessage::user("hi"));
598        conv.push(ConversationMessage::assistant_tool_calls(vec![
599            ToolCall {
600                id: "c1".into(),
601                name: "t1".into(),
602                arguments: "{}".into(),
603            },
604            ToolCall {
605                id: "c2".into(),
606                name: "t2".into(),
607                arguments: "{}".into(),
608            },
609        ]));
610        conv.push(ConversationMessage::tool_result("c1", "t1", "ok"));
611        conv.push(ConversationMessage::tool_result("c2", "t2", "ok"));
612
613        let meta = conv.metadata();
614        assert_eq!(meta["message_count"], "5");
615        assert_eq!(meta["has_system"], "true");
616        assert_eq!(meta["tool_call_count"], "2");
617        assert_eq!(meta["tool_result_count"], "2");
618    }
619
620    #[test]
621    fn test_last_assistant_message() {
622        let mut conv = Conversation::new();
623        assert!(conv.last_assistant_message().is_none());
624
625        conv.push(ConversationMessage::user("hi"));
626        conv.push(ConversationMessage::assistant("first"));
627        conv.push(ConversationMessage::user("more"));
628        conv.push(ConversationMessage::assistant("second"));
629
630        assert_eq!(conv.last_assistant_message().unwrap().content, "second");
631    }
632
633    #[test]
634    fn test_inject_knowledge_context_after_system() {
635        let mut conv = Conversation::with_system("You are helpful.");
636        conv.push(ConversationMessage::user("hello"));
637        conv.inject_knowledge_context("Some knowledge here");
638
639        assert_eq!(conv.len(), 3);
640        // Knowledge context should be at position 1 (after system)
641        assert_eq!(conv.messages()[0].role, MessageRole::System);
642        assert_eq!(conv.messages()[0].content, "You are helpful.");
643        assert!(conv.messages()[1].content.contains("[KNOWLEDGE_CONTEXT]"));
644        assert!(conv.messages()[1].content.contains("Some knowledge here"));
645        assert_eq!(conv.messages()[2].role, MessageRole::User);
646    }
647
648    #[test]
649    fn test_inject_knowledge_context_replaces_existing() {
650        let mut conv = Conversation::with_system("System prompt");
651        conv.inject_knowledge_context("First knowledge");
652        conv.inject_knowledge_context("Updated knowledge");
653
654        // Should still have just system + one knowledge context
655        let knowledge_msgs: Vec<_> = conv
656            .messages()
657            .iter()
658            .filter(|m| m.content.contains("[KNOWLEDGE_CONTEXT]"))
659            .collect();
660        assert_eq!(knowledge_msgs.len(), 1);
661        assert!(knowledge_msgs[0].content.contains("Updated knowledge"));
662    }
663
664    #[test]
665    fn test_inject_knowledge_context_no_system_message() {
666        let mut conv = Conversation::new();
667        conv.push(ConversationMessage::user("hello"));
668        conv.inject_knowledge_context("Knowledge without system");
669
670        assert_eq!(conv.len(), 2);
671        // Knowledge context should be at position 0
672        assert!(conv.messages()[0].content.contains("[KNOWLEDGE_CONTEXT]"));
673        assert_eq!(conv.messages()[1].role, MessageRole::User);
674    }
675
676    #[test]
677    fn test_serde_roundtrip() {
678        let mut conv = Conversation::with_system("test");
679        conv.push(ConversationMessage::user("hello"));
680        conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
681            id: "tc1".into(),
682            name: "search".into(),
683            arguments: r#"{"q":"test"}"#.into(),
684        }]));
685
686        let json = serde_json::to_string(&conv).unwrap();
687        let restored: Conversation = serde_json::from_str(&json).unwrap();
688        assert_eq!(restored.len(), conv.len());
689        assert_eq!(restored.messages()[2].tool_calls[0].name, "search");
690    }
691}