1use crate::error::LlmError;
7use crate::types::{
8 CompletionRequest, CompletionResponse, Content, CostEstimate, Message, Role, StreamEvent,
9 TokenUsage, ToolDefinition,
10};
11use async_trait::async_trait;
12use std::collections::HashSet;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tracing::{debug, info, warn};
16
17#[async_trait]
19pub trait LlmProvider: Send + Sync {
20 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
22
23 async fn complete_streaming(
25 &self,
26 request: CompletionRequest,
27 tx: mpsc::Sender<StreamEvent>,
28 ) -> Result<(), LlmError>;
29
30 fn estimate_tokens(&self, messages: &[Message]) -> usize;
32
33 fn context_window(&self) -> usize;
35
36 fn supports_tools(&self) -> bool;
38
39 fn cost_per_token(&self) -> (f64, f64);
41
42 fn model_name(&self) -> &str;
44}
45
46pub struct TokenCounter {
48 bpe: tiktoken_rs::CoreBPE,
49}
50
51impl TokenCounter {
52 pub fn for_model(model: &str) -> Self {
55 let bpe = tiktoken_rs::get_bpe_from_model(model).unwrap_or_else(|_| {
56 tiktoken_rs::cl100k_base().expect("cl100k_base should be available")
57 });
58 Self { bpe }
59 }
60
61 pub fn count(&self, text: &str) -> usize {
63 self.bpe.encode_with_special_tokens(text).len()
64 }
65
66 pub fn count_messages(&self, messages: &[Message]) -> usize {
69 let mut total = 0;
70 for msg in messages {
71 total += 4;
73 match &msg.content {
74 Content::Text { text } => total += self.count(text),
75 Content::ToolCall {
76 name, arguments, ..
77 } => {
78 total += self.count(name);
79 total += self.count(&arguments.to_string());
80 }
81 Content::ToolResult { output, .. } => {
82 total += self.count(output);
83 }
84 Content::MultiPart { parts } => {
85 for part in parts {
86 match part {
87 Content::Text { text } => total += self.count(text),
88 Content::ToolCall {
89 name, arguments, ..
90 } => {
91 total += self.count(name);
92 total += self.count(&arguments.to_string());
93 }
94 Content::ToolResult { output, .. } => {
95 total += self.count(output);
96 }
97 _ => total += 10,
98 }
99 }
100 }
101 }
102 }
103 total + 3 }
105}
106
107pub fn sanitize_tool_sequence(messages: &mut Vec<Message>) {
119 let mut tool_call_ids: HashSet<String> = HashSet::new();
121 for msg in messages.iter() {
122 if msg.role != Role::Assistant {
123 continue;
124 }
125 collect_tool_call_ids(&msg.content, &mut tool_call_ids);
126 }
127
128 messages.retain(|msg| {
130 if msg.role != Role::Tool {
131 if msg.role == Role::User {
133 if let Content::ToolResult { call_id, .. } = &msg.content {
134 if !tool_call_ids.contains(call_id) {
135 warn!(
136 call_id = call_id.as_str(),
137 "Removing orphaned tool_result (no matching tool_call)"
138 );
139 return false;
140 }
141 }
142 }
143 return true;
144 }
145 match &msg.content {
146 Content::ToolResult { call_id, .. } => {
147 if tool_call_ids.contains(call_id) {
148 true
149 } else {
150 warn!(
151 call_id = call_id.as_str(),
152 "Removing orphaned tool_result (no matching tool_call)"
153 );
154 false
155 }
156 }
157 Content::MultiPart { parts } => {
158 let has_valid = parts.iter().any(|p| {
160 if let Content::ToolResult { call_id, .. } = p {
161 tool_call_ids.contains(call_id)
162 } else {
163 true
164 }
165 });
166 if !has_valid {
167 warn!("Removing multipart tool message with all orphaned tool_results");
168 }
169 has_valid
170 }
171 _ => true,
172 }
173 });
174
175 let mut i = 0;
180 while i + 1 < messages.len() {
181 let has_tool_call =
182 messages[i].role == Role::Assistant && content_has_tool_call(&messages[i].content);
183
184 if has_tool_call {
185 let mut j = i + 1;
187 let mut system_messages_to_relocate = Vec::new();
188 while j < messages.len() && messages[j].role == Role::System {
189 system_messages_to_relocate.push(j);
190 j += 1;
191 }
192 if !system_messages_to_relocate.is_empty() {
194 let mut extracted: Vec<Message> = Vec::new();
196 for &idx in system_messages_to_relocate.iter().rev() {
197 extracted.push(messages.remove(idx));
198 }
199 extracted.reverse();
200 for (offset, msg) in extracted.into_iter().enumerate() {
202 messages.insert(i + offset, msg);
203 i += 1; }
205 }
206 }
207 i += 1;
208 }
209}
210
211fn collect_tool_call_ids(content: &Content, ids: &mut HashSet<String>) {
213 match content {
214 Content::ToolCall { id, .. } => {
215 ids.insert(id.clone());
216 }
217 Content::MultiPart { parts } => {
218 for part in parts {
219 collect_tool_call_ids(part, ids);
220 }
221 }
222 _ => {}
223 }
224}
225
226fn content_has_tool_call(content: &Content) -> bool {
228 match content {
229 Content::ToolCall { .. } => true,
230 Content::MultiPart { parts } => parts.iter().any(content_has_tool_call),
231 _ => false,
232 }
233}
234
235pub struct Brain {
238 provider: Arc<dyn LlmProvider>,
239 system_prompt: String,
240 total_usage: TokenUsage,
241 total_cost: CostEstimate,
242 token_counter: TokenCounter,
243 knowledge_addendum: String,
245}
246
247impl Brain {
248 pub fn new(provider: Arc<dyn LlmProvider>, system_prompt: impl Into<String>) -> Self {
249 let model_name = provider.model_name().to_string();
250 Self {
251 provider,
252 system_prompt: system_prompt.into(),
253 total_usage: TokenUsage::default(),
254 total_cost: CostEstimate::default(),
255 token_counter: TokenCounter::for_model(&model_name),
256 knowledge_addendum: String::new(),
257 }
258 }
259
260 pub fn set_knowledge_addendum(&mut self, addendum: String) {
262 self.knowledge_addendum = addendum;
263 }
264
265 pub fn estimate_tokens(&self, messages: &[Message]) -> usize {
267 self.token_counter.count_messages(messages)
268 }
269
270 pub fn build_messages(&self, conversation: &[Message]) -> Vec<Message> {
278 let mut messages = Vec::with_capacity(conversation.len() + 1);
279 if self.knowledge_addendum.is_empty() {
280 messages.push(Message::system(&self.system_prompt));
281 } else {
282 let augmented = format!("{}{}", self.system_prompt, self.knowledge_addendum);
283 messages.push(Message::system(&augmented));
284 }
285 messages.extend_from_slice(conversation);
286 sanitize_tool_sequence(&mut messages);
287 messages
288 }
289
290 pub async fn think(
292 &mut self,
293 conversation: &[Message],
294 tools: Option<Vec<ToolDefinition>>,
295 ) -> Result<CompletionResponse, LlmError> {
296 let messages = self.build_messages(conversation);
297 let token_estimate = self.provider.estimate_tokens(&messages);
298 let context_limit = self.provider.context_window();
299
300 if token_estimate > context_limit {
301 return Err(LlmError::ContextOverflow {
302 used: token_estimate,
303 limit: context_limit,
304 });
305 }
306
307 debug!(
308 model = self.provider.model_name(),
309 estimated_tokens = token_estimate,
310 "Sending completion request"
311 );
312
313 let request = CompletionRequest {
314 messages,
315 tools,
316 temperature: 0.7,
317 max_tokens: None,
318 stop_sequences: Vec::new(),
319 model: None,
320 };
321
322 let response = self.provider.complete(request).await?;
323
324 self.total_usage.accumulate(&response.usage);
326 let (input_rate, output_rate) = self.provider.cost_per_token();
327 let cost = CostEstimate {
328 input_cost: response.usage.input_tokens as f64 * input_rate,
329 output_cost: response.usage.output_tokens as f64 * output_rate,
330 };
331 self.total_cost.accumulate(&cost);
332
333 info!(
334 input_tokens = response.usage.input_tokens,
335 output_tokens = response.usage.output_tokens,
336 cost = format!("${:.4}", cost.total()),
337 "Completion received"
338 );
339
340 Ok(response)
341 }
342
343 pub async fn think_with_retry(
349 &mut self,
350 conversation: &[Message],
351 tools: Option<Vec<ToolDefinition>>,
352 max_retries: usize,
353 ) -> Result<CompletionResponse, LlmError> {
354 let mut last_error = None;
355
356 for attempt in 0..=max_retries {
357 match self.think(conversation, tools.clone()).await {
358 Ok(response) => return Ok(response),
359 Err(e) if Self::is_retryable(&e) => {
360 if attempt < max_retries {
361 let backoff_secs = std::cmp::min(1u64 << attempt, 32);
362 let wait = match &e {
363 LlmError::RateLimited { retry_after_secs } => {
364 std::cmp::max(*retry_after_secs, backoff_secs)
365 }
366 _ => backoff_secs,
367 };
368 info!(
369 attempt = attempt + 1,
370 max_retries,
371 backoff_secs = wait,
372 error = %e,
373 "Retrying after transient error"
374 );
375 tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
376 last_error = Some(e);
377 } else {
378 return Err(e);
379 }
380 }
381 Err(e) => return Err(e),
382 }
383 }
384
385 Err(last_error.unwrap_or(LlmError::Connection {
386 message: "Max retries exceeded".to_string(),
387 }))
388 }
389
390 pub fn is_retryable(error: &LlmError) -> bool {
392 matches!(
393 error,
394 LlmError::RateLimited { .. } | LlmError::Timeout { .. } | LlmError::Connection { .. }
395 )
396 }
397
398 pub async fn think_streaming(
400 &mut self,
401 conversation: &[Message],
402 tools: Option<Vec<ToolDefinition>>,
403 tx: mpsc::Sender<StreamEvent>,
404 ) -> Result<(), LlmError> {
405 let messages = self.build_messages(conversation);
406
407 let request = CompletionRequest {
408 messages,
409 tools,
410 temperature: 0.7,
411 max_tokens: None,
412 stop_sequences: Vec::new(),
413 model: None,
414 };
415
416 self.provider.complete_streaming(request, tx).await
417 }
418
419 pub fn total_usage(&self) -> &TokenUsage {
421 &self.total_usage
422 }
423
424 pub fn total_cost(&self) -> &CostEstimate {
426 &self.total_cost
427 }
428
429 pub fn model_name(&self) -> &str {
431 self.provider.model_name()
432 }
433
434 pub fn context_window(&self) -> usize {
436 self.provider.context_window()
437 }
438
439 pub fn provider_cost_rates(&self) -> (f64, f64) {
441 self.provider.cost_per_token()
442 }
443
444 pub fn provider(&self) -> &dyn LlmProvider {
446 &*self.provider
447 }
448
449 pub fn provider_arc(&self) -> Arc<dyn LlmProvider> {
451 Arc::clone(&self.provider)
452 }
453
454 pub fn track_usage(&mut self, usage: &TokenUsage) {
456 self.total_usage.accumulate(usage);
457 let (input_rate, output_rate) = self.provider.cost_per_token();
458 let cost = CostEstimate {
459 input_cost: usage.input_tokens as f64 * input_rate,
460 output_cost: usage.output_tokens as f64 * output_rate,
461 };
462 self.total_cost.accumulate(&cost);
463 }
464
465 pub fn context_usage_ratio(&self, conversation: &[Message]) -> f32 {
467 let messages = self.build_messages(conversation);
468 let tokens = self.provider.estimate_tokens(&messages);
469 tokens as f32 / self.provider.context_window() as f32
470 }
471}
472
473pub struct MockLlmProvider {
475 model: String,
476 context_window: usize,
477 responses: std::sync::Mutex<Vec<CompletionResponse>>,
478}
479
480impl MockLlmProvider {
481 pub fn new() -> Self {
482 Self {
483 model: "mock-model".to_string(),
484 context_window: 128_000,
485 responses: std::sync::Mutex::new(Vec::new()),
486 }
487 }
488
489 pub fn with_response(text: &str) -> Self {
493 let provider = Self::new();
494 for _ in 0..20 {
495 provider.queue_response(Self::text_response(text));
496 }
497 provider
498 }
499
500 pub fn queue_response(&self, response: CompletionResponse) {
502 self.responses.lock().unwrap().push(response);
503 }
504
505 pub fn text_response(text: &str) -> CompletionResponse {
507 CompletionResponse {
508 message: Message::assistant(text),
509 usage: TokenUsage {
510 input_tokens: 100,
511 output_tokens: 50,
512 },
513 model: "mock-model".to_string(),
514 finish_reason: Some("stop".to_string()),
515 }
516 }
517
518 pub fn tool_call_response(tool_name: &str, arguments: serde_json::Value) -> CompletionResponse {
520 let call_id = format!("call_{}", uuid::Uuid::new_v4());
521 CompletionResponse {
522 message: Message::new(
523 Role::Assistant,
524 Content::tool_call(&call_id, tool_name, arguments),
525 ),
526 usage: TokenUsage {
527 input_tokens: 100,
528 output_tokens: 30,
529 },
530 model: "mock-model".to_string(),
531 finish_reason: Some("tool_calls".to_string()),
532 }
533 }
534
535 pub fn multipart_response(
537 text: &str,
538 tool_name: &str,
539 arguments: serde_json::Value,
540 ) -> CompletionResponse {
541 let call_id = format!("call_{}", uuid::Uuid::new_v4());
542 CompletionResponse {
543 message: Message::new(
544 Role::Assistant,
545 Content::MultiPart {
546 parts: vec![
547 Content::text(text),
548 Content::tool_call(&call_id, tool_name, arguments),
549 ],
550 },
551 ),
552 usage: TokenUsage {
553 input_tokens: 100,
554 output_tokens: 50,
555 },
556 model: "mock-model".to_string(),
557 finish_reason: Some("tool_calls".to_string()),
558 }
559 }
560}
561
562impl Default for MockLlmProvider {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568#[async_trait]
569impl LlmProvider for MockLlmProvider {
570 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
571 let mut responses = self.responses.lock().unwrap();
572 if responses.is_empty() {
573 Ok(MockLlmProvider::text_response(
574 "I'm a mock LLM. No queued responses available.",
575 ))
576 } else {
577 Ok(responses.remove(0))
578 }
579 }
580
581 async fn complete_streaming(
582 &self,
583 request: CompletionRequest,
584 tx: mpsc::Sender<StreamEvent>,
585 ) -> Result<(), LlmError> {
586 let response = self.complete(request).await?;
587 if let Some(text) = response.message.content.as_text() {
588 for word in text.split_whitespace() {
589 let _ = tx.send(StreamEvent::Token(format!("{} ", word))).await;
590 }
591 }
592 let _ = tx
593 .send(StreamEvent::Done {
594 usage: response.usage,
595 })
596 .await;
597 Ok(())
598 }
599
600 fn estimate_tokens(&self, messages: &[Message]) -> usize {
601 messages
603 .iter()
604 .map(|m| match &m.content {
605 Content::Text { text } => text.len() / 4,
606 Content::ToolCall { arguments, .. } => arguments.to_string().len() / 4,
607 Content::ToolResult { output, .. } => output.len() / 4,
608 Content::MultiPart { parts } => parts
609 .iter()
610 .map(|p| match p {
611 Content::Text { text } => text.len() / 4,
612 _ => 50,
613 })
614 .sum(),
615 })
616 .sum::<usize>()
617 + 100 }
619
620 fn context_window(&self) -> usize {
621 self.context_window
622 }
623
624 fn supports_tools(&self) -> bool {
625 true
626 }
627
628 fn cost_per_token(&self) -> (f64, f64) {
629 (0.0, 0.0) }
631
632 fn model_name(&self) -> &str {
633 &self.model
634 }
635}
636
637pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are Rustant, a privacy-first autonomous personal assistant built in Rust. You help users with software engineering, daily productivity, and macOS automation tasks.
639
640CRITICAL — Tool selection rules:
641- You MUST use the dedicated tool for each task. Do NOT use shell_exec when a dedicated tool exists.
642- For clipboard: call macos_clipboard with {"action":"read"} or {"action":"write","content":"..."}
643- For battery/disk/CPU/version: call macos_system_info with {"action":"battery"}, {"action":"version"}, etc.
644- For running apps: call macos_app_control with {"action":"list_running"}
645- For calendar: call macos_calendar. For reminders: call macos_reminders. For notes: call macos_notes.
646- For screenshots: call macos_screenshot. For Spotlight search: call macos_spotlight.
647- shell_exec is a last resort — only use it for commands that have no dedicated tool.
648- Do NOT use document_read for clipboard or system operations — it reads document files only.
649- If a tool call fails, try a different tool or action — do NOT ask the user whether to proceed. Act autonomously.
650- Never call ask_user more than once per task unless the user's answer was genuinely unclear.
651
652Other behaviors:
653- Always read a file before modifying it
654- Prefer small, focused changes over large rewrites
655- Respect file boundaries and permissions
656
657Tool categories:
658
659File & Code: file_read, file_write, file_list, file_search, file_patch, smart_edit, codebase_search, document_read (for PDFs/docs only)
660Git: git_status, git_diff, git_commit
661Shell: shell_exec (last resort only)
662Utilities: calculator, datetime, echo, web_search (for web searches — uses DuckDuckGo, preferred over safari/shell), web_fetch (for fetching URL content — preferred over safari/shell), http_api, template, pdf_generate, compress, file_organizer
663Personal Productivity: pomodoro, inbox, finance, flashcards, travel, relationships
664Research & Intelligence: arxiv_research (ALWAYS use this for paper/preprint searches — it has a built-in arXiv API client, never use safari/curl for arXiv), knowledge_graph (concepts, papers, relationships, BFS traversal), experiment_tracker (hypotheses, experiments, evidence)
665Code Analysis: code_intelligence (architecture, patterns, tech debt, API surface, dependency map), codebase_search
666Professional Growth: skill_tracker (proficiency, practice logs, learning paths), career_intel (goals, achievements, portfolio), content_engine (multi-platform content pipeline, calendar)
667Life Management: life_planner (energy-aware scheduling, deadlines, habits), system_monitor (service topology, health checks, incidents), privacy_manager (data boundaries, compliance, audit), self_improvement (usage patterns, performance, preferences)
668macOS Native: macos_calendar, macos_reminders, macos_notes, macos_clipboard, macos_system_info, macos_app_control, macos_notification, macos_screenshot, macos_spotlight, macos_finder, macos_focus_mode, macos_mail, macos_music, macos_shortcuts, macos_meeting_recorder (use 'record_and_transcribe' for full meeting flow — TTS announcement, silence auto-stop, auto-transcribe to Notes.app; use 'stop' to end manually), macos_daily_briefing, macos_contacts, homekit
669macOS Automation: macos_gui_scripting, macos_accessibility, macos_screen_analyze, macos_safari (only for Safari-specific tasks like tab management — for web searches use web_search, for fetching pages use web_fetch)
670iMessage: imessage_contacts, imessage_send, imessage_read
671Voice: macos_say
672
673Workflows (structured multi-step templates — run via shell_exec "rustant workflow run <name>"):
674 code_review, refactor, test_generation, documentation, dependency_update,
675 security_scan, deployment, incident_response, morning_briefing, pr_review,
676 dependency_audit, changelog, meeting_recorder, daily_briefing_full,
677 end_of_day_summary, app_automation, email_triage, arxiv_research,
678 knowledge_graph, experiment_tracking, code_analysis, content_pipeline,
679 skill_development, career_planning, system_monitoring, life_planning,
680 privacy_audit, self_improvement_loop
681When a user asks for one of these tasks by name or description, execute the workflow or accomplish it step by step.
682
683Security rules:
684- Never execute commands that could damage the system or leak credentials
685- Do not read or write files containing secrets (.env, *.key, *.pem) unless explicitly asked
686- Sanitize all user input before passing to shell or AppleScript commands
687- When unsure about a destructive action, use ask_user to confirm first"#;
688
689pub struct TokenBudgetManager {
696 session_limit_usd: f64,
697 task_limit_usd: f64,
698 session_token_limit: usize,
699 halt_on_exceed: bool,
700 session_cost: f64,
701 task_cost: f64,
702 session_tokens: usize,
703}
704
705#[derive(Debug, Clone, PartialEq)]
707pub enum BudgetCheckResult {
708 Ok,
710 Warning { message: String, usage_pct: f64 },
712 Exceeded { message: String },
714}
715
716impl TokenBudgetManager {
717 pub fn new(config: Option<&crate::config::BudgetConfig>) -> Self {
720 match config {
721 Some(cfg) => Self {
722 session_limit_usd: cfg.session_limit_usd,
723 task_limit_usd: cfg.task_limit_usd,
724 session_token_limit: cfg.session_token_limit,
725 halt_on_exceed: cfg.halt_on_exceed,
726 session_cost: 0.0,
727 task_cost: 0.0,
728 session_tokens: 0,
729 },
730 None => Self {
731 session_limit_usd: 0.0,
732 task_limit_usd: 0.0,
733 session_token_limit: 0,
734 halt_on_exceed: false,
735 session_cost: 0.0,
736 task_cost: 0.0,
737 session_tokens: 0,
738 },
739 }
740 }
741
742 pub fn reset_task(&mut self) {
744 self.task_cost = 0.0;
745 }
746
747 pub fn record_usage(&mut self, usage: &TokenUsage, cost: &CostEstimate) {
749 self.session_cost += cost.total();
750 self.task_cost += cost.total();
751 self.session_tokens += usage.total();
752 }
753
754 pub fn check_budget(
760 &self,
761 estimated_input_tokens: usize,
762 input_rate: f64,
763 output_rate: f64,
764 ) -> BudgetCheckResult {
765 let predicted_output = estimated_input_tokens / 2;
767 let predicted_cost =
768 (estimated_input_tokens as f64 * input_rate) + (predicted_output as f64 * output_rate);
769
770 let projected_session_cost = self.session_cost + predicted_cost;
771 let projected_task_cost = self.task_cost + predicted_cost;
772 let projected_session_tokens =
773 self.session_tokens + estimated_input_tokens + predicted_output;
774
775 if self.session_limit_usd > 0.0 && projected_session_cost > self.session_limit_usd {
777 return BudgetCheckResult::Exceeded {
778 message: format!(
779 "Session cost ${:.4} would exceed limit ${:.4}",
780 projected_session_cost, self.session_limit_usd
781 ),
782 };
783 }
784
785 if self.task_limit_usd > 0.0 && projected_task_cost > self.task_limit_usd {
787 return BudgetCheckResult::Exceeded {
788 message: format!(
789 "Task cost ${:.4} would exceed limit ${:.4}",
790 projected_task_cost, self.task_limit_usd
791 ),
792 };
793 }
794
795 if self.session_token_limit > 0 && projected_session_tokens > self.session_token_limit {
797 return BudgetCheckResult::Exceeded {
798 message: format!(
799 "Session tokens {} would exceed limit {}",
800 projected_session_tokens, self.session_token_limit
801 ),
802 };
803 }
804
805 if self.session_limit_usd > 0.0 {
807 let pct = projected_session_cost / self.session_limit_usd;
808 if pct > 0.8 {
809 return BudgetCheckResult::Warning {
810 message: format!(
811 "Session cost at {:.0}% of ${:.4} limit",
812 pct * 100.0,
813 self.session_limit_usd
814 ),
815 usage_pct: pct,
816 };
817 }
818 }
819
820 BudgetCheckResult::Ok
821 }
822
823 pub fn should_halt_on_exceed(&self) -> bool {
825 self.halt_on_exceed
826 }
827
828 pub fn session_cost(&self) -> f64 {
830 self.session_cost
831 }
832
833 pub fn task_cost(&self) -> f64 {
835 self.task_cost
836 }
837
838 pub fn session_tokens(&self) -> usize {
840 self.session_tokens
841 }
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847
848 #[tokio::test]
849 async fn test_mock_provider_default_response() {
850 let provider = MockLlmProvider::new();
851 let request = CompletionRequest::default();
852 let response = provider.complete(request).await.unwrap();
853 assert!(response.message.content.as_text().is_some());
854 }
855
856 #[tokio::test]
857 async fn test_mock_provider_queued_responses() {
858 let provider = MockLlmProvider::new();
859 provider.queue_response(MockLlmProvider::text_response("first"));
860 provider.queue_response(MockLlmProvider::text_response("second"));
861
862 let r1 = provider
863 .complete(CompletionRequest::default())
864 .await
865 .unwrap();
866 assert_eq!(r1.message.content.as_text(), Some("first"));
867
868 let r2 = provider
869 .complete(CompletionRequest::default())
870 .await
871 .unwrap();
872 assert_eq!(r2.message.content.as_text(), Some("second"));
873 }
874
875 #[tokio::test]
876 async fn test_mock_provider_streaming() {
877 let provider = MockLlmProvider::new();
878 provider.queue_response(MockLlmProvider::text_response("hello world"));
879
880 let (tx, mut rx) = mpsc::channel(32);
881 provider
882 .complete_streaming(CompletionRequest::default(), tx)
883 .await
884 .unwrap();
885
886 let mut tokens = Vec::new();
887 while let Some(event) = rx.recv().await {
888 match event {
889 StreamEvent::Token(t) => tokens.push(t),
890 StreamEvent::Done { .. } => break,
891 _ => {}
892 }
893 }
894 assert_eq!(tokens.len(), 2); }
896
897 #[test]
898 fn test_mock_provider_token_estimation() {
899 let provider = MockLlmProvider::new();
900 let messages = vec![Message::user("Hello, this is a test message.")];
901 let tokens = provider.estimate_tokens(&messages);
902 assert!(tokens > 0);
903 }
904
905 #[test]
906 fn test_mock_provider_properties() {
907 let provider = MockLlmProvider::new();
908 assert_eq!(provider.context_window(), 128_000);
909 assert!(provider.supports_tools());
910 assert_eq!(provider.cost_per_token(), (0.0, 0.0));
911 assert_eq!(provider.model_name(), "mock-model");
912 }
913
914 #[tokio::test]
915 async fn test_brain_think() {
916 let provider = Arc::new(MockLlmProvider::new());
917 provider.queue_response(MockLlmProvider::text_response("I can help with that."));
918
919 let mut brain = Brain::new(provider, "You are a helpful assistant.");
920 let conversation = vec![Message::user("Help me refactor")];
921
922 let response = brain.think(&conversation, None).await.unwrap();
923 assert_eq!(
924 response.message.content.as_text(),
925 Some("I can help with that.")
926 );
927 assert!(brain.total_usage().total() > 0);
928 }
929
930 #[tokio::test]
931 async fn test_brain_builds_messages_with_system_prompt() {
932 let provider = Arc::new(MockLlmProvider::new());
933 let brain = Brain::new(provider, "system prompt");
934 let conversation = vec![Message::user("hello")];
935
936 let messages = brain.build_messages(&conversation);
937 assert_eq!(messages.len(), 2);
938 assert_eq!(messages[0].role, Role::System);
939 assert_eq!(messages[0].content.as_text(), Some("system prompt"));
940 assert_eq!(messages[1].role, Role::User);
941 }
942
943 #[test]
944 fn test_brain_context_usage_ratio() {
945 let provider = Arc::new(MockLlmProvider::new());
946 let brain = Brain::new(provider, "system");
947 let conversation = vec![Message::user("short message")];
948
949 let ratio = brain.context_usage_ratio(&conversation);
950 assert!(ratio > 0.0);
951 assert!(ratio < 1.0);
952 }
953
954 #[test]
955 fn test_mock_tool_call_response() {
956 let response = MockLlmProvider::tool_call_response(
957 "file_read",
958 serde_json::json!({"path": "/tmp/test.rs"}),
959 );
960 match &response.message.content {
961 Content::ToolCall {
962 name, arguments, ..
963 } => {
964 assert_eq!(name, "file_read");
965 assert_eq!(arguments["path"], "/tmp/test.rs");
966 }
967 _ => panic!("Expected ToolCall content"),
968 }
969 }
970
971 #[test]
972 fn test_default_system_prompt() {
973 assert!(DEFAULT_SYSTEM_PROMPT.contains("Rustant"));
974 assert!(DEFAULT_SYSTEM_PROMPT.contains("autonomous"));
975 }
976
977 #[test]
978 fn test_is_retryable() {
979 assert!(Brain::is_retryable(&LlmError::RateLimited {
980 retry_after_secs: 5
981 }));
982 assert!(Brain::is_retryable(&LlmError::Timeout { timeout_secs: 30 }));
983 assert!(Brain::is_retryable(&LlmError::Connection {
984 message: "reset".into()
985 }));
986 assert!(!Brain::is_retryable(&LlmError::ContextOverflow {
987 used: 200_000,
988 limit: 128_000
989 }));
990 assert!(!Brain::is_retryable(&LlmError::AuthFailed {
991 provider: "openai".into()
992 }));
993 }
994
995 struct FailingProvider {
997 failures_remaining: std::sync::Mutex<usize>,
998 error_type: String,
999 success_response: CompletionResponse,
1000 }
1001
1002 impl FailingProvider {
1003 fn new(failures: usize, error_type: &str) -> Self {
1004 Self {
1005 failures_remaining: std::sync::Mutex::new(failures),
1006 error_type: error_type.to_string(),
1007 success_response: MockLlmProvider::text_response("Success after retry"),
1008 }
1009 }
1010 }
1011
1012 #[async_trait]
1013 impl LlmProvider for FailingProvider {
1014 async fn complete(
1015 &self,
1016 _request: CompletionRequest,
1017 ) -> Result<CompletionResponse, LlmError> {
1018 let mut remaining = self.failures_remaining.lock().unwrap();
1019 if *remaining > 0 {
1020 *remaining -= 1;
1021 match self.error_type.as_str() {
1022 "rate_limited" => Err(LlmError::RateLimited {
1023 retry_after_secs: 0,
1024 }),
1025 "timeout" => Err(LlmError::Timeout { timeout_secs: 5 }),
1026 "connection" => Err(LlmError::Connection {
1027 message: "connection reset".into(),
1028 }),
1029 _ => Err(LlmError::ApiRequest {
1030 message: "non-retryable".into(),
1031 }),
1032 }
1033 } else {
1034 Ok(self.success_response.clone())
1035 }
1036 }
1037
1038 async fn complete_streaming(
1039 &self,
1040 _request: CompletionRequest,
1041 _tx: mpsc::Sender<StreamEvent>,
1042 ) -> Result<(), LlmError> {
1043 Ok(())
1044 }
1045
1046 fn estimate_tokens(&self, _messages: &[Message]) -> usize {
1047 100
1048 }
1049 fn context_window(&self) -> usize {
1050 128_000
1051 }
1052 fn supports_tools(&self) -> bool {
1053 true
1054 }
1055 fn cost_per_token(&self) -> (f64, f64) {
1056 (0.0, 0.0)
1057 }
1058 fn model_name(&self) -> &str {
1059 "failing-mock"
1060 }
1061 }
1062
1063 #[tokio::test]
1064 async fn test_think_with_retry_succeeds_after_failures() {
1065 let provider = Arc::new(FailingProvider::new(2, "connection"));
1066 let mut brain = Brain::new(provider, "system");
1067 let conversation = vec![Message::user("test")];
1068
1069 let result = brain.think_with_retry(&conversation, None, 3).await;
1070 assert!(result.is_ok());
1071 assert_eq!(
1072 result.unwrap().message.content.as_text(),
1073 Some("Success after retry")
1074 );
1075 }
1076
1077 #[tokio::test]
1078 async fn test_think_with_retry_exhausted() {
1079 let provider = Arc::new(FailingProvider::new(5, "timeout"));
1080 let mut brain = Brain::new(provider, "system");
1081 let conversation = vec![Message::user("test")];
1082
1083 let result = brain.think_with_retry(&conversation, None, 2).await;
1084 assert!(result.is_err());
1085 assert!(matches!(result.unwrap_err(), LlmError::Timeout { .. }));
1086 }
1087
1088 #[tokio::test]
1089 async fn test_think_with_retry_non_retryable_fails_immediately() {
1090 let provider = Arc::new(FailingProvider::new(1, "non_retryable"));
1091 let mut brain = Brain::new(provider, "system");
1092 let conversation = vec![Message::user("test")];
1093
1094 let result = brain.think_with_retry(&conversation, None, 3).await;
1095 assert!(result.is_err());
1096 assert!(matches!(result.unwrap_err(), LlmError::ApiRequest { .. }));
1097 }
1098
1099 #[tokio::test]
1100 async fn test_think_with_retry_rate_limited() {
1101 let provider = Arc::new(FailingProvider::new(1, "rate_limited"));
1102 let mut brain = Brain::new(provider, "system");
1103 let conversation = vec![Message::user("test")];
1104
1105 let result = brain.think_with_retry(&conversation, None, 2).await;
1106 assert!(result.is_ok());
1107 }
1108
1109 #[test]
1110 fn test_track_usage() {
1111 let provider = Arc::new(MockLlmProvider::new());
1112 let mut brain = Brain::new(provider, "system");
1113
1114 let usage = TokenUsage {
1115 input_tokens: 100,
1116 output_tokens: 50,
1117 };
1118 brain.track_usage(&usage);
1119
1120 assert_eq!(brain.total_usage().input_tokens, 100);
1121 assert_eq!(brain.total_usage().output_tokens, 50);
1122 }
1123
1124 #[test]
1125 fn test_token_counter_basic() {
1126 let counter = TokenCounter::for_model("gpt-4o");
1127 let count = counter.count("Hello, world!");
1128 assert!(count > 0);
1129 assert!(count < 20); }
1131
1132 #[test]
1133 fn test_token_counter_messages() {
1134 let counter = TokenCounter::for_model("gpt-4o");
1135 let messages = vec![
1136 Message::system("You are a helpful assistant."),
1137 Message::user("What is 2 + 2?"),
1138 ];
1139 let count = counter.count_messages(&messages);
1140 assert!(count > 5);
1141 assert!(count < 100);
1142 }
1143
1144 #[test]
1145 fn test_token_counter_unknown_model_falls_back() {
1146 let counter = TokenCounter::for_model("unknown-model-xyz");
1147 let count = counter.count("Hello");
1148 assert!(count > 0); }
1150
1151 #[test]
1152 fn test_brain_estimate_tokens() {
1153 let provider = Arc::new(MockLlmProvider::new());
1154 let brain = Brain::new(provider, "system");
1155 let messages = vec![Message::user("Hello, this is a test.")];
1156 let estimate = brain.estimate_tokens(&messages);
1157 assert!(estimate > 0);
1158 }
1159
1160 #[test]
1163 fn test_sanitize_removes_orphaned_tool_results() {
1164 let mut messages = vec![
1165 Message::system("You are a helper."),
1166 Message::user("do something"),
1167 Message::tool_result("call_orphan_123", "some result", false),
1169 Message::assistant("Done!"),
1170 ];
1171
1172 super::sanitize_tool_sequence(&mut messages);
1173
1174 assert_eq!(messages.len(), 3);
1176 assert_eq!(messages[0].role, Role::System);
1177 assert_eq!(messages[1].role, Role::User);
1178 assert_eq!(messages[2].role, Role::Assistant);
1179 }
1180
1181 #[test]
1182 fn test_sanitize_preserves_valid_sequence() {
1183 let mut messages = vec![
1184 Message::system("system prompt"),
1185 Message::user("read main.rs"),
1186 Message::new(
1187 Role::Assistant,
1188 Content::tool_call(
1189 "call_1",
1190 "file_read",
1191 serde_json::json!({"path": "main.rs"}),
1192 ),
1193 ),
1194 Message::tool_result("call_1", "fn main() {}", false),
1195 Message::assistant("Here is the file content."),
1196 ];
1197
1198 super::sanitize_tool_sequence(&mut messages);
1199
1200 assert_eq!(messages.len(), 5);
1202 }
1203
1204 #[test]
1205 fn test_sanitize_handles_system_between_call_and_result() {
1206 let mut messages = vec![
1207 Message::system("system prompt"),
1208 Message::user("do something"),
1209 Message::new(
1210 Role::Assistant,
1211 Content::tool_call("call_1", "file_read", serde_json::json!({"path": "x.rs"})),
1212 ),
1213 Message::system("routing hint: use file_read"),
1215 Message::tool_result("call_1", "file contents", false),
1216 Message::assistant("Done"),
1217 ];
1218
1219 super::sanitize_tool_sequence(&mut messages);
1220
1221 let assistant_idx = messages
1224 .iter()
1225 .position(|m| m.role == Role::Assistant && super::content_has_tool_call(&m.content))
1226 .unwrap();
1227
1228 let next = &messages[assistant_idx + 1];
1230 assert!(
1231 matches!(&next.content, Content::ToolResult { .. })
1232 || next.role == Role::Tool
1233 || next.role == Role::User,
1234 "Expected tool_result after tool_call, got {:?}",
1235 next.role
1236 );
1237 }
1238
1239 #[test]
1240 fn test_sanitize_multipart_tool_call() {
1241 let mut messages = vec![
1242 Message::user("do two things"),
1243 Message::new(
1244 Role::Assistant,
1245 Content::MultiPart {
1246 parts: vec![
1247 Content::text("I'll read both files."),
1248 Content::tool_call(
1249 "call_a",
1250 "file_read",
1251 serde_json::json!({"path": "a.rs"}),
1252 ),
1253 Content::tool_call(
1254 "call_b",
1255 "file_read",
1256 serde_json::json!({"path": "b.rs"}),
1257 ),
1258 ],
1259 },
1260 ),
1261 Message::tool_result("call_a", "contents of a", false),
1262 Message::tool_result("call_b", "contents of b", false),
1263 Message::tool_result("call_nonexistent", "orphan", false),
1265 ];
1266
1267 super::sanitize_tool_sequence(&mut messages);
1268
1269 assert_eq!(messages.len(), 4);
1271 }
1272
1273 #[test]
1274 fn test_sanitize_empty_messages() {
1275 let mut messages: Vec<Message> = vec![];
1276 super::sanitize_tool_sequence(&mut messages);
1277 assert!(messages.is_empty());
1278 }
1279
1280 #[test]
1281 fn test_sanitize_no_tool_messages() {
1282 let mut messages = vec![
1283 Message::system("prompt"),
1284 Message::user("hello"),
1285 Message::assistant("hi"),
1286 ];
1287 super::sanitize_tool_sequence(&mut messages);
1288 assert_eq!(messages.len(), 3);
1289 }
1290}