Skip to main content

rustant_core/
brain.rs

1//! Brain module — LLM provider abstraction and interaction.
2//!
3//! Defines the `LlmProvider` trait for model-agnostic LLM interactions,
4//! and provides an OpenAI-compatible implementation with streaming support.
5
6use crate::error::LlmError;
7use crate::types::{
8    CompletionRequest, CompletionResponse, Content, CostEstimate, Message, Role, StreamEvent,
9    TokenUsage, ToolDefinition,
10};
11use async_trait::async_trait;
12use std::collections::HashSet;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tracing::{debug, info, warn};
16
17/// Trait for LLM providers, supporting both full and streaming completions.
18#[async_trait]
19pub trait LlmProvider: Send + Sync {
20    /// Perform a full completion and return the response.
21    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
22
23    /// Perform a streaming completion, sending events to the channel.
24    async fn complete_streaming(
25        &self,
26        request: CompletionRequest,
27        tx: mpsc::Sender<StreamEvent>,
28    ) -> Result<(), LlmError>;
29
30    /// Estimate the token count for a set of messages.
31    fn estimate_tokens(&self, messages: &[Message]) -> usize;
32
33    /// Return the context window size for this provider/model.
34    fn context_window(&self) -> usize;
35
36    /// Return whether this provider supports tool/function calling.
37    fn supports_tools(&self) -> bool;
38
39    /// Return the cost per token (input, output) in USD.
40    fn cost_per_token(&self) -> (f64, f64);
41
42    /// Return the model name.
43    fn model_name(&self) -> &str;
44}
45
46/// Token counter using tiktoken-rs for accurate BPE tokenization.
47pub struct TokenCounter {
48    bpe: tiktoken_rs::CoreBPE,
49}
50
51impl TokenCounter {
52    /// Create a token counter for the given model.
53    /// Falls back to cl100k_base if the model isn't recognized.
54    pub fn for_model(model: &str) -> Self {
55        let bpe = tiktoken_rs::get_bpe_from_model(model).unwrap_or_else(|_| {
56            tiktoken_rs::cl100k_base().expect("cl100k_base should be available")
57        });
58        Self { bpe }
59    }
60
61    /// Count the number of tokens in a string.
62    pub fn count(&self, text: &str) -> usize {
63        self.bpe.encode_with_special_tokens(text).len()
64    }
65
66    /// Estimate the token count for a set of tool definitions.
67    ///
68    /// Each tool definition adds overhead for the JSON schema structure
69    /// (type/function wrapper), plus the name, description, and parameters.
70    pub fn count_tool_definitions(&self, tools: &[ToolDefinition]) -> usize {
71        let mut total = 0;
72        for tool in tools {
73            total += 10; // struct overhead (type, function wrapper, required fields)
74            total += self.count(&tool.name);
75            total += self.count(&tool.description);
76            total += self.count(&tool.parameters.to_string());
77        }
78        total
79    }
80
81    /// Estimate the token count for a set of messages.
82    /// Adds overhead for message structure (role, separators).
83    pub fn count_messages(&self, messages: &[Message]) -> usize {
84        let mut total = 0;
85        for msg in messages {
86            // Each message has overhead: role token + separators (~4 tokens)
87            total += 4;
88            match &msg.content {
89                Content::Text { text } => total += self.count(text),
90                Content::ToolCall {
91                    name, arguments, ..
92                } => {
93                    total += self.count(name);
94                    total += self.count(&arguments.to_string());
95                }
96                Content::ToolResult { output, .. } => {
97                    total += self.count(output);
98                }
99                Content::MultiPart { parts } => {
100                    for part in parts {
101                        match part {
102                            Content::Text { text } => total += self.count(text),
103                            Content::ToolCall {
104                                name, arguments, ..
105                            } => {
106                                total += self.count(name);
107                                total += self.count(&arguments.to_string());
108                            }
109                            Content::ToolResult { output, .. } => {
110                                total += self.count(output);
111                            }
112                            _ => total += 10,
113                        }
114                    }
115                }
116            }
117        }
118        total + 3 // reply priming overhead
119    }
120}
121
122/// Sanitize tool_call → tool_result ordering in a message sequence.
123///
124/// This runs provider-agnostically *before* messages are sent to any LLM provider,
125/// ensuring that:
126/// 1. Every tool_result has a matching tool_call earlier in the sequence.
127/// 2. No non-tool messages (system hints, summaries) appear between an assistant's
128///    tool_call message and its corresponding user/tool tool_result message.
129/// 3. Orphaned tool_results (no matching tool_call) are removed.
130///
131/// This fixes issues caused by compression moving pinned tool_results out of order,
132/// system routing hints persisting between call/result, and summary injection.
133pub fn sanitize_tool_sequence(messages: &mut Vec<Message>) {
134    // --- Pass 1: Collect all tool_call IDs from assistant messages ---
135    let mut tool_call_ids: HashSet<String> = HashSet::new();
136    for msg in messages.iter() {
137        if msg.role != Role::Assistant {
138            continue;
139        }
140        collect_tool_call_ids(&msg.content, &mut tool_call_ids);
141    }
142
143    // --- Pass 2: Remove orphaned tool_results (no matching tool_call) ---
144    messages.retain(|msg| {
145        if msg.role != Role::Tool {
146            // Check user messages too — Anthropic sends tool_result as user role
147            if msg.role == Role::User
148                && let Content::ToolResult { call_id, .. } = &msg.content
149                && !tool_call_ids.contains(call_id)
150            {
151                warn!(
152                    call_id = call_id.as_str(),
153                    "Removing orphaned tool_result (no matching tool_call)"
154                );
155                return false;
156            }
157            return true;
158        }
159        match &msg.content {
160            Content::ToolResult { call_id, .. } => {
161                if tool_call_ids.contains(call_id) {
162                    true
163                } else {
164                    warn!(
165                        call_id = call_id.as_str(),
166                        "Removing orphaned tool_result (no matching tool_call)"
167                    );
168                    false
169                }
170            }
171            Content::MultiPart { parts } => {
172                // Keep the message if at least one tool_result has a matching call
173                let has_valid = parts.iter().any(|p| {
174                    if let Content::ToolResult { call_id, .. } = p {
175                        tool_call_ids.contains(call_id)
176                    } else {
177                        true
178                    }
179                });
180                if !has_valid {
181                    warn!("Removing multipart tool message with all orphaned tool_results");
182                }
183                has_valid
184            }
185            _ => true,
186        }
187    });
188
189    // --- Pass 3: Relocate system messages that appear between tool_call and tool_result ---
190    // Strategy: find each assistant message with tool_call(s), then check if the
191    // immediately following message is a system message. If so, move the system message
192    // before the assistant message.
193    let mut i = 0;
194    while i + 1 < messages.len() {
195        let has_tool_call =
196            messages[i].role == Role::Assistant && content_has_tool_call(&messages[i].content);
197
198        if has_tool_call {
199            // Check if next message is a system message (should be tool_result instead)
200            let mut j = i + 1;
201            let mut system_messages_to_relocate = Vec::new();
202            while j < messages.len() && messages[j].role == Role::System {
203                system_messages_to_relocate.push(j);
204                j += 1;
205            }
206            // Move system messages before the assistant tool_call message
207            if !system_messages_to_relocate.is_empty() {
208                // Extract system messages in reverse order to maintain relative order
209                let mut extracted: Vec<Message> = Vec::new();
210                for &idx in system_messages_to_relocate.iter().rev() {
211                    extracted.push(messages.remove(idx));
212                }
213                extracted.reverse();
214                // Insert them before position i
215                for (offset, msg) in extracted.into_iter().enumerate() {
216                    messages.insert(i + offset, msg);
217                    i += 1; // adjust i to still point to the assistant message
218                }
219            }
220        }
221        i += 1;
222    }
223}
224
225/// Extract all tool_call IDs from a Content value into the given set.
226fn collect_tool_call_ids(content: &Content, ids: &mut HashSet<String>) {
227    match content {
228        Content::ToolCall { id, .. } => {
229            ids.insert(id.clone());
230        }
231        Content::MultiPart { parts } => {
232            for part in parts {
233                collect_tool_call_ids(part, ids);
234            }
235        }
236        _ => {}
237    }
238}
239
240/// Check whether a Content value contains at least one tool_call.
241fn content_has_tool_call(content: &Content) -> bool {
242    match content {
243        Content::ToolCall { .. } => true,
244        Content::MultiPart { parts } => parts.iter().any(content_has_tool_call),
245        _ => false,
246    }
247}
248
249/// The Brain wraps an LLM provider and adds higher-level logic:
250/// prompt construction, cost tracking, and model selection.
251pub struct Brain {
252    provider: Arc<dyn LlmProvider>,
253    system_prompt: String,
254    total_usage: TokenUsage,
255    total_cost: CostEstimate,
256    token_counter: TokenCounter,
257    /// Optional knowledge addendum appended to system prompt from distilled rules.
258    knowledge_addendum: String,
259}
260
261impl Brain {
262    pub fn new(provider: Arc<dyn LlmProvider>, system_prompt: impl Into<String>) -> Self {
263        let model_name = provider.model_name().to_string();
264        Self {
265            provider,
266            system_prompt: system_prompt.into(),
267            total_usage: TokenUsage::default(),
268            total_cost: CostEstimate::default(),
269            token_counter: TokenCounter::for_model(&model_name),
270            knowledge_addendum: String::new(),
271        }
272    }
273
274    /// Set knowledge addendum (distilled rules) to append to the system prompt.
275    pub fn set_knowledge_addendum(&mut self, addendum: String) {
276        self.knowledge_addendum = addendum;
277    }
278
279    /// Estimate token count for messages using tiktoken-rs.
280    pub fn estimate_tokens(&self, messages: &[Message]) -> usize {
281        self.token_counter.count_messages(messages)
282    }
283
284    /// Estimate token count for messages plus tool definitions.
285    pub fn estimate_tokens_with_tools(
286        &self,
287        messages: &[Message],
288        tools: Option<&[ToolDefinition]>,
289    ) -> usize {
290        let mut total = self.token_counter.count_messages(messages);
291        if let Some(tool_defs) = tools {
292            total += self.token_counter.count_tool_definitions(tool_defs);
293        }
294        total
295    }
296
297    /// Construct messages for the LLM with system prompt prepended.
298    ///
299    /// If a knowledge addendum has been set via `set_knowledge_addendum()`,
300    /// it is automatically appended to the system prompt.
301    ///
302    /// After assembly, [`sanitize_tool_sequence`] runs to ensure tool_call→tool_result
303    /// ordering is never broken regardless of compression, pinning, or system message injection.
304    pub fn build_messages(&self, conversation: &[Message]) -> Vec<Message> {
305        let mut messages = Vec::with_capacity(conversation.len() + 1);
306        if self.knowledge_addendum.is_empty() {
307            messages.push(Message::system(&self.system_prompt));
308        } else {
309            let augmented = format!("{}{}", self.system_prompt, self.knowledge_addendum);
310            messages.push(Message::system(&augmented));
311        }
312        messages.extend_from_slice(conversation);
313        sanitize_tool_sequence(&mut messages);
314        messages
315    }
316
317    /// Send a completion request and return the response, tracking usage.
318    pub async fn think(
319        &mut self,
320        conversation: &[Message],
321        tools: Option<Vec<ToolDefinition>>,
322    ) -> Result<CompletionResponse, LlmError> {
323        let messages = self.build_messages(conversation);
324        let mut token_estimate = self.provider.estimate_tokens(&messages);
325        if let Some(ref tool_defs) = tools {
326            token_estimate += self.token_counter.count_tool_definitions(tool_defs);
327        }
328        let context_limit = self.provider.context_window();
329
330        if token_estimate > context_limit {
331            return Err(LlmError::ContextOverflow {
332                used: token_estimate,
333                limit: context_limit,
334            });
335        }
336
337        debug!(
338            model = self.provider.model_name(),
339            estimated_tokens = token_estimate,
340            "Sending completion request"
341        );
342
343        let request = CompletionRequest {
344            messages,
345            tools,
346            temperature: 0.7,
347            max_tokens: None,
348            stop_sequences: Vec::new(),
349            model: None,
350        };
351
352        let response = self.provider.complete(request).await?;
353
354        // Track usage
355        self.total_usage.accumulate(&response.usage);
356        let (input_rate, output_rate) = self.provider.cost_per_token();
357        let cost = CostEstimate {
358            input_cost: response.usage.input_tokens as f64 * input_rate,
359            output_cost: response.usage.output_tokens as f64 * output_rate,
360        };
361        self.total_cost.accumulate(&cost);
362
363        info!(
364            input_tokens = response.usage.input_tokens,
365            output_tokens = response.usage.output_tokens,
366            cost = format!("${:.4}", cost.total()),
367            "Completion received"
368        );
369
370        Ok(response)
371    }
372
373    /// Send a completion request with retry logic and exponential backoff.
374    ///
375    /// Retries on transient errors (RateLimited, Timeout, Connection) up to
376    /// `max_retries` times with exponential backoff (1s, 2s, 4s, ..., capped at 32s).
377    /// Non-transient errors are returned immediately.
378    pub async fn think_with_retry(
379        &mut self,
380        conversation: &[Message],
381        tools: Option<Vec<ToolDefinition>>,
382        max_retries: usize,
383    ) -> Result<CompletionResponse, LlmError> {
384        let mut last_error = None;
385
386        for attempt in 0..=max_retries {
387            match self.think(conversation, tools.clone()).await {
388                Ok(response) => return Ok(response),
389                Err(e) if Self::is_retryable(&e) => {
390                    if attempt < max_retries {
391                        let backoff_secs = std::cmp::min(1u64 << attempt, 32);
392                        let wait = match &e {
393                            LlmError::RateLimited { retry_after_secs } => {
394                                std::cmp::max(*retry_after_secs, backoff_secs)
395                            }
396                            _ => backoff_secs,
397                        };
398                        info!(
399                            attempt = attempt + 1,
400                            max_retries,
401                            backoff_secs = wait,
402                            error = %e,
403                            "Retrying after transient error"
404                        );
405                        tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
406                        last_error = Some(e);
407                    } else {
408                        return Err(e);
409                    }
410                }
411                Err(e) => return Err(e),
412            }
413        }
414
415        Err(last_error.unwrap_or(LlmError::Connection {
416            message: "Max retries exceeded".to_string(),
417        }))
418    }
419
420    /// Check if an LLM error is transient and should be retried.
421    pub fn is_retryable(error: &LlmError) -> bool {
422        matches!(
423            error,
424            LlmError::RateLimited { .. } | LlmError::Timeout { .. } | LlmError::Connection { .. }
425        )
426    }
427
428    /// Send a streaming completion request, returning events via channel.
429    pub async fn think_streaming(
430        &mut self,
431        conversation: &[Message],
432        tools: Option<Vec<ToolDefinition>>,
433        tx: mpsc::Sender<StreamEvent>,
434    ) -> Result<(), LlmError> {
435        let messages = self.build_messages(conversation);
436
437        let request = CompletionRequest {
438            messages,
439            tools,
440            temperature: 0.7,
441            max_tokens: None,
442            stop_sequences: Vec::new(),
443            model: None,
444        };
445
446        self.provider.complete_streaming(request, tx).await
447    }
448
449    /// Get total token usage across all calls.
450    pub fn total_usage(&self) -> &TokenUsage {
451        &self.total_usage
452    }
453
454    /// Get total cost across all calls.
455    pub fn total_cost(&self) -> &CostEstimate {
456        &self.total_cost
457    }
458
459    /// Get the model name.
460    pub fn model_name(&self) -> &str {
461        self.provider.model_name()
462    }
463
464    /// Get the context window size.
465    pub fn context_window(&self) -> usize {
466        self.provider.context_window()
467    }
468
469    /// Get cost rates (input_per_token, output_per_token) from the provider.
470    pub fn provider_cost_rates(&self) -> (f64, f64) {
471        self.provider.cost_per_token()
472    }
473
474    /// Get a reference to the underlying LLM provider.
475    pub fn provider(&self) -> &dyn LlmProvider {
476        &*self.provider
477    }
478
479    /// Get a cloneable Arc handle to the underlying LLM provider.
480    pub fn provider_arc(&self) -> Arc<dyn LlmProvider> {
481        Arc::clone(&self.provider)
482    }
483
484    /// Track usage and cost from an external completion (e.g., streaming).
485    pub fn track_usage(&mut self, usage: &TokenUsage) {
486        self.total_usage.accumulate(usage);
487        let (input_rate, output_rate) = self.provider.cost_per_token();
488        let cost = CostEstimate {
489            input_cost: usage.input_tokens as f64 * input_rate,
490            output_cost: usage.output_tokens as f64 * output_rate,
491        };
492        self.total_cost.accumulate(&cost);
493    }
494
495    /// Get the current token usage as a fraction of the context window.
496    pub fn context_usage_ratio(&self, conversation: &[Message]) -> f32 {
497        let messages = self.build_messages(conversation);
498        let tokens = self.provider.estimate_tokens(&messages);
499        tokens as f32 / self.provider.context_window() as f32
500    }
501}
502
503/// A mock LLM provider for testing and development.
504pub struct MockLlmProvider {
505    model: String,
506    context_window: usize,
507    responses: std::sync::Mutex<Vec<CompletionResponse>>,
508}
509
510impl MockLlmProvider {
511    pub fn new() -> Self {
512        Self {
513            model: "mock-model".to_string(),
514            context_window: 128_000,
515            responses: std::sync::Mutex::new(Vec::new()),
516        }
517    }
518
519    /// Create a MockLlmProvider that always returns the given text.
520    ///
521    /// Queues multiple copies of the response so it can handle multiple calls.
522    pub fn with_response(text: &str) -> Self {
523        let provider = Self::new();
524        for _ in 0..20 {
525            provider.queue_response(Self::text_response(text));
526        }
527        provider
528    }
529
530    /// Queue a response to be returned by the next `complete` call.
531    pub fn queue_response(&self, response: CompletionResponse) {
532        self.responses.lock().unwrap().push(response);
533    }
534
535    /// Create a simple text response for testing.
536    pub fn text_response(text: &str) -> CompletionResponse {
537        CompletionResponse {
538            message: Message::assistant(text),
539            usage: TokenUsage {
540                input_tokens: 100,
541                output_tokens: 50,
542            },
543            model: "mock-model".to_string(),
544            finish_reason: Some("stop".to_string()),
545        }
546    }
547
548    /// Create a tool call response for testing.
549    pub fn tool_call_response(tool_name: &str, arguments: serde_json::Value) -> CompletionResponse {
550        let call_id = format!("call_{}", uuid::Uuid::new_v4());
551        CompletionResponse {
552            message: Message::new(
553                Role::Assistant,
554                Content::tool_call(&call_id, tool_name, arguments),
555            ),
556            usage: TokenUsage {
557                input_tokens: 100,
558                output_tokens: 30,
559            },
560            model: "mock-model".to_string(),
561            finish_reason: Some("tool_calls".to_string()),
562        }
563    }
564
565    /// Create a multipart response (text + tool call) for testing.
566    pub fn multipart_response(
567        text: &str,
568        tool_name: &str,
569        arguments: serde_json::Value,
570    ) -> CompletionResponse {
571        let call_id = format!("call_{}", uuid::Uuid::new_v4());
572        CompletionResponse {
573            message: Message::new(
574                Role::Assistant,
575                Content::MultiPart {
576                    parts: vec![
577                        Content::text(text),
578                        Content::tool_call(&call_id, tool_name, arguments),
579                    ],
580                },
581            ),
582            usage: TokenUsage {
583                input_tokens: 100,
584                output_tokens: 50,
585            },
586            model: "mock-model".to_string(),
587            finish_reason: Some("tool_calls".to_string()),
588        }
589    }
590}
591
592impl Default for MockLlmProvider {
593    fn default() -> Self {
594        Self::new()
595    }
596}
597
598#[async_trait]
599impl LlmProvider for MockLlmProvider {
600    async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
601        let mut responses = self.responses.lock().unwrap();
602        if responses.is_empty() {
603            Ok(MockLlmProvider::text_response(
604                "I'm a mock LLM. No queued responses available.",
605            ))
606        } else {
607            Ok(responses.remove(0))
608        }
609    }
610
611    async fn complete_streaming(
612        &self,
613        request: CompletionRequest,
614        tx: mpsc::Sender<StreamEvent>,
615    ) -> Result<(), LlmError> {
616        let response = self.complete(request).await?;
617        if let Some(text) = response.message.content.as_text() {
618            for word in text.split_whitespace() {
619                let _ = tx.send(StreamEvent::Token(format!("{} ", word))).await;
620            }
621        }
622        let _ = tx
623            .send(StreamEvent::Done {
624                usage: response.usage,
625            })
626            .await;
627        Ok(())
628    }
629
630    fn estimate_tokens(&self, messages: &[Message]) -> usize {
631        // Rough estimate: ~4 chars per token
632        messages
633            .iter()
634            .map(|m| match &m.content {
635                Content::Text { text } => text.len() / 4,
636                Content::ToolCall { arguments, .. } => arguments.to_string().len() / 4,
637                Content::ToolResult { output, .. } => output.len() / 4,
638                Content::MultiPart { parts } => parts
639                    .iter()
640                    .map(|p| match p {
641                        Content::Text { text } => text.len() / 4,
642                        _ => 50,
643                    })
644                    .sum(),
645            })
646            .sum::<usize>()
647            + 100 // overhead for message structure
648    }
649
650    fn context_window(&self) -> usize {
651        self.context_window
652    }
653
654    fn supports_tools(&self) -> bool {
655        true
656    }
657
658    fn cost_per_token(&self) -> (f64, f64) {
659        (0.0, 0.0) // free for mock
660    }
661
662    fn model_name(&self) -> &str {
663        &self.model
664    }
665}
666
667/// The system prompt used by default for the Rustant agent.
668pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are Rustant, a privacy-first autonomous personal assistant built in Rust. You help users with software engineering, daily productivity, and macOS automation tasks.
669
670CRITICAL — Tool selection rules:
671- You MUST use the dedicated tool for each task. Do NOT use shell_exec when a dedicated tool exists.
672- For clipboard: call macos_clipboard with {"action":"read"} or {"action":"write","content":"..."}
673- For battery/disk/CPU/version: call macos_system_info with {"action":"battery"}, {"action":"version"}, etc.
674- For running apps: call macos_app_control with {"action":"list_running"}
675- For calendar: call macos_calendar. For reminders: call macos_reminders. For notes: call macos_notes.
676- For screenshots: call macos_screenshot. For Spotlight search: call macos_spotlight.
677- shell_exec is a last resort — only use it for commands that have no dedicated tool.
678- Do NOT use document_read for clipboard or system operations — it reads document files only.
679- If a tool call fails, try a different tool or action — do NOT ask the user whether to proceed. Act autonomously.
680- Never call ask_user more than once per task unless the user's answer was genuinely unclear.
681
682Other behaviors:
683- Always read a file before modifying it
684- Prefer small, focused changes over large rewrites
685- Respect file boundaries and permissions
686
687Tools:
688- Use the tools provided to you. Each tool has a name, description, and parameter schema.
689- For arXiv/paper searches: ALWAYS use arxiv_research (built-in API client), never safari/curl.
690- For web searches: use web_search (DuckDuckGo), not safari/shell.
691- For fetching URLs: use web_fetch, not safari/shell.
692- For meeting recording: use macos_meeting_recorder with 'record_and_transcribe' for full flow.
693
694Workflows (structured multi-step templates — run via shell_exec "rustant workflow run <name>"):
695  code_review, refactor, test_generation, documentation, dependency_update,
696  security_scan, deployment, incident_response, morning_briefing, pr_review,
697  dependency_audit, changelog, meeting_recorder, daily_briefing_full,
698  end_of_day_summary, app_automation, email_triage, arxiv_research,
699  knowledge_graph, experiment_tracking, code_analysis, content_pipeline,
700  skill_development, career_planning, system_monitoring, life_planning,
701  privacy_audit, self_improvement_loop
702When a user asks for one of these tasks by name or description, execute the workflow or accomplish it step by step.
703
704Security rules:
705- Never execute commands that could damage the system or leak credentials
706- Do not read or write files containing secrets (.env, *.key, *.pem) unless explicitly asked
707- Sanitize all user input before passing to shell or AppleScript commands
708- When unsure about a destructive action, use ask_user to confirm first"#;
709
710// ---------------------------------------------------------------------------
711// Token Budget Manager
712// ---------------------------------------------------------------------------
713
714/// Tracks token usage against configurable budgets and predicts costs
715/// before execution. Can warn or halt when budgets are exceeded.
716pub struct TokenBudgetManager {
717    session_limit_usd: f64,
718    task_limit_usd: f64,
719    session_token_limit: usize,
720    halt_on_exceed: bool,
721    session_cost: f64,
722    task_cost: f64,
723    session_tokens: usize,
724}
725
726/// The result of a pre-call budget check.
727#[derive(Debug, Clone, PartialEq)]
728pub enum BudgetCheckResult {
729    /// Budget is within limits, proceed.
730    Ok,
731    /// Budget warning — approaching limit but not exceeded.
732    Warning { message: String, usage_pct: f64 },
733    /// Budget exceeded — should halt if configured.
734    Exceeded { message: String },
735}
736
737impl TokenBudgetManager {
738    /// Create a new budget manager from config. Passing `None` creates
739    /// an unlimited manager that always returns `Ok`.
740    pub fn new(config: Option<&crate::config::BudgetConfig>) -> Self {
741        match config {
742            Some(cfg) => Self {
743                session_limit_usd: cfg.session_limit_usd,
744                task_limit_usd: cfg.task_limit_usd,
745                session_token_limit: cfg.session_token_limit,
746                halt_on_exceed: cfg.halt_on_exceed,
747                session_cost: 0.0,
748                task_cost: 0.0,
749                session_tokens: 0,
750            },
751            None => Self {
752                session_limit_usd: 0.0,
753                task_limit_usd: 0.0,
754                session_token_limit: 0,
755                halt_on_exceed: false,
756                session_cost: 0.0,
757                task_cost: 0.0,
758                session_tokens: 0,
759            },
760        }
761    }
762
763    /// Reset task-level tracking (call at start of each new task).
764    pub fn reset_task(&mut self) {
765        self.task_cost = 0.0;
766    }
767
768    /// Record usage after an LLM call completes.
769    pub fn record_usage(&mut self, usage: &TokenUsage, cost: &CostEstimate) {
770        self.session_cost += cost.total();
771        self.task_cost += cost.total();
772        self.session_tokens += usage.total();
773    }
774
775    /// Estimate cost for an upcoming LLM call and check against budgets.
776    ///
777    /// `estimated_input_tokens` is the count of tokens in the request.
778    /// `input_rate` and `output_rate` are the per-token costs from the provider.
779    /// Output tokens are estimated at 0.5x input as a heuristic.
780    pub fn check_budget(
781        &self,
782        estimated_input_tokens: usize,
783        input_rate: f64,
784        output_rate: f64,
785    ) -> BudgetCheckResult {
786        // Predict output tokens as ~50% of input (heuristic)
787        let predicted_output = estimated_input_tokens / 2;
788        let predicted_cost =
789            (estimated_input_tokens as f64 * input_rate) + (predicted_output as f64 * output_rate);
790
791        let projected_session_cost = self.session_cost + predicted_cost;
792        let projected_task_cost = self.task_cost + predicted_cost;
793        let projected_session_tokens =
794            self.session_tokens + estimated_input_tokens + predicted_output;
795
796        // Check session cost limit
797        if self.session_limit_usd > 0.0 && projected_session_cost > self.session_limit_usd {
798            return BudgetCheckResult::Exceeded {
799                message: format!(
800                    "Session cost ${:.4} would exceed limit ${:.4}",
801                    projected_session_cost, self.session_limit_usd
802                ),
803            };
804        }
805
806        // Check task cost limit
807        if self.task_limit_usd > 0.0 && projected_task_cost > self.task_limit_usd {
808            return BudgetCheckResult::Exceeded {
809                message: format!(
810                    "Task cost ${:.4} would exceed limit ${:.4}",
811                    projected_task_cost, self.task_limit_usd
812                ),
813            };
814        }
815
816        // Check session token limit
817        if self.session_token_limit > 0 && projected_session_tokens > self.session_token_limit {
818            return BudgetCheckResult::Exceeded {
819                message: format!(
820                    "Session tokens {} would exceed limit {}",
821                    projected_session_tokens, self.session_token_limit
822                ),
823            };
824        }
825
826        // Check if approaching limits (>80%)
827        if self.session_limit_usd > 0.0 {
828            let pct = projected_session_cost / self.session_limit_usd;
829            if pct > 0.8 {
830                return BudgetCheckResult::Warning {
831                    message: format!(
832                        "Session cost at {:.0}% of ${:.4} limit",
833                        pct * 100.0,
834                        self.session_limit_usd
835                    ),
836                    usage_pct: pct,
837                };
838            }
839        }
840
841        BudgetCheckResult::Ok
842    }
843
844    /// Whether budget enforcement should halt execution on exceed.
845    pub fn should_halt_on_exceed(&self) -> bool {
846        self.halt_on_exceed
847    }
848
849    /// Current session cost.
850    pub fn session_cost(&self) -> f64 {
851        self.session_cost
852    }
853
854    /// Current task cost.
855    pub fn task_cost(&self) -> f64 {
856        self.task_cost
857    }
858
859    /// Current session token count.
860    pub fn session_tokens(&self) -> usize {
861        self.session_tokens
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868
869    #[tokio::test]
870    async fn test_mock_provider_default_response() {
871        let provider = MockLlmProvider::new();
872        let request = CompletionRequest::default();
873        let response = provider.complete(request).await.unwrap();
874        assert!(response.message.content.as_text().is_some());
875    }
876
877    #[tokio::test]
878    async fn test_mock_provider_queued_responses() {
879        let provider = MockLlmProvider::new();
880        provider.queue_response(MockLlmProvider::text_response("first"));
881        provider.queue_response(MockLlmProvider::text_response("second"));
882
883        let r1 = provider
884            .complete(CompletionRequest::default())
885            .await
886            .unwrap();
887        assert_eq!(r1.message.content.as_text(), Some("first"));
888
889        let r2 = provider
890            .complete(CompletionRequest::default())
891            .await
892            .unwrap();
893        assert_eq!(r2.message.content.as_text(), Some("second"));
894    }
895
896    #[tokio::test]
897    async fn test_mock_provider_streaming() {
898        let provider = MockLlmProvider::new();
899        provider.queue_response(MockLlmProvider::text_response("hello world"));
900
901        let (tx, mut rx) = mpsc::channel(32);
902        provider
903            .complete_streaming(CompletionRequest::default(), tx)
904            .await
905            .unwrap();
906
907        let mut tokens = Vec::new();
908        while let Some(event) = rx.recv().await {
909            match event {
910                StreamEvent::Token(t) => tokens.push(t),
911                StreamEvent::Done { .. } => break,
912                _ => {}
913            }
914        }
915        assert_eq!(tokens.len(), 2); // "hello " and "world "
916    }
917
918    #[test]
919    fn test_mock_provider_token_estimation() {
920        let provider = MockLlmProvider::new();
921        let messages = vec![Message::user("Hello, this is a test message.")];
922        let tokens = provider.estimate_tokens(&messages);
923        assert!(tokens > 0);
924    }
925
926    #[test]
927    fn test_mock_provider_properties() {
928        let provider = MockLlmProvider::new();
929        assert_eq!(provider.context_window(), 128_000);
930        assert!(provider.supports_tools());
931        assert_eq!(provider.cost_per_token(), (0.0, 0.0));
932        assert_eq!(provider.model_name(), "mock-model");
933    }
934
935    #[tokio::test]
936    async fn test_brain_think() {
937        let provider = Arc::new(MockLlmProvider::new());
938        provider.queue_response(MockLlmProvider::text_response("I can help with that."));
939
940        let mut brain = Brain::new(provider, "You are a helpful assistant.");
941        let conversation = vec![Message::user("Help me refactor")];
942
943        let response = brain.think(&conversation, None).await.unwrap();
944        assert_eq!(
945            response.message.content.as_text(),
946            Some("I can help with that.")
947        );
948        assert!(brain.total_usage().total() > 0);
949    }
950
951    #[tokio::test]
952    async fn test_brain_builds_messages_with_system_prompt() {
953        let provider = Arc::new(MockLlmProvider::new());
954        let brain = Brain::new(provider, "system prompt");
955        let conversation = vec![Message::user("hello")];
956
957        let messages = brain.build_messages(&conversation);
958        assert_eq!(messages.len(), 2);
959        assert_eq!(messages[0].role, Role::System);
960        assert_eq!(messages[0].content.as_text(), Some("system prompt"));
961        assert_eq!(messages[1].role, Role::User);
962    }
963
964    #[test]
965    fn test_brain_context_usage_ratio() {
966        let provider = Arc::new(MockLlmProvider::new());
967        let brain = Brain::new(provider, "system");
968        let conversation = vec![Message::user("short message")];
969
970        let ratio = brain.context_usage_ratio(&conversation);
971        assert!(ratio > 0.0);
972        assert!(ratio < 1.0);
973    }
974
975    #[test]
976    fn test_mock_tool_call_response() {
977        let response = MockLlmProvider::tool_call_response(
978            "file_read",
979            serde_json::json!({"path": "/tmp/test.rs"}),
980        );
981        match &response.message.content {
982            Content::ToolCall {
983                name, arguments, ..
984            } => {
985                assert_eq!(name, "file_read");
986                assert_eq!(arguments["path"], "/tmp/test.rs");
987            }
988            _ => panic!("Expected ToolCall content"),
989        }
990    }
991
992    #[test]
993    fn test_default_system_prompt() {
994        assert!(DEFAULT_SYSTEM_PROMPT.contains("Rustant"));
995        assert!(DEFAULT_SYSTEM_PROMPT.contains("autonomous"));
996    }
997
998    #[test]
999    fn test_is_retryable() {
1000        assert!(Brain::is_retryable(&LlmError::RateLimited {
1001            retry_after_secs: 5
1002        }));
1003        assert!(Brain::is_retryable(&LlmError::Timeout { timeout_secs: 30 }));
1004        assert!(Brain::is_retryable(&LlmError::Connection {
1005            message: "reset".into()
1006        }));
1007        assert!(!Brain::is_retryable(&LlmError::ContextOverflow {
1008            used: 200_000,
1009            limit: 128_000
1010        }));
1011        assert!(!Brain::is_retryable(&LlmError::AuthFailed {
1012            provider: "openai".into()
1013        }));
1014    }
1015
1016    /// A mock provider that fails N times before succeeding.
1017    struct FailingProvider {
1018        failures_remaining: std::sync::Mutex<usize>,
1019        error_type: String,
1020        success_response: CompletionResponse,
1021    }
1022
1023    impl FailingProvider {
1024        fn new(failures: usize, error_type: &str) -> Self {
1025            Self {
1026                failures_remaining: std::sync::Mutex::new(failures),
1027                error_type: error_type.to_string(),
1028                success_response: MockLlmProvider::text_response("Success after retry"),
1029            }
1030        }
1031    }
1032
1033    #[async_trait]
1034    impl LlmProvider for FailingProvider {
1035        async fn complete(
1036            &self,
1037            _request: CompletionRequest,
1038        ) -> Result<CompletionResponse, LlmError> {
1039            let mut remaining = self.failures_remaining.lock().unwrap();
1040            if *remaining > 0 {
1041                *remaining -= 1;
1042                match self.error_type.as_str() {
1043                    "rate_limited" => Err(LlmError::RateLimited {
1044                        retry_after_secs: 0,
1045                    }),
1046                    "timeout" => Err(LlmError::Timeout { timeout_secs: 5 }),
1047                    "connection" => Err(LlmError::Connection {
1048                        message: "connection reset".into(),
1049                    }),
1050                    _ => Err(LlmError::ApiRequest {
1051                        message: "non-retryable".into(),
1052                    }),
1053                }
1054            } else {
1055                Ok(self.success_response.clone())
1056            }
1057        }
1058
1059        async fn complete_streaming(
1060            &self,
1061            _request: CompletionRequest,
1062            _tx: mpsc::Sender<StreamEvent>,
1063        ) -> Result<(), LlmError> {
1064            Ok(())
1065        }
1066
1067        fn estimate_tokens(&self, _messages: &[Message]) -> usize {
1068            100
1069        }
1070        fn context_window(&self) -> usize {
1071            128_000
1072        }
1073        fn supports_tools(&self) -> bool {
1074            true
1075        }
1076        fn cost_per_token(&self) -> (f64, f64) {
1077            (0.0, 0.0)
1078        }
1079        fn model_name(&self) -> &str {
1080            "failing-mock"
1081        }
1082    }
1083
1084    #[tokio::test]
1085    async fn test_think_with_retry_succeeds_after_failures() {
1086        let provider = Arc::new(FailingProvider::new(2, "connection"));
1087        let mut brain = Brain::new(provider, "system");
1088        let conversation = vec![Message::user("test")];
1089
1090        let result = brain.think_with_retry(&conversation, None, 3).await;
1091        assert!(result.is_ok());
1092        assert_eq!(
1093            result.unwrap().message.content.as_text(),
1094            Some("Success after retry")
1095        );
1096    }
1097
1098    #[tokio::test]
1099    async fn test_think_with_retry_exhausted() {
1100        let provider = Arc::new(FailingProvider::new(5, "timeout"));
1101        let mut brain = Brain::new(provider, "system");
1102        let conversation = vec![Message::user("test")];
1103
1104        let result = brain.think_with_retry(&conversation, None, 2).await;
1105        assert!(result.is_err());
1106        assert!(matches!(result.unwrap_err(), LlmError::Timeout { .. }));
1107    }
1108
1109    #[tokio::test]
1110    async fn test_think_with_retry_non_retryable_fails_immediately() {
1111        let provider = Arc::new(FailingProvider::new(1, "non_retryable"));
1112        let mut brain = Brain::new(provider, "system");
1113        let conversation = vec![Message::user("test")];
1114
1115        let result = brain.think_with_retry(&conversation, None, 3).await;
1116        assert!(result.is_err());
1117        assert!(matches!(result.unwrap_err(), LlmError::ApiRequest { .. }));
1118    }
1119
1120    #[tokio::test]
1121    async fn test_think_with_retry_rate_limited() {
1122        let provider = Arc::new(FailingProvider::new(1, "rate_limited"));
1123        let mut brain = Brain::new(provider, "system");
1124        let conversation = vec![Message::user("test")];
1125
1126        let result = brain.think_with_retry(&conversation, None, 2).await;
1127        assert!(result.is_ok());
1128    }
1129
1130    #[test]
1131    fn test_track_usage() {
1132        let provider = Arc::new(MockLlmProvider::new());
1133        let mut brain = Brain::new(provider, "system");
1134
1135        let usage = TokenUsage {
1136            input_tokens: 100,
1137            output_tokens: 50,
1138        };
1139        brain.track_usage(&usage);
1140
1141        assert_eq!(brain.total_usage().input_tokens, 100);
1142        assert_eq!(brain.total_usage().output_tokens, 50);
1143    }
1144
1145    #[test]
1146    fn test_token_counter_basic() {
1147        let counter = TokenCounter::for_model("gpt-4o");
1148        let count = counter.count("Hello, world!");
1149        assert!(count > 0);
1150        assert!(count < 20); // should be ~4 tokens
1151    }
1152
1153    #[test]
1154    fn test_token_counter_messages() {
1155        let counter = TokenCounter::for_model("gpt-4o");
1156        let messages = vec![
1157            Message::system("You are a helpful assistant."),
1158            Message::user("What is 2 + 2?"),
1159        ];
1160        let count = counter.count_messages(&messages);
1161        assert!(count > 5);
1162        assert!(count < 100);
1163    }
1164
1165    #[test]
1166    fn test_token_counter_unknown_model_falls_back() {
1167        let counter = TokenCounter::for_model("unknown-model-xyz");
1168        let count = counter.count("Hello");
1169        assert!(count > 0); // Should use cl100k_base fallback
1170    }
1171
1172    #[test]
1173    fn test_brain_estimate_tokens() {
1174        let provider = Arc::new(MockLlmProvider::new());
1175        let brain = Brain::new(provider, "system");
1176        let messages = vec![Message::user("Hello, this is a test.")];
1177        let estimate = brain.estimate_tokens(&messages);
1178        assert!(estimate > 0);
1179    }
1180
1181    // --- sanitize_tool_sequence tests ---
1182
1183    #[test]
1184    fn test_sanitize_removes_orphaned_tool_results() {
1185        let mut messages = vec![
1186            Message::system("You are a helper."),
1187            Message::user("do something"),
1188            // Orphaned tool_result — no matching tool_call
1189            Message::tool_result("call_orphan_123", "some result", false),
1190            Message::assistant("Done!"),
1191        ];
1192
1193        super::sanitize_tool_sequence(&mut messages);
1194
1195        // The orphaned tool_result should be removed
1196        assert_eq!(messages.len(), 3);
1197        assert_eq!(messages[0].role, Role::System);
1198        assert_eq!(messages[1].role, Role::User);
1199        assert_eq!(messages[2].role, Role::Assistant);
1200    }
1201
1202    #[test]
1203    fn test_sanitize_preserves_valid_sequence() {
1204        let mut messages = vec![
1205            Message::system("system prompt"),
1206            Message::user("read main.rs"),
1207            Message::new(
1208                Role::Assistant,
1209                Content::tool_call(
1210                    "call_1",
1211                    "file_read",
1212                    serde_json::json!({"path": "main.rs"}),
1213                ),
1214            ),
1215            Message::tool_result("call_1", "fn main() {}", false),
1216            Message::assistant("Here is the file content."),
1217        ];
1218
1219        super::sanitize_tool_sequence(&mut messages);
1220
1221        // All messages should be preserved
1222        assert_eq!(messages.len(), 5);
1223    }
1224
1225    #[test]
1226    fn test_sanitize_handles_system_between_call_and_result() {
1227        let mut messages = vec![
1228            Message::system("system prompt"),
1229            Message::user("do something"),
1230            Message::new(
1231                Role::Assistant,
1232                Content::tool_call("call_1", "file_read", serde_json::json!({"path": "x.rs"})),
1233            ),
1234            // System message injected between tool_call and tool_result
1235            Message::system("routing hint: use file_read"),
1236            Message::tool_result("call_1", "file contents", false),
1237            Message::assistant("Done"),
1238        ];
1239
1240        super::sanitize_tool_sequence(&mut messages);
1241
1242        // System message should be moved before the assistant tool_call
1243        // Find the assistant tool_call message
1244        let assistant_idx = messages
1245            .iter()
1246            .position(|m| m.role == Role::Assistant && super::content_has_tool_call(&m.content))
1247            .unwrap();
1248
1249        // The message right after the assistant tool_call should be the tool_result
1250        let next = &messages[assistant_idx + 1];
1251        assert!(
1252            matches!(&next.content, Content::ToolResult { .. })
1253                || next.role == Role::Tool
1254                || next.role == Role::User,
1255            "Expected tool_result after tool_call, got {:?}",
1256            next.role
1257        );
1258    }
1259
1260    #[test]
1261    fn test_sanitize_multipart_tool_call() {
1262        let mut messages = vec![
1263            Message::user("do two things"),
1264            Message::new(
1265                Role::Assistant,
1266                Content::MultiPart {
1267                    parts: vec![
1268                        Content::text("I'll read both files."),
1269                        Content::tool_call(
1270                            "call_a",
1271                            "file_read",
1272                            serde_json::json!({"path": "a.rs"}),
1273                        ),
1274                        Content::tool_call(
1275                            "call_b",
1276                            "file_read",
1277                            serde_json::json!({"path": "b.rs"}),
1278                        ),
1279                    ],
1280                },
1281            ),
1282            Message::tool_result("call_a", "contents of a", false),
1283            Message::tool_result("call_b", "contents of b", false),
1284            // Orphaned tool_result
1285            Message::tool_result("call_nonexistent", "orphan", false),
1286        ];
1287
1288        super::sanitize_tool_sequence(&mut messages);
1289
1290        // Orphaned result removed, valid ones preserved
1291        assert_eq!(messages.len(), 4);
1292    }
1293
1294    #[test]
1295    fn test_sanitize_empty_messages() {
1296        let mut messages: Vec<Message> = vec![];
1297        super::sanitize_tool_sequence(&mut messages);
1298        assert!(messages.is_empty());
1299    }
1300
1301    #[test]
1302    fn test_sanitize_no_tool_messages() {
1303        let mut messages = vec![
1304            Message::system("prompt"),
1305            Message::user("hello"),
1306            Message::assistant("hi"),
1307        ];
1308        super::sanitize_tool_sequence(&mut messages);
1309        assert_eq!(messages.len(), 3);
1310    }
1311
1312    #[test]
1313    fn test_count_tool_definitions() {
1314        let counter = TokenCounter::for_model("gpt-4");
1315        let tools = vec![
1316            ToolDefinition {
1317                name: "calculator".to_string(),
1318                description: "Perform arithmetic calculations".to_string(),
1319                parameters: serde_json::json!({
1320                    "type": "object",
1321                    "properties": {
1322                        "expression": { "type": "string", "description": "Math expression" }
1323                    },
1324                    "required": ["expression"]
1325                }),
1326            },
1327            ToolDefinition {
1328                name: "file_read".to_string(),
1329                description: "Read a file from the filesystem".to_string(),
1330                parameters: serde_json::json!({
1331                    "type": "object",
1332                    "properties": {
1333                        "path": { "type": "string", "description": "File path" }
1334                    },
1335                    "required": ["path"]
1336                }),
1337            },
1338        ];
1339
1340        let token_count = counter.count_tool_definitions(&tools);
1341        // Each tool: 10 overhead + name tokens + description tokens + params tokens
1342        // Should be a meaningful positive number
1343        assert!(
1344            token_count > 40,
1345            "Two tool definitions should count as >40 tokens, got {}",
1346            token_count
1347        );
1348        // Sanity: shouldn't be absurdly large
1349        assert!(
1350            token_count < 500,
1351            "Two simple tool definitions should be <500 tokens, got {}",
1352            token_count
1353        );
1354    }
1355
1356    #[test]
1357    fn test_count_tool_definitions_empty() {
1358        let counter = TokenCounter::for_model("gpt-4");
1359        assert_eq!(counter.count_tool_definitions(&[]), 0);
1360    }
1361
1362    #[test]
1363    fn test_estimate_tokens_with_tools() {
1364        let provider = Arc::new(MockLlmProvider::new());
1365        let brain = Brain::new(provider, "system prompt");
1366
1367        let messages = vec![Message::user("hello")];
1368        let tools = vec![ToolDefinition {
1369            name: "echo".to_string(),
1370            description: "Echo text back".to_string(),
1371            parameters: serde_json::json!({"type": "object"}),
1372        }];
1373
1374        let without_tools = brain.estimate_tokens(&messages);
1375        let with_tools = brain.estimate_tokens_with_tools(&messages, Some(&tools));
1376        assert!(
1377            with_tools > without_tools,
1378            "Token estimate with tools ({}) should be greater than without ({})",
1379            with_tools,
1380            without_tools
1381        );
1382    }
1383}