Skip to main content

rustant_core/
brain.rs

1//! Brain module — LLM provider abstraction and interaction.
2//!
3//! Defines the `LlmProvider` trait for model-agnostic LLM interactions,
4//! and provides an OpenAI-compatible implementation with streaming support.
5
6use crate::error::LlmError;
7use crate::types::{
8    CompletionRequest, CompletionResponse, Content, CostEstimate, Message, Role, StreamEvent,
9    TokenUsage, ToolDefinition,
10};
11use async_trait::async_trait;
12use std::collections::HashSet;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tracing::{debug, info, warn};
16
17/// Trait for LLM providers, supporting both full and streaming completions.
18#[async_trait]
19pub trait LlmProvider: Send + Sync {
20    /// Perform a full completion and return the response.
21    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
22
23    /// Perform a streaming completion, sending events to the channel.
24    async fn complete_streaming(
25        &self,
26        request: CompletionRequest,
27        tx: mpsc::Sender<StreamEvent>,
28    ) -> Result<(), LlmError>;
29
30    /// Estimate the token count for a set of messages.
31    fn estimate_tokens(&self, messages: &[Message]) -> usize;
32
33    /// Return the context window size for this provider/model.
34    fn context_window(&self) -> usize;
35
36    /// Return whether this provider supports tool/function calling.
37    fn supports_tools(&self) -> bool;
38
39    /// Return the cost per token (input, output) in USD.
40    fn cost_per_token(&self) -> (f64, f64);
41
42    /// Return the model name.
43    fn model_name(&self) -> &str;
44}
45
46/// Token counter using tiktoken-rs for accurate BPE tokenization.
47pub struct TokenCounter {
48    bpe: tiktoken_rs::CoreBPE,
49}
50
51impl TokenCounter {
52    /// Create a token counter for the given model.
53    /// Falls back to cl100k_base if the model isn't recognized.
54    pub fn for_model(model: &str) -> Self {
55        let bpe = tiktoken_rs::get_bpe_from_model(model).unwrap_or_else(|_| {
56            tiktoken_rs::cl100k_base().expect("cl100k_base should be available")
57        });
58        Self { bpe }
59    }
60
61    /// Count the number of tokens in a string.
62    pub fn count(&self, text: &str) -> usize {
63        self.bpe.encode_with_special_tokens(text).len()
64    }
65
66    /// Estimate the token count for a set of messages.
67    /// Adds overhead for message structure (role, separators).
68    pub fn count_messages(&self, messages: &[Message]) -> usize {
69        let mut total = 0;
70        for msg in messages {
71            // Each message has overhead: role token + separators (~4 tokens)
72            total += 4;
73            match &msg.content {
74                Content::Text { text } => total += self.count(text),
75                Content::ToolCall {
76                    name, arguments, ..
77                } => {
78                    total += self.count(name);
79                    total += self.count(&arguments.to_string());
80                }
81                Content::ToolResult { output, .. } => {
82                    total += self.count(output);
83                }
84                Content::MultiPart { parts } => {
85                    for part in parts {
86                        match part {
87                            Content::Text { text } => total += self.count(text),
88                            Content::ToolCall {
89                                name, arguments, ..
90                            } => {
91                                total += self.count(name);
92                                total += self.count(&arguments.to_string());
93                            }
94                            Content::ToolResult { output, .. } => {
95                                total += self.count(output);
96                            }
97                            _ => total += 10,
98                        }
99                    }
100                }
101            }
102        }
103        total + 3 // reply priming overhead
104    }
105}
106
107/// Sanitize tool_call → tool_result ordering in a message sequence.
108///
109/// This runs provider-agnostically *before* messages are sent to any LLM provider,
110/// ensuring that:
111/// 1. Every tool_result has a matching tool_call earlier in the sequence.
112/// 2. No non-tool messages (system hints, summaries) appear between an assistant's
113///    tool_call message and its corresponding user/tool tool_result message.
114/// 3. Orphaned tool_results (no matching tool_call) are removed.
115///
116/// This fixes issues caused by compression moving pinned tool_results out of order,
117/// system routing hints persisting between call/result, and summary injection.
118pub fn sanitize_tool_sequence(messages: &mut Vec<Message>) {
119    // --- Pass 1: Collect all tool_call IDs from assistant messages ---
120    let mut tool_call_ids: HashSet<String> = HashSet::new();
121    for msg in messages.iter() {
122        if msg.role != Role::Assistant {
123            continue;
124        }
125        collect_tool_call_ids(&msg.content, &mut tool_call_ids);
126    }
127
128    // --- Pass 2: Remove orphaned tool_results (no matching tool_call) ---
129    messages.retain(|msg| {
130        if msg.role != Role::Tool {
131            // Check user messages too — Anthropic sends tool_result as user role
132            if msg.role == Role::User {
133                if let Content::ToolResult { call_id, .. } = &msg.content {
134                    if !tool_call_ids.contains(call_id) {
135                        warn!(
136                            call_id = call_id.as_str(),
137                            "Removing orphaned tool_result (no matching tool_call)"
138                        );
139                        return false;
140                    }
141                }
142            }
143            return true;
144        }
145        match &msg.content {
146            Content::ToolResult { call_id, .. } => {
147                if tool_call_ids.contains(call_id) {
148                    true
149                } else {
150                    warn!(
151                        call_id = call_id.as_str(),
152                        "Removing orphaned tool_result (no matching tool_call)"
153                    );
154                    false
155                }
156            }
157            Content::MultiPart { parts } => {
158                // Keep the message if at least one tool_result has a matching call
159                let has_valid = parts.iter().any(|p| {
160                    if let Content::ToolResult { call_id, .. } = p {
161                        tool_call_ids.contains(call_id)
162                    } else {
163                        true
164                    }
165                });
166                if !has_valid {
167                    warn!("Removing multipart tool message with all orphaned tool_results");
168                }
169                has_valid
170            }
171            _ => true,
172        }
173    });
174
175    // --- Pass 3: Relocate system messages that appear between tool_call and tool_result ---
176    // Strategy: find each assistant message with tool_call(s), then check if the
177    // immediately following message is a system message. If so, move the system message
178    // before the assistant message.
179    let mut i = 0;
180    while i + 1 < messages.len() {
181        let has_tool_call =
182            messages[i].role == Role::Assistant && content_has_tool_call(&messages[i].content);
183
184        if has_tool_call {
185            // Check if next message is a system message (should be tool_result instead)
186            let mut j = i + 1;
187            let mut system_messages_to_relocate = Vec::new();
188            while j < messages.len() && messages[j].role == Role::System {
189                system_messages_to_relocate.push(j);
190                j += 1;
191            }
192            // Move system messages before the assistant tool_call message
193            if !system_messages_to_relocate.is_empty() {
194                // Extract system messages in reverse order to maintain relative order
195                let mut extracted: Vec<Message> = Vec::new();
196                for &idx in system_messages_to_relocate.iter().rev() {
197                    extracted.push(messages.remove(idx));
198                }
199                extracted.reverse();
200                // Insert them before position i
201                for (offset, msg) in extracted.into_iter().enumerate() {
202                    messages.insert(i + offset, msg);
203                    i += 1; // adjust i to still point to the assistant message
204                }
205            }
206        }
207        i += 1;
208    }
209}
210
211/// Extract all tool_call IDs from a Content value into the given set.
212fn collect_tool_call_ids(content: &Content, ids: &mut HashSet<String>) {
213    match content {
214        Content::ToolCall { id, .. } => {
215            ids.insert(id.clone());
216        }
217        Content::MultiPart { parts } => {
218            for part in parts {
219                collect_tool_call_ids(part, ids);
220            }
221        }
222        _ => {}
223    }
224}
225
226/// Check whether a Content value contains at least one tool_call.
227fn content_has_tool_call(content: &Content) -> bool {
228    match content {
229        Content::ToolCall { .. } => true,
230        Content::MultiPart { parts } => parts.iter().any(content_has_tool_call),
231        _ => false,
232    }
233}
234
235/// The Brain wraps an LLM provider and adds higher-level logic:
236/// prompt construction, cost tracking, and model selection.
237pub struct Brain {
238    provider: Arc<dyn LlmProvider>,
239    system_prompt: String,
240    total_usage: TokenUsage,
241    total_cost: CostEstimate,
242    token_counter: TokenCounter,
243    /// Optional knowledge addendum appended to system prompt from distilled rules.
244    knowledge_addendum: String,
245}
246
247impl Brain {
248    pub fn new(provider: Arc<dyn LlmProvider>, system_prompt: impl Into<String>) -> Self {
249        let model_name = provider.model_name().to_string();
250        Self {
251            provider,
252            system_prompt: system_prompt.into(),
253            total_usage: TokenUsage::default(),
254            total_cost: CostEstimate::default(),
255            token_counter: TokenCounter::for_model(&model_name),
256            knowledge_addendum: String::new(),
257        }
258    }
259
260    /// Set knowledge addendum (distilled rules) to append to the system prompt.
261    pub fn set_knowledge_addendum(&mut self, addendum: String) {
262        self.knowledge_addendum = addendum;
263    }
264
265    /// Estimate token count for messages using tiktoken-rs.
266    pub fn estimate_tokens(&self, messages: &[Message]) -> usize {
267        self.token_counter.count_messages(messages)
268    }
269
270    /// Construct messages for the LLM with system prompt prepended.
271    ///
272    /// If a knowledge addendum has been set via `set_knowledge_addendum()`,
273    /// it is automatically appended to the system prompt.
274    ///
275    /// After assembly, [`sanitize_tool_sequence`] runs to ensure tool_call→tool_result
276    /// ordering is never broken regardless of compression, pinning, or system message injection.
277    pub fn build_messages(&self, conversation: &[Message]) -> Vec<Message> {
278        let mut messages = Vec::with_capacity(conversation.len() + 1);
279        if self.knowledge_addendum.is_empty() {
280            messages.push(Message::system(&self.system_prompt));
281        } else {
282            let augmented = format!("{}{}", self.system_prompt, self.knowledge_addendum);
283            messages.push(Message::system(&augmented));
284        }
285        messages.extend_from_slice(conversation);
286        sanitize_tool_sequence(&mut messages);
287        messages
288    }
289
290    /// Send a completion request and return the response, tracking usage.
291    pub async fn think(
292        &mut self,
293        conversation: &[Message],
294        tools: Option<Vec<ToolDefinition>>,
295    ) -> Result<CompletionResponse, LlmError> {
296        let messages = self.build_messages(conversation);
297        let token_estimate = self.provider.estimate_tokens(&messages);
298        let context_limit = self.provider.context_window();
299
300        if token_estimate > context_limit {
301            return Err(LlmError::ContextOverflow {
302                used: token_estimate,
303                limit: context_limit,
304            });
305        }
306
307        debug!(
308            model = self.provider.model_name(),
309            estimated_tokens = token_estimate,
310            "Sending completion request"
311        );
312
313        let request = CompletionRequest {
314            messages,
315            tools,
316            temperature: 0.7,
317            max_tokens: None,
318            stop_sequences: Vec::new(),
319            model: None,
320        };
321
322        let response = self.provider.complete(request).await?;
323
324        // Track usage
325        self.total_usage.accumulate(&response.usage);
326        let (input_rate, output_rate) = self.provider.cost_per_token();
327        let cost = CostEstimate {
328            input_cost: response.usage.input_tokens as f64 * input_rate,
329            output_cost: response.usage.output_tokens as f64 * output_rate,
330        };
331        self.total_cost.accumulate(&cost);
332
333        info!(
334            input_tokens = response.usage.input_tokens,
335            output_tokens = response.usage.output_tokens,
336            cost = format!("${:.4}", cost.total()),
337            "Completion received"
338        );
339
340        Ok(response)
341    }
342
343    /// Send a completion request with retry logic and exponential backoff.
344    ///
345    /// Retries on transient errors (RateLimited, Timeout, Connection) up to
346    /// `max_retries` times with exponential backoff (1s, 2s, 4s, ..., capped at 32s).
347    /// Non-transient errors are returned immediately.
348    pub async fn think_with_retry(
349        &mut self,
350        conversation: &[Message],
351        tools: Option<Vec<ToolDefinition>>,
352        max_retries: usize,
353    ) -> Result<CompletionResponse, LlmError> {
354        let mut last_error = None;
355
356        for attempt in 0..=max_retries {
357            match self.think(conversation, tools.clone()).await {
358                Ok(response) => return Ok(response),
359                Err(e) if Self::is_retryable(&e) => {
360                    if attempt < max_retries {
361                        let backoff_secs = std::cmp::min(1u64 << attempt, 32);
362                        let wait = match &e {
363                            LlmError::RateLimited { retry_after_secs } => {
364                                std::cmp::max(*retry_after_secs, backoff_secs)
365                            }
366                            _ => backoff_secs,
367                        };
368                        info!(
369                            attempt = attempt + 1,
370                            max_retries,
371                            backoff_secs = wait,
372                            error = %e,
373                            "Retrying after transient error"
374                        );
375                        tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
376                        last_error = Some(e);
377                    } else {
378                        return Err(e);
379                    }
380                }
381                Err(e) => return Err(e),
382            }
383        }
384
385        Err(last_error.unwrap_or(LlmError::Connection {
386            message: "Max retries exceeded".to_string(),
387        }))
388    }
389
390    /// Check if an LLM error is transient and should be retried.
391    pub fn is_retryable(error: &LlmError) -> bool {
392        matches!(
393            error,
394            LlmError::RateLimited { .. } | LlmError::Timeout { .. } | LlmError::Connection { .. }
395        )
396    }
397
398    /// Send a streaming completion request, returning events via channel.
399    pub async fn think_streaming(
400        &mut self,
401        conversation: &[Message],
402        tools: Option<Vec<ToolDefinition>>,
403        tx: mpsc::Sender<StreamEvent>,
404    ) -> Result<(), LlmError> {
405        let messages = self.build_messages(conversation);
406
407        let request = CompletionRequest {
408            messages,
409            tools,
410            temperature: 0.7,
411            max_tokens: None,
412            stop_sequences: Vec::new(),
413            model: None,
414        };
415
416        self.provider.complete_streaming(request, tx).await
417    }
418
419    /// Get total token usage across all calls.
420    pub fn total_usage(&self) -> &TokenUsage {
421        &self.total_usage
422    }
423
424    /// Get total cost across all calls.
425    pub fn total_cost(&self) -> &CostEstimate {
426        &self.total_cost
427    }
428
429    /// Get the model name.
430    pub fn model_name(&self) -> &str {
431        self.provider.model_name()
432    }
433
434    /// Get the context window size.
435    pub fn context_window(&self) -> usize {
436        self.provider.context_window()
437    }
438
439    /// Get cost rates (input_per_token, output_per_token) from the provider.
440    pub fn provider_cost_rates(&self) -> (f64, f64) {
441        self.provider.cost_per_token()
442    }
443
444    /// Get a reference to the underlying LLM provider.
445    pub fn provider(&self) -> &dyn LlmProvider {
446        &*self.provider
447    }
448
449    /// Get a cloneable Arc handle to the underlying LLM provider.
450    pub fn provider_arc(&self) -> Arc<dyn LlmProvider> {
451        Arc::clone(&self.provider)
452    }
453
454    /// Track usage and cost from an external completion (e.g., streaming).
455    pub fn track_usage(&mut self, usage: &TokenUsage) {
456        self.total_usage.accumulate(usage);
457        let (input_rate, output_rate) = self.provider.cost_per_token();
458        let cost = CostEstimate {
459            input_cost: usage.input_tokens as f64 * input_rate,
460            output_cost: usage.output_tokens as f64 * output_rate,
461        };
462        self.total_cost.accumulate(&cost);
463    }
464
465    /// Get the current token usage as a fraction of the context window.
466    pub fn context_usage_ratio(&self, conversation: &[Message]) -> f32 {
467        let messages = self.build_messages(conversation);
468        let tokens = self.provider.estimate_tokens(&messages);
469        tokens as f32 / self.provider.context_window() as f32
470    }
471}
472
473/// A mock LLM provider for testing and development.
474pub struct MockLlmProvider {
475    model: String,
476    context_window: usize,
477    responses: std::sync::Mutex<Vec<CompletionResponse>>,
478}
479
480impl MockLlmProvider {
481    pub fn new() -> Self {
482        Self {
483            model: "mock-model".to_string(),
484            context_window: 128_000,
485            responses: std::sync::Mutex::new(Vec::new()),
486        }
487    }
488
489    /// Create a MockLlmProvider that always returns the given text.
490    ///
491    /// Queues multiple copies of the response so it can handle multiple calls.
492    pub fn with_response(text: &str) -> Self {
493        let provider = Self::new();
494        for _ in 0..20 {
495            provider.queue_response(Self::text_response(text));
496        }
497        provider
498    }
499
500    /// Queue a response to be returned by the next `complete` call.
501    pub fn queue_response(&self, response: CompletionResponse) {
502        self.responses.lock().unwrap().push(response);
503    }
504
505    /// Create a simple text response for testing.
506    pub fn text_response(text: &str) -> CompletionResponse {
507        CompletionResponse {
508            message: Message::assistant(text),
509            usage: TokenUsage {
510                input_tokens: 100,
511                output_tokens: 50,
512            },
513            model: "mock-model".to_string(),
514            finish_reason: Some("stop".to_string()),
515        }
516    }
517
518    /// Create a tool call response for testing.
519    pub fn tool_call_response(tool_name: &str, arguments: serde_json::Value) -> CompletionResponse {
520        let call_id = format!("call_{}", uuid::Uuid::new_v4());
521        CompletionResponse {
522            message: Message::new(
523                Role::Assistant,
524                Content::tool_call(&call_id, tool_name, arguments),
525            ),
526            usage: TokenUsage {
527                input_tokens: 100,
528                output_tokens: 30,
529            },
530            model: "mock-model".to_string(),
531            finish_reason: Some("tool_calls".to_string()),
532        }
533    }
534
535    /// Create a multipart response (text + tool call) for testing.
536    pub fn multipart_response(
537        text: &str,
538        tool_name: &str,
539        arguments: serde_json::Value,
540    ) -> CompletionResponse {
541        let call_id = format!("call_{}", uuid::Uuid::new_v4());
542        CompletionResponse {
543            message: Message::new(
544                Role::Assistant,
545                Content::MultiPart {
546                    parts: vec![
547                        Content::text(text),
548                        Content::tool_call(&call_id, tool_name, arguments),
549                    ],
550                },
551            ),
552            usage: TokenUsage {
553                input_tokens: 100,
554                output_tokens: 50,
555            },
556            model: "mock-model".to_string(),
557            finish_reason: Some("tool_calls".to_string()),
558        }
559    }
560}
561
562impl Default for MockLlmProvider {
563    fn default() -> Self {
564        Self::new()
565    }
566}
567
568#[async_trait]
569impl LlmProvider for MockLlmProvider {
570    async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
571        let mut responses = self.responses.lock().unwrap();
572        if responses.is_empty() {
573            Ok(MockLlmProvider::text_response(
574                "I'm a mock LLM. No queued responses available.",
575            ))
576        } else {
577            Ok(responses.remove(0))
578        }
579    }
580
581    async fn complete_streaming(
582        &self,
583        request: CompletionRequest,
584        tx: mpsc::Sender<StreamEvent>,
585    ) -> Result<(), LlmError> {
586        let response = self.complete(request).await?;
587        if let Some(text) = response.message.content.as_text() {
588            for word in text.split_whitespace() {
589                let _ = tx.send(StreamEvent::Token(format!("{} ", word))).await;
590            }
591        }
592        let _ = tx
593            .send(StreamEvent::Done {
594                usage: response.usage,
595            })
596            .await;
597        Ok(())
598    }
599
600    fn estimate_tokens(&self, messages: &[Message]) -> usize {
601        // Rough estimate: ~4 chars per token
602        messages
603            .iter()
604            .map(|m| match &m.content {
605                Content::Text { text } => text.len() / 4,
606                Content::ToolCall { arguments, .. } => arguments.to_string().len() / 4,
607                Content::ToolResult { output, .. } => output.len() / 4,
608                Content::MultiPart { parts } => parts
609                    .iter()
610                    .map(|p| match p {
611                        Content::Text { text } => text.len() / 4,
612                        _ => 50,
613                    })
614                    .sum(),
615            })
616            .sum::<usize>()
617            + 100 // overhead for message structure
618    }
619
620    fn context_window(&self) -> usize {
621        self.context_window
622    }
623
624    fn supports_tools(&self) -> bool {
625        true
626    }
627
628    fn cost_per_token(&self) -> (f64, f64) {
629        (0.0, 0.0) // free for mock
630    }
631
632    fn model_name(&self) -> &str {
633        &self.model
634    }
635}
636
637/// The system prompt used by default for the Rustant agent.
638pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are Rustant, a privacy-first autonomous personal assistant built in Rust. You help users with software engineering, daily productivity, and macOS automation tasks.
639
640CRITICAL — Tool selection rules:
641- You MUST use the dedicated tool for each task. Do NOT use shell_exec when a dedicated tool exists.
642- For clipboard: call macos_clipboard with {"action":"read"} or {"action":"write","content":"..."}
643- For battery/disk/CPU/version: call macos_system_info with {"action":"battery"}, {"action":"version"}, etc.
644- For running apps: call macos_app_control with {"action":"list_running"}
645- For calendar: call macos_calendar. For reminders: call macos_reminders. For notes: call macos_notes.
646- For screenshots: call macos_screenshot. For Spotlight search: call macos_spotlight.
647- shell_exec is a last resort — only use it for commands that have no dedicated tool.
648- Do NOT use document_read for clipboard or system operations — it reads document files only.
649- If a tool call fails, try a different tool or action — do NOT ask the user whether to proceed. Act autonomously.
650- Never call ask_user more than once per task unless the user's answer was genuinely unclear.
651
652Other behaviors:
653- Always read a file before modifying it
654- Prefer small, focused changes over large rewrites
655- Respect file boundaries and permissions
656
657Tool categories:
658
659File & Code: file_read, file_write, file_list, file_search, file_patch, smart_edit, codebase_search, document_read (for PDFs/docs only)
660Git: git_status, git_diff, git_commit
661Shell: shell_exec (last resort only)
662Utilities: calculator, datetime, echo, web_search (for web searches — uses DuckDuckGo, preferred over safari/shell), web_fetch (for fetching URL content — preferred over safari/shell), http_api, template, pdf_generate, compress, file_organizer
663Personal Productivity: pomodoro, inbox, finance, flashcards, travel, relationships
664Research & Intelligence: arxiv_research (ALWAYS use this for paper/preprint searches — it has a built-in arXiv API client, never use safari/curl for arXiv), knowledge_graph (concepts, papers, relationships, BFS traversal), experiment_tracker (hypotheses, experiments, evidence)
665Code Analysis: code_intelligence (architecture, patterns, tech debt, API surface, dependency map), codebase_search
666Professional Growth: skill_tracker (proficiency, practice logs, learning paths), career_intel (goals, achievements, portfolio), content_engine (multi-platform content pipeline, calendar)
667Life Management: life_planner (energy-aware scheduling, deadlines, habits), system_monitor (service topology, health checks, incidents), privacy_manager (data boundaries, compliance, audit), self_improvement (usage patterns, performance, preferences)
668macOS Native: macos_calendar, macos_reminders, macos_notes, macos_clipboard, macos_system_info, macos_app_control, macos_notification, macos_screenshot, macos_spotlight, macos_finder, macos_focus_mode, macos_mail, macos_music, macos_shortcuts, macos_meeting_recorder (use 'record_and_transcribe' for full meeting flow — TTS announcement, silence auto-stop, auto-transcribe to Notes.app; use 'stop' to end manually), macos_daily_briefing, macos_contacts, homekit
669macOS Automation: macos_gui_scripting, macos_accessibility, macos_screen_analyze, macos_safari (only for Safari-specific tasks like tab management — for web searches use web_search, for fetching pages use web_fetch)
670iMessage: imessage_contacts, imessage_send, imessage_read
671Voice: macos_say
672
673Workflows (structured multi-step templates — run via shell_exec "rustant workflow run <name>"):
674  code_review, refactor, test_generation, documentation, dependency_update,
675  security_scan, deployment, incident_response, morning_briefing, pr_review,
676  dependency_audit, changelog, meeting_recorder, daily_briefing_full,
677  end_of_day_summary, app_automation, email_triage, arxiv_research,
678  knowledge_graph, experiment_tracking, code_analysis, content_pipeline,
679  skill_development, career_planning, system_monitoring, life_planning,
680  privacy_audit, self_improvement_loop
681When a user asks for one of these tasks by name or description, execute the workflow or accomplish it step by step.
682
683Security rules:
684- Never execute commands that could damage the system or leak credentials
685- Do not read or write files containing secrets (.env, *.key, *.pem) unless explicitly asked
686- Sanitize all user input before passing to shell or AppleScript commands
687- When unsure about a destructive action, use ask_user to confirm first"#;
688
689// ---------------------------------------------------------------------------
690// Token Budget Manager
691// ---------------------------------------------------------------------------
692
693/// Tracks token usage against configurable budgets and predicts costs
694/// before execution. Can warn or halt when budgets are exceeded.
695pub struct TokenBudgetManager {
696    session_limit_usd: f64,
697    task_limit_usd: f64,
698    session_token_limit: usize,
699    halt_on_exceed: bool,
700    session_cost: f64,
701    task_cost: f64,
702    session_tokens: usize,
703}
704
705/// The result of a pre-call budget check.
706#[derive(Debug, Clone, PartialEq)]
707pub enum BudgetCheckResult {
708    /// Budget is within limits, proceed.
709    Ok,
710    /// Budget warning — approaching limit but not exceeded.
711    Warning { message: String, usage_pct: f64 },
712    /// Budget exceeded — should halt if configured.
713    Exceeded { message: String },
714}
715
716impl TokenBudgetManager {
717    /// Create a new budget manager from config. Passing `None` creates
718    /// an unlimited manager that always returns `Ok`.
719    pub fn new(config: Option<&crate::config::BudgetConfig>) -> Self {
720        match config {
721            Some(cfg) => Self {
722                session_limit_usd: cfg.session_limit_usd,
723                task_limit_usd: cfg.task_limit_usd,
724                session_token_limit: cfg.session_token_limit,
725                halt_on_exceed: cfg.halt_on_exceed,
726                session_cost: 0.0,
727                task_cost: 0.0,
728                session_tokens: 0,
729            },
730            None => Self {
731                session_limit_usd: 0.0,
732                task_limit_usd: 0.0,
733                session_token_limit: 0,
734                halt_on_exceed: false,
735                session_cost: 0.0,
736                task_cost: 0.0,
737                session_tokens: 0,
738            },
739        }
740    }
741
742    /// Reset task-level tracking (call at start of each new task).
743    pub fn reset_task(&mut self) {
744        self.task_cost = 0.0;
745    }
746
747    /// Record usage after an LLM call completes.
748    pub fn record_usage(&mut self, usage: &TokenUsage, cost: &CostEstimate) {
749        self.session_cost += cost.total();
750        self.task_cost += cost.total();
751        self.session_tokens += usage.total();
752    }
753
754    /// Estimate cost for an upcoming LLM call and check against budgets.
755    ///
756    /// `estimated_input_tokens` is the count of tokens in the request.
757    /// `input_rate` and `output_rate` are the per-token costs from the provider.
758    /// Output tokens are estimated at 0.5x input as a heuristic.
759    pub fn check_budget(
760        &self,
761        estimated_input_tokens: usize,
762        input_rate: f64,
763        output_rate: f64,
764    ) -> BudgetCheckResult {
765        // Predict output tokens as ~50% of input (heuristic)
766        let predicted_output = estimated_input_tokens / 2;
767        let predicted_cost =
768            (estimated_input_tokens as f64 * input_rate) + (predicted_output as f64 * output_rate);
769
770        let projected_session_cost = self.session_cost + predicted_cost;
771        let projected_task_cost = self.task_cost + predicted_cost;
772        let projected_session_tokens =
773            self.session_tokens + estimated_input_tokens + predicted_output;
774
775        // Check session cost limit
776        if self.session_limit_usd > 0.0 && projected_session_cost > self.session_limit_usd {
777            return BudgetCheckResult::Exceeded {
778                message: format!(
779                    "Session cost ${:.4} would exceed limit ${:.4}",
780                    projected_session_cost, self.session_limit_usd
781                ),
782            };
783        }
784
785        // Check task cost limit
786        if self.task_limit_usd > 0.0 && projected_task_cost > self.task_limit_usd {
787            return BudgetCheckResult::Exceeded {
788                message: format!(
789                    "Task cost ${:.4} would exceed limit ${:.4}",
790                    projected_task_cost, self.task_limit_usd
791                ),
792            };
793        }
794
795        // Check session token limit
796        if self.session_token_limit > 0 && projected_session_tokens > self.session_token_limit {
797            return BudgetCheckResult::Exceeded {
798                message: format!(
799                    "Session tokens {} would exceed limit {}",
800                    projected_session_tokens, self.session_token_limit
801                ),
802            };
803        }
804
805        // Check if approaching limits (>80%)
806        if self.session_limit_usd > 0.0 {
807            let pct = projected_session_cost / self.session_limit_usd;
808            if pct > 0.8 {
809                return BudgetCheckResult::Warning {
810                    message: format!(
811                        "Session cost at {:.0}% of ${:.4} limit",
812                        pct * 100.0,
813                        self.session_limit_usd
814                    ),
815                    usage_pct: pct,
816                };
817            }
818        }
819
820        BudgetCheckResult::Ok
821    }
822
823    /// Whether budget enforcement should halt execution on exceed.
824    pub fn should_halt_on_exceed(&self) -> bool {
825        self.halt_on_exceed
826    }
827
828    /// Current session cost.
829    pub fn session_cost(&self) -> f64 {
830        self.session_cost
831    }
832
833    /// Current task cost.
834    pub fn task_cost(&self) -> f64 {
835        self.task_cost
836    }
837
838    /// Current session token count.
839    pub fn session_tokens(&self) -> usize {
840        self.session_tokens
841    }
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847
848    #[tokio::test]
849    async fn test_mock_provider_default_response() {
850        let provider = MockLlmProvider::new();
851        let request = CompletionRequest::default();
852        let response = provider.complete(request).await.unwrap();
853        assert!(response.message.content.as_text().is_some());
854    }
855
856    #[tokio::test]
857    async fn test_mock_provider_queued_responses() {
858        let provider = MockLlmProvider::new();
859        provider.queue_response(MockLlmProvider::text_response("first"));
860        provider.queue_response(MockLlmProvider::text_response("second"));
861
862        let r1 = provider
863            .complete(CompletionRequest::default())
864            .await
865            .unwrap();
866        assert_eq!(r1.message.content.as_text(), Some("first"));
867
868        let r2 = provider
869            .complete(CompletionRequest::default())
870            .await
871            .unwrap();
872        assert_eq!(r2.message.content.as_text(), Some("second"));
873    }
874
875    #[tokio::test]
876    async fn test_mock_provider_streaming() {
877        let provider = MockLlmProvider::new();
878        provider.queue_response(MockLlmProvider::text_response("hello world"));
879
880        let (tx, mut rx) = mpsc::channel(32);
881        provider
882            .complete_streaming(CompletionRequest::default(), tx)
883            .await
884            .unwrap();
885
886        let mut tokens = Vec::new();
887        while let Some(event) = rx.recv().await {
888            match event {
889                StreamEvent::Token(t) => tokens.push(t),
890                StreamEvent::Done { .. } => break,
891                _ => {}
892            }
893        }
894        assert_eq!(tokens.len(), 2); // "hello " and "world "
895    }
896
897    #[test]
898    fn test_mock_provider_token_estimation() {
899        let provider = MockLlmProvider::new();
900        let messages = vec![Message::user("Hello, this is a test message.")];
901        let tokens = provider.estimate_tokens(&messages);
902        assert!(tokens > 0);
903    }
904
905    #[test]
906    fn test_mock_provider_properties() {
907        let provider = MockLlmProvider::new();
908        assert_eq!(provider.context_window(), 128_000);
909        assert!(provider.supports_tools());
910        assert_eq!(provider.cost_per_token(), (0.0, 0.0));
911        assert_eq!(provider.model_name(), "mock-model");
912    }
913
914    #[tokio::test]
915    async fn test_brain_think() {
916        let provider = Arc::new(MockLlmProvider::new());
917        provider.queue_response(MockLlmProvider::text_response("I can help with that."));
918
919        let mut brain = Brain::new(provider, "You are a helpful assistant.");
920        let conversation = vec![Message::user("Help me refactor")];
921
922        let response = brain.think(&conversation, None).await.unwrap();
923        assert_eq!(
924            response.message.content.as_text(),
925            Some("I can help with that.")
926        );
927        assert!(brain.total_usage().total() > 0);
928    }
929
930    #[tokio::test]
931    async fn test_brain_builds_messages_with_system_prompt() {
932        let provider = Arc::new(MockLlmProvider::new());
933        let brain = Brain::new(provider, "system prompt");
934        let conversation = vec![Message::user("hello")];
935
936        let messages = brain.build_messages(&conversation);
937        assert_eq!(messages.len(), 2);
938        assert_eq!(messages[0].role, Role::System);
939        assert_eq!(messages[0].content.as_text(), Some("system prompt"));
940        assert_eq!(messages[1].role, Role::User);
941    }
942
943    #[test]
944    fn test_brain_context_usage_ratio() {
945        let provider = Arc::new(MockLlmProvider::new());
946        let brain = Brain::new(provider, "system");
947        let conversation = vec![Message::user("short message")];
948
949        let ratio = brain.context_usage_ratio(&conversation);
950        assert!(ratio > 0.0);
951        assert!(ratio < 1.0);
952    }
953
954    #[test]
955    fn test_mock_tool_call_response() {
956        let response = MockLlmProvider::tool_call_response(
957            "file_read",
958            serde_json::json!({"path": "/tmp/test.rs"}),
959        );
960        match &response.message.content {
961            Content::ToolCall {
962                name, arguments, ..
963            } => {
964                assert_eq!(name, "file_read");
965                assert_eq!(arguments["path"], "/tmp/test.rs");
966            }
967            _ => panic!("Expected ToolCall content"),
968        }
969    }
970
971    #[test]
972    fn test_default_system_prompt() {
973        assert!(DEFAULT_SYSTEM_PROMPT.contains("Rustant"));
974        assert!(DEFAULT_SYSTEM_PROMPT.contains("autonomous"));
975    }
976
977    #[test]
978    fn test_is_retryable() {
979        assert!(Brain::is_retryable(&LlmError::RateLimited {
980            retry_after_secs: 5
981        }));
982        assert!(Brain::is_retryable(&LlmError::Timeout { timeout_secs: 30 }));
983        assert!(Brain::is_retryable(&LlmError::Connection {
984            message: "reset".into()
985        }));
986        assert!(!Brain::is_retryable(&LlmError::ContextOverflow {
987            used: 200_000,
988            limit: 128_000
989        }));
990        assert!(!Brain::is_retryable(&LlmError::AuthFailed {
991            provider: "openai".into()
992        }));
993    }
994
995    /// A mock provider that fails N times before succeeding.
996    struct FailingProvider {
997        failures_remaining: std::sync::Mutex<usize>,
998        error_type: String,
999        success_response: CompletionResponse,
1000    }
1001
1002    impl FailingProvider {
1003        fn new(failures: usize, error_type: &str) -> Self {
1004            Self {
1005                failures_remaining: std::sync::Mutex::new(failures),
1006                error_type: error_type.to_string(),
1007                success_response: MockLlmProvider::text_response("Success after retry"),
1008            }
1009        }
1010    }
1011
1012    #[async_trait]
1013    impl LlmProvider for FailingProvider {
1014        async fn complete(
1015            &self,
1016            _request: CompletionRequest,
1017        ) -> Result<CompletionResponse, LlmError> {
1018            let mut remaining = self.failures_remaining.lock().unwrap();
1019            if *remaining > 0 {
1020                *remaining -= 1;
1021                match self.error_type.as_str() {
1022                    "rate_limited" => Err(LlmError::RateLimited {
1023                        retry_after_secs: 0,
1024                    }),
1025                    "timeout" => Err(LlmError::Timeout { timeout_secs: 5 }),
1026                    "connection" => Err(LlmError::Connection {
1027                        message: "connection reset".into(),
1028                    }),
1029                    _ => Err(LlmError::ApiRequest {
1030                        message: "non-retryable".into(),
1031                    }),
1032                }
1033            } else {
1034                Ok(self.success_response.clone())
1035            }
1036        }
1037
1038        async fn complete_streaming(
1039            &self,
1040            _request: CompletionRequest,
1041            _tx: mpsc::Sender<StreamEvent>,
1042        ) -> Result<(), LlmError> {
1043            Ok(())
1044        }
1045
1046        fn estimate_tokens(&self, _messages: &[Message]) -> usize {
1047            100
1048        }
1049        fn context_window(&self) -> usize {
1050            128_000
1051        }
1052        fn supports_tools(&self) -> bool {
1053            true
1054        }
1055        fn cost_per_token(&self) -> (f64, f64) {
1056            (0.0, 0.0)
1057        }
1058        fn model_name(&self) -> &str {
1059            "failing-mock"
1060        }
1061    }
1062
1063    #[tokio::test]
1064    async fn test_think_with_retry_succeeds_after_failures() {
1065        let provider = Arc::new(FailingProvider::new(2, "connection"));
1066        let mut brain = Brain::new(provider, "system");
1067        let conversation = vec![Message::user("test")];
1068
1069        let result = brain.think_with_retry(&conversation, None, 3).await;
1070        assert!(result.is_ok());
1071        assert_eq!(
1072            result.unwrap().message.content.as_text(),
1073            Some("Success after retry")
1074        );
1075    }
1076
1077    #[tokio::test]
1078    async fn test_think_with_retry_exhausted() {
1079        let provider = Arc::new(FailingProvider::new(5, "timeout"));
1080        let mut brain = Brain::new(provider, "system");
1081        let conversation = vec![Message::user("test")];
1082
1083        let result = brain.think_with_retry(&conversation, None, 2).await;
1084        assert!(result.is_err());
1085        assert!(matches!(result.unwrap_err(), LlmError::Timeout { .. }));
1086    }
1087
1088    #[tokio::test]
1089    async fn test_think_with_retry_non_retryable_fails_immediately() {
1090        let provider = Arc::new(FailingProvider::new(1, "non_retryable"));
1091        let mut brain = Brain::new(provider, "system");
1092        let conversation = vec![Message::user("test")];
1093
1094        let result = brain.think_with_retry(&conversation, None, 3).await;
1095        assert!(result.is_err());
1096        assert!(matches!(result.unwrap_err(), LlmError::ApiRequest { .. }));
1097    }
1098
1099    #[tokio::test]
1100    async fn test_think_with_retry_rate_limited() {
1101        let provider = Arc::new(FailingProvider::new(1, "rate_limited"));
1102        let mut brain = Brain::new(provider, "system");
1103        let conversation = vec![Message::user("test")];
1104
1105        let result = brain.think_with_retry(&conversation, None, 2).await;
1106        assert!(result.is_ok());
1107    }
1108
1109    #[test]
1110    fn test_track_usage() {
1111        let provider = Arc::new(MockLlmProvider::new());
1112        let mut brain = Brain::new(provider, "system");
1113
1114        let usage = TokenUsage {
1115            input_tokens: 100,
1116            output_tokens: 50,
1117        };
1118        brain.track_usage(&usage);
1119
1120        assert_eq!(brain.total_usage().input_tokens, 100);
1121        assert_eq!(brain.total_usage().output_tokens, 50);
1122    }
1123
1124    #[test]
1125    fn test_token_counter_basic() {
1126        let counter = TokenCounter::for_model("gpt-4o");
1127        let count = counter.count("Hello, world!");
1128        assert!(count > 0);
1129        assert!(count < 20); // should be ~4 tokens
1130    }
1131
1132    #[test]
1133    fn test_token_counter_messages() {
1134        let counter = TokenCounter::for_model("gpt-4o");
1135        let messages = vec![
1136            Message::system("You are a helpful assistant."),
1137            Message::user("What is 2 + 2?"),
1138        ];
1139        let count = counter.count_messages(&messages);
1140        assert!(count > 5);
1141        assert!(count < 100);
1142    }
1143
1144    #[test]
1145    fn test_token_counter_unknown_model_falls_back() {
1146        let counter = TokenCounter::for_model("unknown-model-xyz");
1147        let count = counter.count("Hello");
1148        assert!(count > 0); // Should use cl100k_base fallback
1149    }
1150
1151    #[test]
1152    fn test_brain_estimate_tokens() {
1153        let provider = Arc::new(MockLlmProvider::new());
1154        let brain = Brain::new(provider, "system");
1155        let messages = vec![Message::user("Hello, this is a test.")];
1156        let estimate = brain.estimate_tokens(&messages);
1157        assert!(estimate > 0);
1158    }
1159
1160    // --- sanitize_tool_sequence tests ---
1161
1162    #[test]
1163    fn test_sanitize_removes_orphaned_tool_results() {
1164        let mut messages = vec![
1165            Message::system("You are a helper."),
1166            Message::user("do something"),
1167            // Orphaned tool_result — no matching tool_call
1168            Message::tool_result("call_orphan_123", "some result", false),
1169            Message::assistant("Done!"),
1170        ];
1171
1172        super::sanitize_tool_sequence(&mut messages);
1173
1174        // The orphaned tool_result should be removed
1175        assert_eq!(messages.len(), 3);
1176        assert_eq!(messages[0].role, Role::System);
1177        assert_eq!(messages[1].role, Role::User);
1178        assert_eq!(messages[2].role, Role::Assistant);
1179    }
1180
1181    #[test]
1182    fn test_sanitize_preserves_valid_sequence() {
1183        let mut messages = vec![
1184            Message::system("system prompt"),
1185            Message::user("read main.rs"),
1186            Message::new(
1187                Role::Assistant,
1188                Content::tool_call(
1189                    "call_1",
1190                    "file_read",
1191                    serde_json::json!({"path": "main.rs"}),
1192                ),
1193            ),
1194            Message::tool_result("call_1", "fn main() {}", false),
1195            Message::assistant("Here is the file content."),
1196        ];
1197
1198        super::sanitize_tool_sequence(&mut messages);
1199
1200        // All messages should be preserved
1201        assert_eq!(messages.len(), 5);
1202    }
1203
1204    #[test]
1205    fn test_sanitize_handles_system_between_call_and_result() {
1206        let mut messages = vec![
1207            Message::system("system prompt"),
1208            Message::user("do something"),
1209            Message::new(
1210                Role::Assistant,
1211                Content::tool_call("call_1", "file_read", serde_json::json!({"path": "x.rs"})),
1212            ),
1213            // System message injected between tool_call and tool_result
1214            Message::system("routing hint: use file_read"),
1215            Message::tool_result("call_1", "file contents", false),
1216            Message::assistant("Done"),
1217        ];
1218
1219        super::sanitize_tool_sequence(&mut messages);
1220
1221        // System message should be moved before the assistant tool_call
1222        // Find the assistant tool_call message
1223        let assistant_idx = messages
1224            .iter()
1225            .position(|m| m.role == Role::Assistant && super::content_has_tool_call(&m.content))
1226            .unwrap();
1227
1228        // The message right after the assistant tool_call should be the tool_result
1229        let next = &messages[assistant_idx + 1];
1230        assert!(
1231            matches!(&next.content, Content::ToolResult { .. })
1232                || next.role == Role::Tool
1233                || next.role == Role::User,
1234            "Expected tool_result after tool_call, got {:?}",
1235            next.role
1236        );
1237    }
1238
1239    #[test]
1240    fn test_sanitize_multipart_tool_call() {
1241        let mut messages = vec![
1242            Message::user("do two things"),
1243            Message::new(
1244                Role::Assistant,
1245                Content::MultiPart {
1246                    parts: vec![
1247                        Content::text("I'll read both files."),
1248                        Content::tool_call(
1249                            "call_a",
1250                            "file_read",
1251                            serde_json::json!({"path": "a.rs"}),
1252                        ),
1253                        Content::tool_call(
1254                            "call_b",
1255                            "file_read",
1256                            serde_json::json!({"path": "b.rs"}),
1257                        ),
1258                    ],
1259                },
1260            ),
1261            Message::tool_result("call_a", "contents of a", false),
1262            Message::tool_result("call_b", "contents of b", false),
1263            // Orphaned tool_result
1264            Message::tool_result("call_nonexistent", "orphan", false),
1265        ];
1266
1267        super::sanitize_tool_sequence(&mut messages);
1268
1269        // Orphaned result removed, valid ones preserved
1270        assert_eq!(messages.len(), 4);
1271    }
1272
1273    #[test]
1274    fn test_sanitize_empty_messages() {
1275        let mut messages: Vec<Message> = vec![];
1276        super::sanitize_tool_sequence(&mut messages);
1277        assert!(messages.is_empty());
1278    }
1279
1280    #[test]
1281    fn test_sanitize_no_tool_messages() {
1282        let mut messages = vec![
1283            Message::system("prompt"),
1284            Message::user("hello"),
1285            Message::assistant("hi"),
1286        ];
1287        super::sanitize_tool_sequence(&mut messages);
1288        assert_eq!(messages.len(), 3);
1289    }
1290}