1use crate::error::LlmError;
7use crate::types::{
8 CompletionRequest, CompletionResponse, Content, CostEstimate, Message, Role, StreamEvent,
9 TokenUsage, ToolDefinition,
10};
11use async_trait::async_trait;
12use std::collections::HashSet;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tracing::{debug, info, warn};
16
17#[async_trait]
19pub trait LlmProvider: Send + Sync {
20 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
22
23 async fn complete_streaming(
25 &self,
26 request: CompletionRequest,
27 tx: mpsc::Sender<StreamEvent>,
28 ) -> Result<(), LlmError>;
29
30 fn estimate_tokens(&self, messages: &[Message]) -> usize;
32
33 fn context_window(&self) -> usize;
35
36 fn supports_tools(&self) -> bool;
38
39 fn cost_per_token(&self) -> (f64, f64);
41
42 fn model_name(&self) -> &str;
44}
45
46pub struct TokenCounter {
48 bpe: tiktoken_rs::CoreBPE,
49}
50
51impl TokenCounter {
52 pub fn for_model(model: &str) -> Self {
55 let bpe = tiktoken_rs::get_bpe_from_model(model).unwrap_or_else(|_| {
56 tiktoken_rs::cl100k_base().expect("cl100k_base should be available")
57 });
58 Self { bpe }
59 }
60
61 pub fn count(&self, text: &str) -> usize {
63 self.bpe.encode_with_special_tokens(text).len()
64 }
65
66 pub fn count_tool_definitions(&self, tools: &[ToolDefinition]) -> usize {
71 let mut total = 0;
72 for tool in tools {
73 total += 10; total += self.count(&tool.name);
75 total += self.count(&tool.description);
76 total += self.count(&tool.parameters.to_string());
77 }
78 total
79 }
80
81 pub fn count_messages(&self, messages: &[Message]) -> usize {
84 let mut total = 0;
85 for msg in messages {
86 total += 4;
88 match &msg.content {
89 Content::Text { text } => total += self.count(text),
90 Content::ToolCall {
91 name, arguments, ..
92 } => {
93 total += self.count(name);
94 total += self.count(&arguments.to_string());
95 }
96 Content::ToolResult { output, .. } => {
97 total += self.count(output);
98 }
99 Content::MultiPart { parts } => {
100 for part in parts {
101 match part {
102 Content::Text { text } => total += self.count(text),
103 Content::ToolCall {
104 name, arguments, ..
105 } => {
106 total += self.count(name);
107 total += self.count(&arguments.to_string());
108 }
109 Content::ToolResult { output, .. } => {
110 total += self.count(output);
111 }
112 _ => total += 10,
113 }
114 }
115 }
116 }
117 }
118 total + 3 }
120}
121
122pub fn sanitize_tool_sequence(messages: &mut Vec<Message>) {
134 let mut tool_call_ids: HashSet<String> = HashSet::new();
136 for msg in messages.iter() {
137 if msg.role != Role::Assistant {
138 continue;
139 }
140 collect_tool_call_ids(&msg.content, &mut tool_call_ids);
141 }
142
143 messages.retain(|msg| {
145 if msg.role != Role::Tool {
146 if msg.role == Role::User
148 && let Content::ToolResult { call_id, .. } = &msg.content
149 && !tool_call_ids.contains(call_id)
150 {
151 warn!(
152 call_id = call_id.as_str(),
153 "Removing orphaned tool_result (no matching tool_call)"
154 );
155 return false;
156 }
157 return true;
158 }
159 match &msg.content {
160 Content::ToolResult { call_id, .. } => {
161 if tool_call_ids.contains(call_id) {
162 true
163 } else {
164 warn!(
165 call_id = call_id.as_str(),
166 "Removing orphaned tool_result (no matching tool_call)"
167 );
168 false
169 }
170 }
171 Content::MultiPart { parts } => {
172 let has_valid = parts.iter().any(|p| {
174 if let Content::ToolResult { call_id, .. } = p {
175 tool_call_ids.contains(call_id)
176 } else {
177 true
178 }
179 });
180 if !has_valid {
181 warn!("Removing multipart tool message with all orphaned tool_results");
182 }
183 has_valid
184 }
185 _ => true,
186 }
187 });
188
189 let mut i = 0;
194 while i + 1 < messages.len() {
195 let has_tool_call =
196 messages[i].role == Role::Assistant && content_has_tool_call(&messages[i].content);
197
198 if has_tool_call {
199 let mut j = i + 1;
201 let mut system_messages_to_relocate = Vec::new();
202 while j < messages.len() && messages[j].role == Role::System {
203 system_messages_to_relocate.push(j);
204 j += 1;
205 }
206 if !system_messages_to_relocate.is_empty() {
208 let mut extracted: Vec<Message> = Vec::new();
210 for &idx in system_messages_to_relocate.iter().rev() {
211 extracted.push(messages.remove(idx));
212 }
213 extracted.reverse();
214 for (offset, msg) in extracted.into_iter().enumerate() {
216 messages.insert(i + offset, msg);
217 i += 1; }
219 }
220 }
221 i += 1;
222 }
223}
224
225fn collect_tool_call_ids(content: &Content, ids: &mut HashSet<String>) {
227 match content {
228 Content::ToolCall { id, .. } => {
229 ids.insert(id.clone());
230 }
231 Content::MultiPart { parts } => {
232 for part in parts {
233 collect_tool_call_ids(part, ids);
234 }
235 }
236 _ => {}
237 }
238}
239
240fn content_has_tool_call(content: &Content) -> bool {
242 match content {
243 Content::ToolCall { .. } => true,
244 Content::MultiPart { parts } => parts.iter().any(content_has_tool_call),
245 _ => false,
246 }
247}
248
249pub struct Brain {
252 provider: Arc<dyn LlmProvider>,
253 system_prompt: String,
254 total_usage: TokenUsage,
255 total_cost: CostEstimate,
256 token_counter: TokenCounter,
257 knowledge_addendum: String,
259}
260
261impl Brain {
262 pub fn new(provider: Arc<dyn LlmProvider>, system_prompt: impl Into<String>) -> Self {
263 let model_name = provider.model_name().to_string();
264 Self {
265 provider,
266 system_prompt: system_prompt.into(),
267 total_usage: TokenUsage::default(),
268 total_cost: CostEstimate::default(),
269 token_counter: TokenCounter::for_model(&model_name),
270 knowledge_addendum: String::new(),
271 }
272 }
273
274 pub fn set_knowledge_addendum(&mut self, addendum: String) {
276 self.knowledge_addendum = addendum;
277 }
278
279 pub fn estimate_tokens(&self, messages: &[Message]) -> usize {
281 self.token_counter.count_messages(messages)
282 }
283
284 pub fn estimate_tokens_with_tools(
286 &self,
287 messages: &[Message],
288 tools: Option<&[ToolDefinition]>,
289 ) -> usize {
290 let mut total = self.token_counter.count_messages(messages);
291 if let Some(tool_defs) = tools {
292 total += self.token_counter.count_tool_definitions(tool_defs);
293 }
294 total
295 }
296
297 pub fn build_messages(&self, conversation: &[Message]) -> Vec<Message> {
305 let mut messages = Vec::with_capacity(conversation.len() + 1);
306 if self.knowledge_addendum.is_empty() {
307 messages.push(Message::system(&self.system_prompt));
308 } else {
309 let augmented = format!("{}{}", self.system_prompt, self.knowledge_addendum);
310 messages.push(Message::system(&augmented));
311 }
312 messages.extend_from_slice(conversation);
313 sanitize_tool_sequence(&mut messages);
314 messages
315 }
316
317 pub async fn think(
319 &mut self,
320 conversation: &[Message],
321 tools: Option<Vec<ToolDefinition>>,
322 ) -> Result<CompletionResponse, LlmError> {
323 let messages = self.build_messages(conversation);
324 let mut token_estimate = self.provider.estimate_tokens(&messages);
325 if let Some(ref tool_defs) = tools {
326 token_estimate += self.token_counter.count_tool_definitions(tool_defs);
327 }
328 let context_limit = self.provider.context_window();
329
330 if token_estimate > context_limit {
331 return Err(LlmError::ContextOverflow {
332 used: token_estimate,
333 limit: context_limit,
334 });
335 }
336
337 debug!(
338 model = self.provider.model_name(),
339 estimated_tokens = token_estimate,
340 "Sending completion request"
341 );
342
343 let request = CompletionRequest {
344 messages,
345 tools,
346 temperature: 0.7,
347 max_tokens: None,
348 stop_sequences: Vec::new(),
349 model: None,
350 };
351
352 let response = self.provider.complete(request).await?;
353
354 self.total_usage.accumulate(&response.usage);
356 let (input_rate, output_rate) = self.provider.cost_per_token();
357 let cost = CostEstimate {
358 input_cost: response.usage.input_tokens as f64 * input_rate,
359 output_cost: response.usage.output_tokens as f64 * output_rate,
360 };
361 self.total_cost.accumulate(&cost);
362
363 info!(
364 input_tokens = response.usage.input_tokens,
365 output_tokens = response.usage.output_tokens,
366 cost = format!("${:.4}", cost.total()),
367 "Completion received"
368 );
369
370 Ok(response)
371 }
372
373 pub async fn think_with_retry(
379 &mut self,
380 conversation: &[Message],
381 tools: Option<Vec<ToolDefinition>>,
382 max_retries: usize,
383 ) -> Result<CompletionResponse, LlmError> {
384 let mut last_error = None;
385
386 for attempt in 0..=max_retries {
387 match self.think(conversation, tools.clone()).await {
388 Ok(response) => return Ok(response),
389 Err(e) if Self::is_retryable(&e) => {
390 if attempt < max_retries {
391 let backoff_secs = std::cmp::min(1u64 << attempt, 32);
392 let wait = match &e {
393 LlmError::RateLimited { retry_after_secs } => {
394 std::cmp::max(*retry_after_secs, backoff_secs)
395 }
396 _ => backoff_secs,
397 };
398 info!(
399 attempt = attempt + 1,
400 max_retries,
401 backoff_secs = wait,
402 error = %e,
403 "Retrying after transient error"
404 );
405 tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
406 last_error = Some(e);
407 } else {
408 return Err(e);
409 }
410 }
411 Err(e) => return Err(e),
412 }
413 }
414
415 Err(last_error.unwrap_or(LlmError::Connection {
416 message: "Max retries exceeded".to_string(),
417 }))
418 }
419
420 pub fn is_retryable(error: &LlmError) -> bool {
422 matches!(
423 error,
424 LlmError::RateLimited { .. } | LlmError::Timeout { .. } | LlmError::Connection { .. }
425 )
426 }
427
428 pub async fn think_streaming(
430 &mut self,
431 conversation: &[Message],
432 tools: Option<Vec<ToolDefinition>>,
433 tx: mpsc::Sender<StreamEvent>,
434 ) -> Result<(), LlmError> {
435 let messages = self.build_messages(conversation);
436
437 let request = CompletionRequest {
438 messages,
439 tools,
440 temperature: 0.7,
441 max_tokens: None,
442 stop_sequences: Vec::new(),
443 model: None,
444 };
445
446 self.provider.complete_streaming(request, tx).await
447 }
448
449 pub fn total_usage(&self) -> &TokenUsage {
451 &self.total_usage
452 }
453
454 pub fn total_cost(&self) -> &CostEstimate {
456 &self.total_cost
457 }
458
459 pub fn model_name(&self) -> &str {
461 self.provider.model_name()
462 }
463
464 pub fn context_window(&self) -> usize {
466 self.provider.context_window()
467 }
468
469 pub fn provider_cost_rates(&self) -> (f64, f64) {
471 self.provider.cost_per_token()
472 }
473
474 pub fn provider(&self) -> &dyn LlmProvider {
476 &*self.provider
477 }
478
479 pub fn provider_arc(&self) -> Arc<dyn LlmProvider> {
481 Arc::clone(&self.provider)
482 }
483
484 pub fn track_usage(&mut self, usage: &TokenUsage) {
486 self.total_usage.accumulate(usage);
487 let (input_rate, output_rate) = self.provider.cost_per_token();
488 let cost = CostEstimate {
489 input_cost: usage.input_tokens as f64 * input_rate,
490 output_cost: usage.output_tokens as f64 * output_rate,
491 };
492 self.total_cost.accumulate(&cost);
493 }
494
495 pub fn context_usage_ratio(&self, conversation: &[Message]) -> f32 {
497 let messages = self.build_messages(conversation);
498 let tokens = self.provider.estimate_tokens(&messages);
499 tokens as f32 / self.provider.context_window() as f32
500 }
501}
502
503pub struct MockLlmProvider {
505 model: String,
506 context_window: usize,
507 responses: std::sync::Mutex<Vec<CompletionResponse>>,
508}
509
510impl MockLlmProvider {
511 pub fn new() -> Self {
512 Self {
513 model: "mock-model".to_string(),
514 context_window: 128_000,
515 responses: std::sync::Mutex::new(Vec::new()),
516 }
517 }
518
519 pub fn with_response(text: &str) -> Self {
523 let provider = Self::new();
524 for _ in 0..20 {
525 provider.queue_response(Self::text_response(text));
526 }
527 provider
528 }
529
530 pub fn queue_response(&self, response: CompletionResponse) {
532 self.responses.lock().unwrap().push(response);
533 }
534
535 pub fn text_response(text: &str) -> CompletionResponse {
537 CompletionResponse {
538 message: Message::assistant(text),
539 usage: TokenUsage {
540 input_tokens: 100,
541 output_tokens: 50,
542 },
543 model: "mock-model".to_string(),
544 finish_reason: Some("stop".to_string()),
545 }
546 }
547
548 pub fn tool_call_response(tool_name: &str, arguments: serde_json::Value) -> CompletionResponse {
550 let call_id = format!("call_{}", uuid::Uuid::new_v4());
551 CompletionResponse {
552 message: Message::new(
553 Role::Assistant,
554 Content::tool_call(&call_id, tool_name, arguments),
555 ),
556 usage: TokenUsage {
557 input_tokens: 100,
558 output_tokens: 30,
559 },
560 model: "mock-model".to_string(),
561 finish_reason: Some("tool_calls".to_string()),
562 }
563 }
564
565 pub fn multipart_response(
567 text: &str,
568 tool_name: &str,
569 arguments: serde_json::Value,
570 ) -> CompletionResponse {
571 let call_id = format!("call_{}", uuid::Uuid::new_v4());
572 CompletionResponse {
573 message: Message::new(
574 Role::Assistant,
575 Content::MultiPart {
576 parts: vec![
577 Content::text(text),
578 Content::tool_call(&call_id, tool_name, arguments),
579 ],
580 },
581 ),
582 usage: TokenUsage {
583 input_tokens: 100,
584 output_tokens: 50,
585 },
586 model: "mock-model".to_string(),
587 finish_reason: Some("tool_calls".to_string()),
588 }
589 }
590}
591
592impl Default for MockLlmProvider {
593 fn default() -> Self {
594 Self::new()
595 }
596}
597
598#[async_trait]
599impl LlmProvider for MockLlmProvider {
600 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
601 let mut responses = self.responses.lock().unwrap();
602 if responses.is_empty() {
603 Ok(MockLlmProvider::text_response(
604 "I'm a mock LLM. No queued responses available.",
605 ))
606 } else {
607 Ok(responses.remove(0))
608 }
609 }
610
611 async fn complete_streaming(
612 &self,
613 request: CompletionRequest,
614 tx: mpsc::Sender<StreamEvent>,
615 ) -> Result<(), LlmError> {
616 let response = self.complete(request).await?;
617 if let Some(text) = response.message.content.as_text() {
618 for word in text.split_whitespace() {
619 let _ = tx.send(StreamEvent::Token(format!("{} ", word))).await;
620 }
621 }
622 let _ = tx
623 .send(StreamEvent::Done {
624 usage: response.usage,
625 })
626 .await;
627 Ok(())
628 }
629
630 fn estimate_tokens(&self, messages: &[Message]) -> usize {
631 messages
633 .iter()
634 .map(|m| match &m.content {
635 Content::Text { text } => text.len() / 4,
636 Content::ToolCall { arguments, .. } => arguments.to_string().len() / 4,
637 Content::ToolResult { output, .. } => output.len() / 4,
638 Content::MultiPart { parts } => parts
639 .iter()
640 .map(|p| match p {
641 Content::Text { text } => text.len() / 4,
642 _ => 50,
643 })
644 .sum(),
645 })
646 .sum::<usize>()
647 + 100 }
649
650 fn context_window(&self) -> usize {
651 self.context_window
652 }
653
654 fn supports_tools(&self) -> bool {
655 true
656 }
657
658 fn cost_per_token(&self) -> (f64, f64) {
659 (0.0, 0.0) }
661
662 fn model_name(&self) -> &str {
663 &self.model
664 }
665}
666
667pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are Rustant, a privacy-first autonomous personal assistant built in Rust. You help users with software engineering, daily productivity, and macOS automation tasks.
669
670CRITICAL — Tool selection rules:
671- You MUST use the dedicated tool for each task. Do NOT use shell_exec when a dedicated tool exists.
672- For clipboard: call macos_clipboard with {"action":"read"} or {"action":"write","content":"..."}
673- For battery/disk/CPU/version: call macos_system_info with {"action":"battery"}, {"action":"version"}, etc.
674- For running apps: call macos_app_control with {"action":"list_running"}
675- For calendar: call macos_calendar. For reminders: call macos_reminders. For notes: call macos_notes.
676- For screenshots: call macos_screenshot. For Spotlight search: call macos_spotlight.
677- shell_exec is a last resort — only use it for commands that have no dedicated tool.
678- Do NOT use document_read for clipboard or system operations — it reads document files only.
679- If a tool call fails, try a different tool or action — do NOT ask the user whether to proceed. Act autonomously.
680- Never call ask_user more than once per task unless the user's answer was genuinely unclear.
681
682Other behaviors:
683- Always read a file before modifying it
684- Prefer small, focused changes over large rewrites
685- Respect file boundaries and permissions
686
687Tools:
688- Use the tools provided to you. Each tool has a name, description, and parameter schema.
689- For arXiv/paper searches: ALWAYS use arxiv_research (built-in API client), never safari/curl.
690- For web searches: use web_search (DuckDuckGo), not safari/shell.
691- For fetching URLs: use web_fetch, not safari/shell.
692- For meeting recording: use macos_meeting_recorder with 'record_and_transcribe' for full flow.
693
694Workflows (structured multi-step templates — run via shell_exec "rustant workflow run <name>"):
695 code_review, refactor, test_generation, documentation, dependency_update,
696 security_scan, deployment, incident_response, morning_briefing, pr_review,
697 dependency_audit, changelog, meeting_recorder, daily_briefing_full,
698 end_of_day_summary, app_automation, email_triage, arxiv_research,
699 knowledge_graph, experiment_tracking, code_analysis, content_pipeline,
700 skill_development, career_planning, system_monitoring, life_planning,
701 privacy_audit, self_improvement_loop
702When a user asks for one of these tasks by name or description, execute the workflow or accomplish it step by step.
703
704Security rules:
705- Never execute commands that could damage the system or leak credentials
706- Do not read or write files containing secrets (.env, *.key, *.pem) unless explicitly asked
707- Sanitize all user input before passing to shell or AppleScript commands
708- When unsure about a destructive action, use ask_user to confirm first"#;
709
710pub struct TokenBudgetManager {
717 session_limit_usd: f64,
718 task_limit_usd: f64,
719 session_token_limit: usize,
720 halt_on_exceed: bool,
721 session_cost: f64,
722 task_cost: f64,
723 session_tokens: usize,
724}
725
726#[derive(Debug, Clone, PartialEq)]
728pub enum BudgetCheckResult {
729 Ok,
731 Warning { message: String, usage_pct: f64 },
733 Exceeded { message: String },
735}
736
737impl TokenBudgetManager {
738 pub fn new(config: Option<&crate::config::BudgetConfig>) -> Self {
741 match config {
742 Some(cfg) => Self {
743 session_limit_usd: cfg.session_limit_usd,
744 task_limit_usd: cfg.task_limit_usd,
745 session_token_limit: cfg.session_token_limit,
746 halt_on_exceed: cfg.halt_on_exceed,
747 session_cost: 0.0,
748 task_cost: 0.0,
749 session_tokens: 0,
750 },
751 None => Self {
752 session_limit_usd: 0.0,
753 task_limit_usd: 0.0,
754 session_token_limit: 0,
755 halt_on_exceed: false,
756 session_cost: 0.0,
757 task_cost: 0.0,
758 session_tokens: 0,
759 },
760 }
761 }
762
763 pub fn reset_task(&mut self) {
765 self.task_cost = 0.0;
766 }
767
768 pub fn record_usage(&mut self, usage: &TokenUsage, cost: &CostEstimate) {
770 self.session_cost += cost.total();
771 self.task_cost += cost.total();
772 self.session_tokens += usage.total();
773 }
774
775 pub fn check_budget(
781 &self,
782 estimated_input_tokens: usize,
783 input_rate: f64,
784 output_rate: f64,
785 ) -> BudgetCheckResult {
786 let predicted_output = estimated_input_tokens / 2;
788 let predicted_cost =
789 (estimated_input_tokens as f64 * input_rate) + (predicted_output as f64 * output_rate);
790
791 let projected_session_cost = self.session_cost + predicted_cost;
792 let projected_task_cost = self.task_cost + predicted_cost;
793 let projected_session_tokens =
794 self.session_tokens + estimated_input_tokens + predicted_output;
795
796 if self.session_limit_usd > 0.0 && projected_session_cost > self.session_limit_usd {
798 return BudgetCheckResult::Exceeded {
799 message: format!(
800 "Session cost ${:.4} would exceed limit ${:.4}",
801 projected_session_cost, self.session_limit_usd
802 ),
803 };
804 }
805
806 if self.task_limit_usd > 0.0 && projected_task_cost > self.task_limit_usd {
808 return BudgetCheckResult::Exceeded {
809 message: format!(
810 "Task cost ${:.4} would exceed limit ${:.4}",
811 projected_task_cost, self.task_limit_usd
812 ),
813 };
814 }
815
816 if self.session_token_limit > 0 && projected_session_tokens > self.session_token_limit {
818 return BudgetCheckResult::Exceeded {
819 message: format!(
820 "Session tokens {} would exceed limit {}",
821 projected_session_tokens, self.session_token_limit
822 ),
823 };
824 }
825
826 if self.session_limit_usd > 0.0 {
828 let pct = projected_session_cost / self.session_limit_usd;
829 if pct > 0.8 {
830 return BudgetCheckResult::Warning {
831 message: format!(
832 "Session cost at {:.0}% of ${:.4} limit",
833 pct * 100.0,
834 self.session_limit_usd
835 ),
836 usage_pct: pct,
837 };
838 }
839 }
840
841 BudgetCheckResult::Ok
842 }
843
844 pub fn should_halt_on_exceed(&self) -> bool {
846 self.halt_on_exceed
847 }
848
849 pub fn session_cost(&self) -> f64 {
851 self.session_cost
852 }
853
854 pub fn task_cost(&self) -> f64 {
856 self.task_cost
857 }
858
859 pub fn session_tokens(&self) -> usize {
861 self.session_tokens
862 }
863}
864
865#[cfg(test)]
866mod tests {
867 use super::*;
868
869 #[tokio::test]
870 async fn test_mock_provider_default_response() {
871 let provider = MockLlmProvider::new();
872 let request = CompletionRequest::default();
873 let response = provider.complete(request).await.unwrap();
874 assert!(response.message.content.as_text().is_some());
875 }
876
877 #[tokio::test]
878 async fn test_mock_provider_queued_responses() {
879 let provider = MockLlmProvider::new();
880 provider.queue_response(MockLlmProvider::text_response("first"));
881 provider.queue_response(MockLlmProvider::text_response("second"));
882
883 let r1 = provider
884 .complete(CompletionRequest::default())
885 .await
886 .unwrap();
887 assert_eq!(r1.message.content.as_text(), Some("first"));
888
889 let r2 = provider
890 .complete(CompletionRequest::default())
891 .await
892 .unwrap();
893 assert_eq!(r2.message.content.as_text(), Some("second"));
894 }
895
896 #[tokio::test]
897 async fn test_mock_provider_streaming() {
898 let provider = MockLlmProvider::new();
899 provider.queue_response(MockLlmProvider::text_response("hello world"));
900
901 let (tx, mut rx) = mpsc::channel(32);
902 provider
903 .complete_streaming(CompletionRequest::default(), tx)
904 .await
905 .unwrap();
906
907 let mut tokens = Vec::new();
908 while let Some(event) = rx.recv().await {
909 match event {
910 StreamEvent::Token(t) => tokens.push(t),
911 StreamEvent::Done { .. } => break,
912 _ => {}
913 }
914 }
915 assert_eq!(tokens.len(), 2); }
917
918 #[test]
919 fn test_mock_provider_token_estimation() {
920 let provider = MockLlmProvider::new();
921 let messages = vec![Message::user("Hello, this is a test message.")];
922 let tokens = provider.estimate_tokens(&messages);
923 assert!(tokens > 0);
924 }
925
926 #[test]
927 fn test_mock_provider_properties() {
928 let provider = MockLlmProvider::new();
929 assert_eq!(provider.context_window(), 128_000);
930 assert!(provider.supports_tools());
931 assert_eq!(provider.cost_per_token(), (0.0, 0.0));
932 assert_eq!(provider.model_name(), "mock-model");
933 }
934
935 #[tokio::test]
936 async fn test_brain_think() {
937 let provider = Arc::new(MockLlmProvider::new());
938 provider.queue_response(MockLlmProvider::text_response("I can help with that."));
939
940 let mut brain = Brain::new(provider, "You are a helpful assistant.");
941 let conversation = vec![Message::user("Help me refactor")];
942
943 let response = brain.think(&conversation, None).await.unwrap();
944 assert_eq!(
945 response.message.content.as_text(),
946 Some("I can help with that.")
947 );
948 assert!(brain.total_usage().total() > 0);
949 }
950
951 #[tokio::test]
952 async fn test_brain_builds_messages_with_system_prompt() {
953 let provider = Arc::new(MockLlmProvider::new());
954 let brain = Brain::new(provider, "system prompt");
955 let conversation = vec![Message::user("hello")];
956
957 let messages = brain.build_messages(&conversation);
958 assert_eq!(messages.len(), 2);
959 assert_eq!(messages[0].role, Role::System);
960 assert_eq!(messages[0].content.as_text(), Some("system prompt"));
961 assert_eq!(messages[1].role, Role::User);
962 }
963
964 #[test]
965 fn test_brain_context_usage_ratio() {
966 let provider = Arc::new(MockLlmProvider::new());
967 let brain = Brain::new(provider, "system");
968 let conversation = vec![Message::user("short message")];
969
970 let ratio = brain.context_usage_ratio(&conversation);
971 assert!(ratio > 0.0);
972 assert!(ratio < 1.0);
973 }
974
975 #[test]
976 fn test_mock_tool_call_response() {
977 let response = MockLlmProvider::tool_call_response(
978 "file_read",
979 serde_json::json!({"path": "/tmp/test.rs"}),
980 );
981 match &response.message.content {
982 Content::ToolCall {
983 name, arguments, ..
984 } => {
985 assert_eq!(name, "file_read");
986 assert_eq!(arguments["path"], "/tmp/test.rs");
987 }
988 _ => panic!("Expected ToolCall content"),
989 }
990 }
991
992 #[test]
993 fn test_default_system_prompt() {
994 assert!(DEFAULT_SYSTEM_PROMPT.contains("Rustant"));
995 assert!(DEFAULT_SYSTEM_PROMPT.contains("autonomous"));
996 }
997
998 #[test]
999 fn test_is_retryable() {
1000 assert!(Brain::is_retryable(&LlmError::RateLimited {
1001 retry_after_secs: 5
1002 }));
1003 assert!(Brain::is_retryable(&LlmError::Timeout { timeout_secs: 30 }));
1004 assert!(Brain::is_retryable(&LlmError::Connection {
1005 message: "reset".into()
1006 }));
1007 assert!(!Brain::is_retryable(&LlmError::ContextOverflow {
1008 used: 200_000,
1009 limit: 128_000
1010 }));
1011 assert!(!Brain::is_retryable(&LlmError::AuthFailed {
1012 provider: "openai".into()
1013 }));
1014 }
1015
1016 struct FailingProvider {
1018 failures_remaining: std::sync::Mutex<usize>,
1019 error_type: String,
1020 success_response: CompletionResponse,
1021 }
1022
1023 impl FailingProvider {
1024 fn new(failures: usize, error_type: &str) -> Self {
1025 Self {
1026 failures_remaining: std::sync::Mutex::new(failures),
1027 error_type: error_type.to_string(),
1028 success_response: MockLlmProvider::text_response("Success after retry"),
1029 }
1030 }
1031 }
1032
1033 #[async_trait]
1034 impl LlmProvider for FailingProvider {
1035 async fn complete(
1036 &self,
1037 _request: CompletionRequest,
1038 ) -> Result<CompletionResponse, LlmError> {
1039 let mut remaining = self.failures_remaining.lock().unwrap();
1040 if *remaining > 0 {
1041 *remaining -= 1;
1042 match self.error_type.as_str() {
1043 "rate_limited" => Err(LlmError::RateLimited {
1044 retry_after_secs: 0,
1045 }),
1046 "timeout" => Err(LlmError::Timeout { timeout_secs: 5 }),
1047 "connection" => Err(LlmError::Connection {
1048 message: "connection reset".into(),
1049 }),
1050 _ => Err(LlmError::ApiRequest {
1051 message: "non-retryable".into(),
1052 }),
1053 }
1054 } else {
1055 Ok(self.success_response.clone())
1056 }
1057 }
1058
1059 async fn complete_streaming(
1060 &self,
1061 _request: CompletionRequest,
1062 _tx: mpsc::Sender<StreamEvent>,
1063 ) -> Result<(), LlmError> {
1064 Ok(())
1065 }
1066
1067 fn estimate_tokens(&self, _messages: &[Message]) -> usize {
1068 100
1069 }
1070 fn context_window(&self) -> usize {
1071 128_000
1072 }
1073 fn supports_tools(&self) -> bool {
1074 true
1075 }
1076 fn cost_per_token(&self) -> (f64, f64) {
1077 (0.0, 0.0)
1078 }
1079 fn model_name(&self) -> &str {
1080 "failing-mock"
1081 }
1082 }
1083
1084 #[tokio::test]
1085 async fn test_think_with_retry_succeeds_after_failures() {
1086 let provider = Arc::new(FailingProvider::new(2, "connection"));
1087 let mut brain = Brain::new(provider, "system");
1088 let conversation = vec![Message::user("test")];
1089
1090 let result = brain.think_with_retry(&conversation, None, 3).await;
1091 assert!(result.is_ok());
1092 assert_eq!(
1093 result.unwrap().message.content.as_text(),
1094 Some("Success after retry")
1095 );
1096 }
1097
1098 #[tokio::test]
1099 async fn test_think_with_retry_exhausted() {
1100 let provider = Arc::new(FailingProvider::new(5, "timeout"));
1101 let mut brain = Brain::new(provider, "system");
1102 let conversation = vec![Message::user("test")];
1103
1104 let result = brain.think_with_retry(&conversation, None, 2).await;
1105 assert!(result.is_err());
1106 assert!(matches!(result.unwrap_err(), LlmError::Timeout { .. }));
1107 }
1108
1109 #[tokio::test]
1110 async fn test_think_with_retry_non_retryable_fails_immediately() {
1111 let provider = Arc::new(FailingProvider::new(1, "non_retryable"));
1112 let mut brain = Brain::new(provider, "system");
1113 let conversation = vec![Message::user("test")];
1114
1115 let result = brain.think_with_retry(&conversation, None, 3).await;
1116 assert!(result.is_err());
1117 assert!(matches!(result.unwrap_err(), LlmError::ApiRequest { .. }));
1118 }
1119
1120 #[tokio::test]
1121 async fn test_think_with_retry_rate_limited() {
1122 let provider = Arc::new(FailingProvider::new(1, "rate_limited"));
1123 let mut brain = Brain::new(provider, "system");
1124 let conversation = vec![Message::user("test")];
1125
1126 let result = brain.think_with_retry(&conversation, None, 2).await;
1127 assert!(result.is_ok());
1128 }
1129
1130 #[test]
1131 fn test_track_usage() {
1132 let provider = Arc::new(MockLlmProvider::new());
1133 let mut brain = Brain::new(provider, "system");
1134
1135 let usage = TokenUsage {
1136 input_tokens: 100,
1137 output_tokens: 50,
1138 };
1139 brain.track_usage(&usage);
1140
1141 assert_eq!(brain.total_usage().input_tokens, 100);
1142 assert_eq!(brain.total_usage().output_tokens, 50);
1143 }
1144
1145 #[test]
1146 fn test_token_counter_basic() {
1147 let counter = TokenCounter::for_model("gpt-4o");
1148 let count = counter.count("Hello, world!");
1149 assert!(count > 0);
1150 assert!(count < 20); }
1152
1153 #[test]
1154 fn test_token_counter_messages() {
1155 let counter = TokenCounter::for_model("gpt-4o");
1156 let messages = vec![
1157 Message::system("You are a helpful assistant."),
1158 Message::user("What is 2 + 2?"),
1159 ];
1160 let count = counter.count_messages(&messages);
1161 assert!(count > 5);
1162 assert!(count < 100);
1163 }
1164
1165 #[test]
1166 fn test_token_counter_unknown_model_falls_back() {
1167 let counter = TokenCounter::for_model("unknown-model-xyz");
1168 let count = counter.count("Hello");
1169 assert!(count > 0); }
1171
1172 #[test]
1173 fn test_brain_estimate_tokens() {
1174 let provider = Arc::new(MockLlmProvider::new());
1175 let brain = Brain::new(provider, "system");
1176 let messages = vec![Message::user("Hello, this is a test.")];
1177 let estimate = brain.estimate_tokens(&messages);
1178 assert!(estimate > 0);
1179 }
1180
1181 #[test]
1184 fn test_sanitize_removes_orphaned_tool_results() {
1185 let mut messages = vec![
1186 Message::system("You are a helper."),
1187 Message::user("do something"),
1188 Message::tool_result("call_orphan_123", "some result", false),
1190 Message::assistant("Done!"),
1191 ];
1192
1193 super::sanitize_tool_sequence(&mut messages);
1194
1195 assert_eq!(messages.len(), 3);
1197 assert_eq!(messages[0].role, Role::System);
1198 assert_eq!(messages[1].role, Role::User);
1199 assert_eq!(messages[2].role, Role::Assistant);
1200 }
1201
1202 #[test]
1203 fn test_sanitize_preserves_valid_sequence() {
1204 let mut messages = vec![
1205 Message::system("system prompt"),
1206 Message::user("read main.rs"),
1207 Message::new(
1208 Role::Assistant,
1209 Content::tool_call(
1210 "call_1",
1211 "file_read",
1212 serde_json::json!({"path": "main.rs"}),
1213 ),
1214 ),
1215 Message::tool_result("call_1", "fn main() {}", false),
1216 Message::assistant("Here is the file content."),
1217 ];
1218
1219 super::sanitize_tool_sequence(&mut messages);
1220
1221 assert_eq!(messages.len(), 5);
1223 }
1224
1225 #[test]
1226 fn test_sanitize_handles_system_between_call_and_result() {
1227 let mut messages = vec![
1228 Message::system("system prompt"),
1229 Message::user("do something"),
1230 Message::new(
1231 Role::Assistant,
1232 Content::tool_call("call_1", "file_read", serde_json::json!({"path": "x.rs"})),
1233 ),
1234 Message::system("routing hint: use file_read"),
1236 Message::tool_result("call_1", "file contents", false),
1237 Message::assistant("Done"),
1238 ];
1239
1240 super::sanitize_tool_sequence(&mut messages);
1241
1242 let assistant_idx = messages
1245 .iter()
1246 .position(|m| m.role == Role::Assistant && super::content_has_tool_call(&m.content))
1247 .unwrap();
1248
1249 let next = &messages[assistant_idx + 1];
1251 assert!(
1252 matches!(&next.content, Content::ToolResult { .. })
1253 || next.role == Role::Tool
1254 || next.role == Role::User,
1255 "Expected tool_result after tool_call, got {:?}",
1256 next.role
1257 );
1258 }
1259
1260 #[test]
1261 fn test_sanitize_multipart_tool_call() {
1262 let mut messages = vec![
1263 Message::user("do two things"),
1264 Message::new(
1265 Role::Assistant,
1266 Content::MultiPart {
1267 parts: vec![
1268 Content::text("I'll read both files."),
1269 Content::tool_call(
1270 "call_a",
1271 "file_read",
1272 serde_json::json!({"path": "a.rs"}),
1273 ),
1274 Content::tool_call(
1275 "call_b",
1276 "file_read",
1277 serde_json::json!({"path": "b.rs"}),
1278 ),
1279 ],
1280 },
1281 ),
1282 Message::tool_result("call_a", "contents of a", false),
1283 Message::tool_result("call_b", "contents of b", false),
1284 Message::tool_result("call_nonexistent", "orphan", false),
1286 ];
1287
1288 super::sanitize_tool_sequence(&mut messages);
1289
1290 assert_eq!(messages.len(), 4);
1292 }
1293
1294 #[test]
1295 fn test_sanitize_empty_messages() {
1296 let mut messages: Vec<Message> = vec![];
1297 super::sanitize_tool_sequence(&mut messages);
1298 assert!(messages.is_empty());
1299 }
1300
1301 #[test]
1302 fn test_sanitize_no_tool_messages() {
1303 let mut messages = vec![
1304 Message::system("prompt"),
1305 Message::user("hello"),
1306 Message::assistant("hi"),
1307 ];
1308 super::sanitize_tool_sequence(&mut messages);
1309 assert_eq!(messages.len(), 3);
1310 }
1311
1312 #[test]
1313 fn test_count_tool_definitions() {
1314 let counter = TokenCounter::for_model("gpt-4");
1315 let tools = vec![
1316 ToolDefinition {
1317 name: "calculator".to_string(),
1318 description: "Perform arithmetic calculations".to_string(),
1319 parameters: serde_json::json!({
1320 "type": "object",
1321 "properties": {
1322 "expression": { "type": "string", "description": "Math expression" }
1323 },
1324 "required": ["expression"]
1325 }),
1326 },
1327 ToolDefinition {
1328 name: "file_read".to_string(),
1329 description: "Read a file from the filesystem".to_string(),
1330 parameters: serde_json::json!({
1331 "type": "object",
1332 "properties": {
1333 "path": { "type": "string", "description": "File path" }
1334 },
1335 "required": ["path"]
1336 }),
1337 },
1338 ];
1339
1340 let token_count = counter.count_tool_definitions(&tools);
1341 assert!(
1344 token_count > 40,
1345 "Two tool definitions should count as >40 tokens, got {}",
1346 token_count
1347 );
1348 assert!(
1350 token_count < 500,
1351 "Two simple tool definitions should be <500 tokens, got {}",
1352 token_count
1353 );
1354 }
1355
1356 #[test]
1357 fn test_count_tool_definitions_empty() {
1358 let counter = TokenCounter::for_model("gpt-4");
1359 assert_eq!(counter.count_tool_definitions(&[]), 0);
1360 }
1361
1362 #[test]
1363 fn test_estimate_tokens_with_tools() {
1364 let provider = Arc::new(MockLlmProvider::new());
1365 let brain = Brain::new(provider, "system prompt");
1366
1367 let messages = vec![Message::user("hello")];
1368 let tools = vec![ToolDefinition {
1369 name: "echo".to_string(),
1370 description: "Echo text back".to_string(),
1371 parameters: serde_json::json!({"type": "object"}),
1372 }];
1373
1374 let without_tools = brain.estimate_tokens(&messages);
1375 let with_tools = brain.estimate_tokens_with_tools(&messages, Some(&tools));
1376 assert!(
1377 with_tools > without_tools,
1378 "Token estimate with tools ({}) should be greater than without ({})",
1379 with_tools,
1380 without_tools
1381 );
1382 }
1383}