1use crate::brain::{Brain, LlmProvider};
7use crate::config::{AgentConfig, MessagePriority};
8use crate::error::{AgentError, LlmError, RustantError, ToolError};
9use crate::explanation::{DecisionExplanation, DecisionType, ExplanationBuilder, FactorInfluence};
10use crate::memory::MemorySystem;
11use crate::safety::{
12 ActionDetails, ActionRequest, ApprovalContext, ApprovalDecision, ContractCheckResult,
13 PermissionResult, ReversibilityInfo, SafetyGuardian,
14};
15use crate::scheduler::{CronScheduler, HeartbeatManager, JobManager};
16use crate::summarizer::ContextSummarizer;
17use crate::types::{
18 AgentState, AgentStatus, CompletionResponse, Content, CostEstimate, Message, ProgressUpdate,
19 RiskLevel, Role, StreamEvent, TaskClassification, TokenUsage, ToolDefinition, ToolOutput,
20};
21use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23use std::time::Instant;
24use tokio::sync::{mpsc, oneshot};
25use tokio_util::sync::CancellationToken;
26use tracing::{debug, info, warn};
27use uuid::Uuid;
28
29fn truncate_str(s: &str, max_chars: usize) -> &str {
31 match s.char_indices().nth(max_chars) {
32 Some((idx, _)) => &s[..idx],
33 None => s,
34 }
35}
36
37pub enum AgentMessage {
39 ProcessTask {
40 task: String,
41 reply: oneshot::Sender<TaskResult>,
42 },
43 Cancel {
44 task_id: Uuid,
45 },
46 GetStatus {
47 reply: oneshot::Sender<AgentStatus>,
48 },
49 Shutdown,
50}
51
52#[derive(Debug, Clone)]
54pub struct TaskResult {
55 pub task_id: Uuid,
56 pub success: bool,
57 pub response: String,
58 pub iterations: usize,
59 pub total_usage: TokenUsage,
60 pub total_cost: CostEstimate,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum BudgetSeverity {
66 Warning,
68 Exceeded,
70}
71
72#[derive(Debug, Clone)]
74pub enum ContextHealthEvent {
75 Warning {
77 usage_percent: u8,
78 total_tokens: usize,
79 context_window: usize,
80 hint: String,
82 },
83 Critical {
85 usage_percent: u8,
86 total_tokens: usize,
87 context_window: usize,
88 hint: String,
90 },
91 Compressed {
93 messages_compressed: usize,
94 was_llm_summarized: bool,
95 pinned_preserved: usize,
96 },
97}
98
99#[async_trait::async_trait]
101pub trait AgentCallback: Send + Sync {
102 async fn on_assistant_message(&self, message: &str);
104
105 async fn on_token(&self, token: &str);
107
108 async fn request_approval(&self, action: &ActionRequest) -> ApprovalDecision;
110
111 async fn on_tool_start(&self, tool_name: &str, args: &serde_json::Value);
113
114 async fn on_tool_result(&self, tool_name: &str, output: &ToolOutput, duration_ms: u64);
116
117 async fn on_status_change(&self, status: AgentStatus);
119
120 async fn on_usage_update(&self, usage: &TokenUsage, cost: &CostEstimate);
122
123 async fn on_decision_explanation(&self, explanation: &DecisionExplanation);
125
126 async fn on_budget_warning(&self, _message: &str, _severity: BudgetSeverity) {}
129
130 async fn on_progress(&self, _progress: &ProgressUpdate) {}
133
134 async fn on_clarification_request(&self, _question: &str) -> String {
138 String::new()
139 }
140
141 async fn on_iteration_start(&self, _iteration: usize, _max_iterations: usize) {}
145
146 async fn on_cost_prediction(&self, _estimated_tokens: usize, _estimated_cost: f64) {}
150
151 async fn on_context_health(&self, _event: &ContextHealthEvent) {}
154
155 async fn on_channel_digest(&self, _digest: &serde_json::Value) {}
159
160 async fn on_channel_alert(&self, _channel: &str, _sender: &str, _summary: &str) {}
169
170 async fn on_reminder(&self, _reminder: &serde_json::Value) {}
175
176 async fn on_plan_generating(&self, _goal: &str) {}
181
182 async fn on_plan_review(
186 &self,
187 _plan: &crate::plan::ExecutionPlan,
188 ) -> crate::plan::PlanDecision {
189 crate::plan::PlanDecision::Approve
190 }
191
192 async fn on_plan_step_start(&self, _step_index: usize, _step: &crate::plan::PlanStep) {}
195
196 async fn on_plan_step_complete(&self, _step_index: usize, _step: &crate::plan::PlanStep) {}
199}
200
201pub type ToolExecutor = Box<
203 dyn Fn(
204 serde_json::Value,
205 ) -> std::pin::Pin<
206 Box<dyn std::future::Future<Output = Result<ToolOutput, ToolError>> + Send>,
207 > + Send
208 + Sync,
209>;
210
211pub struct RegisteredTool {
213 pub definition: ToolDefinition,
214 pub risk_level: RiskLevel,
215 pub executor: ToolExecutor,
216}
217
218pub struct Agent {
220 brain: Brain,
221 memory: MemorySystem,
222 safety: SafetyGuardian,
223 tools: HashMap<String, RegisteredTool>,
224 state: AgentState,
225 #[allow(dead_code)]
226 config: AgentConfig,
227 cancellation: CancellationToken,
228 callback: Arc<dyn AgentCallback>,
229 summarizer: ContextSummarizer,
231 budget: crate::brain::TokenBudgetManager,
233 knowledge: crate::memory::KnowledgeDistiller,
235 tool_token_usage: HashMap<String, usize>,
237 cron_scheduler: Option<CronScheduler>,
239 heartbeat_manager: Option<HeartbeatManager>,
241 job_manager: JobManager,
243 consecutive_failures: (String, usize),
246 recent_explanations: Vec<DecisionExplanation>,
248 plan_mode: bool,
250 current_plan: Option<crate::plan::ExecutionPlan>,
252}
253
254impl Agent {
255 pub fn new(
256 provider: Arc<dyn LlmProvider>,
257 config: AgentConfig,
258 callback: Arc<dyn AgentCallback>,
259 ) -> Self {
260 let summarizer = ContextSummarizer::new(Arc::clone(&provider));
261 let brain = Brain::new(provider, crate::brain::DEFAULT_SYSTEM_PROMPT);
262 let memory = MemorySystem::new(config.memory.window_size);
263 let safety = SafetyGuardian::new(config.safety.clone());
264 let max_iter = config.safety.max_iterations;
265 let budget = crate::brain::TokenBudgetManager::new(config.budget.as_ref());
266 let knowledge = crate::memory::KnowledgeDistiller::new(config.knowledge.as_ref());
267
268 let cron_scheduler = config.scheduler.as_ref().and_then(|sc| {
269 if sc.enabled {
270 let mut scheduler = CronScheduler::new();
271 for job_config in &sc.cron_jobs {
272 if let Err(e) = scheduler.add_job(job_config.clone()) {
273 warn!("Failed to add cron job '{}': {}", job_config.name, e);
274 }
275 }
276 Some(scheduler)
277 } else {
278 None
279 }
280 });
281 let heartbeat_manager = config.scheduler.as_ref().and_then(|sc| {
282 sc.heartbeat
283 .as_ref()
284 .map(|hb| HeartbeatManager::new(hb.clone()))
285 });
286 let max_bg_jobs = config
287 .scheduler
288 .as_ref()
289 .map(|sc| sc.max_background_jobs)
290 .unwrap_or(10);
291 let job_manager = JobManager::new(max_bg_jobs);
292 let plan_mode_enabled = config.plan.as_ref().map(|p| p.enabled).unwrap_or(false);
293
294 Self {
295 brain,
296 memory,
297 safety,
298 tools: HashMap::new(),
299 state: AgentState::new(max_iter),
300 config,
301 cancellation: CancellationToken::new(),
302 callback,
303 summarizer,
304 budget,
305 knowledge,
306 tool_token_usage: HashMap::new(),
307 cron_scheduler,
308 heartbeat_manager,
309 job_manager,
310 consecutive_failures: (String::new(), 0),
311 recent_explanations: Vec::new(),
312 plan_mode: plan_mode_enabled,
313 current_plan: None,
314 }
315 }
316
317 pub fn register_tool(&mut self, tool: RegisteredTool) {
319 self.tools.insert(tool.definition.name.clone(), tool);
320 }
321
322 fn tools_for_classification(
329 classification: &TaskClassification,
330 ) -> Option<HashSet<&'static str>> {
331 let core: [&str; 10] = [
333 "ask_user",
334 "echo",
335 "datetime",
336 "calculator",
337 "shell_exec",
338 "file_read",
339 "file_write",
340 "file_list",
341 "file_search",
342 "web_search",
343 ];
344
345 let extra: &[&str] = match classification {
346 TaskClassification::General | TaskClassification::Workflow(_) => return None,
347 TaskClassification::FileOperation => &[
348 "file_patch",
349 "smart_edit",
350 "codebase_search",
351 "document_read",
352 ],
353 TaskClassification::GitOperation => &[
354 "git_status",
355 "git_diff",
356 "git_commit",
357 "file_patch",
358 "smart_edit",
359 ],
360 TaskClassification::CodeAnalysis => &[
361 "code_intelligence",
362 "codebase_search",
363 "smart_edit",
364 "git_status",
365 "git_diff",
366 ],
367 TaskClassification::Search => &["codebase_search", "web_fetch", "smart_edit"],
368 TaskClassification::WebSearch => &["web_fetch"],
369 TaskClassification::WebFetch => &["web_fetch", "http_api"],
370 TaskClassification::Calendar => &["macos_calendar", "macos_notification"],
371 TaskClassification::Reminders => &["macos_reminders", "macos_notification"],
372 TaskClassification::Notes => &["macos_notes"],
373 TaskClassification::Email => &["macos_mail", "macos_notification"],
374 TaskClassification::Music => &["macos_music"],
375 TaskClassification::AppControl => &[
376 "macos_app_control",
377 "macos_gui_scripting",
378 "macos_accessibility",
379 "macos_screen_analyze",
380 ],
381 TaskClassification::Clipboard => &["macos_clipboard"],
382 TaskClassification::Screenshot => &["macos_screenshot"],
383 TaskClassification::SystemInfo => &["macos_system_info"],
384 TaskClassification::Contacts => &["macos_contacts", "imessage_contacts"],
385 TaskClassification::Safari => &["macos_safari", "web_fetch"],
386 TaskClassification::HomeKit => &["homekit"],
387 TaskClassification::Photos => &["photos"],
388 TaskClassification::Voice => &["macos_say"],
389 TaskClassification::Meeting => &[
390 "macos_meeting_recorder",
391 "macos_notes",
392 "macos_notification",
393 ],
394 TaskClassification::DailyBriefing => &[
395 "macos_daily_briefing",
396 "macos_calendar",
397 "macos_reminders",
398 "macos_mail",
399 "macos_notes",
400 ],
401 TaskClassification::GuiScripting => &[
402 "macos_gui_scripting",
403 "macos_accessibility",
404 "macos_screen_analyze",
405 "macos_app_control",
406 ],
407 TaskClassification::Accessibility => &[
408 "macos_accessibility",
409 "macos_gui_scripting",
410 "macos_screen_analyze",
411 ],
412 TaskClassification::Browser => &[
413 "browser_navigate",
414 "browser_click",
415 "browser_type",
416 "browser_screenshot",
417 "web_fetch",
418 ],
419 TaskClassification::Messaging => {
420 &["imessage_read", "imessage_send", "imessage_contacts"]
421 }
422 TaskClassification::Slack => &["slack"],
423 TaskClassification::ArxivResearch => {
424 &["arxiv_research", "knowledge_graph", "web_fetch"]
425 }
426 TaskClassification::KnowledgeGraph => &["knowledge_graph"],
427 TaskClassification::ExperimentTracking => &["experiment_tracker"],
428 TaskClassification::CodeIntelligence => {
429 &["code_intelligence", "codebase_search", "smart_edit"]
430 }
431 TaskClassification::ContentEngine => &["content_engine"],
432 TaskClassification::SkillTracker => &["skill_tracker"],
433 TaskClassification::CareerIntel => &["career_intel"],
434 TaskClassification::SystemMonitor => &["system_monitor"],
435 TaskClassification::LifePlanner => &["life_planner"],
436 TaskClassification::PrivacyManager => &["privacy_manager"],
437 TaskClassification::SelfImprovement => &["self_improvement"],
438 TaskClassification::Notification => &["macos_notification"],
439 TaskClassification::Spotlight => &["macos_spotlight"],
440 TaskClassification::FocusMode => &["macos_focus_mode"],
441 TaskClassification::Finder => &["macos_finder"],
442 };
443
444 let mut set: HashSet<&str> = core.into_iter().collect();
445 set.extend(extra.iter().copied());
446 Some(set)
447 }
448
449 pub fn tool_definitions(
455 &self,
456 classification: Option<&TaskClassification>,
457 ) -> Vec<ToolDefinition> {
458 let allowed = classification.and_then(Self::tools_for_classification);
459
460 let mut defs: Vec<ToolDefinition> = if let Some(ref allowed_set) = allowed {
461 self.tools
462 .values()
463 .filter(|t| allowed_set.contains(t.definition.name.as_str()))
464 .map(|t| t.definition.clone())
465 .collect()
466 } else {
467 self.tools.values().map(|t| t.definition.clone()).collect()
468 };
469
470 let tool_count = defs.len();
471 let total_registered = self.tools.len();
472 if tool_count < total_registered {
473 debug!(
474 filtered = tool_count,
475 total = total_registered,
476 classification = ?classification,
477 "Filtered tool definitions for LLM request"
478 );
479 }
480
481 defs.push(ToolDefinition {
483 name: "ask_user".to_string(),
484 description: "Ask the user a clarifying question when you need more information to proceed. Use this when the task is ambiguous or you need to confirm something before taking action.".to_string(),
485 parameters: serde_json::json!({
486 "type": "object",
487 "properties": {
488 "question": {
489 "type": "string",
490 "description": "The question to ask the user"
491 }
492 },
493 "required": ["question"]
494 }),
495 });
496
497 defs
498 }
499
500 pub async fn process_task(&mut self, task: &str) -> Result<TaskResult, RustantError> {
502 if self.plan_mode {
504 return self.process_task_with_plan(task).await;
505 }
506
507 let task_id = Uuid::new_v4();
508 info!(task_id = %task_id, task = task, "Starting task processing");
509
510 self.state.start_task(task);
511 self.state.task_id = Some(task_id);
512 self.memory.start_new_task(task);
513 self.budget.reset_task();
514 self.tool_token_usage.clear();
515
516 self.knowledge.distill(&self.memory.long_term);
518 let mut knowledge_addendum = self.knowledge.rules_for_prompt();
519
520 if let Some(ref classification) = self.state.task_classification
525 && let Some(hint) = Self::tool_routing_hint_from_classification(classification)
526 {
527 knowledge_addendum.push_str("\n\n");
528 knowledge_addendum.push_str(&hint);
529 }
530 self.brain.set_knowledge_addendum(knowledge_addendum);
531
532 self.memory.add_message(Message::user(task));
533
534 self.callback.on_status_change(AgentStatus::Thinking).await;
535
536 let mut final_response = String::new();
537
538 loop {
539 if self.cancellation.is_cancelled() {
541 self.state.set_error();
542 return Err(RustantError::Agent(AgentError::Cancelled));
543 }
544
545 if !self.state.increment_iteration() {
547 warn!(
548 task_id = %task_id,
549 iterations = self.state.iteration,
550 "Maximum iterations reached"
551 );
552 self.state.set_error();
553 return Err(RustantError::Agent(AgentError::MaxIterationsReached {
554 max: self.state.max_iterations,
555 }));
556 }
557
558 debug!(
559 task_id = %task_id,
560 iteration = self.state.iteration,
561 "Agent loop iteration"
562 );
563
564 self.callback
566 .on_iteration_start(self.state.iteration, self.state.max_iterations)
567 .await;
568
569 self.state.status = AgentStatus::Thinking;
571 self.callback.on_status_change(AgentStatus::Thinking).await;
572
573 let conversation = self.memory.context_messages();
574 let tools = Some(self.tool_definitions(self.state.task_classification.as_ref()));
575
576 {
578 let context_window = self.brain.provider().context_window();
579 let breakdown = self.memory.context_breakdown(context_window);
580 let usage_percent = (breakdown.usage_ratio() * 100.0) as u8;
581 if usage_percent >= 90 {
582 self.callback
583 .on_context_health(&ContextHealthEvent::Critical {
584 usage_percent,
585 total_tokens: breakdown.total_tokens,
586 context_window: breakdown.context_window,
587 hint: "Context nearly full — auto-compression imminent. Use /pin to protect important messages.".to_string(),
588 })
589 .await;
590 } else if usage_percent >= 70 {
591 self.callback
592 .on_context_health(&ContextHealthEvent::Warning {
593 usage_percent,
594 total_tokens: breakdown.total_tokens,
595 context_window: breakdown.context_window,
596 hint: "Context filling up. Use /compact to compress now, or /pin to protect key messages.".to_string(),
597 })
598 .await;
599 }
600 }
601
602 let estimated_tokens = self
604 .brain
605 .estimate_tokens_with_tools(&conversation, tools.as_deref());
606 let (input_rate, output_rate) = self.brain.provider_cost_rates();
607 let budget_result = self
608 .budget
609 .check_budget(estimated_tokens, input_rate, output_rate);
610 match &budget_result {
611 crate::brain::BudgetCheckResult::Exceeded { message } => {
612 let top = self.top_tool_consumers(3);
613 let enriched = if top.is_empty() {
614 message.clone()
615 } else {
616 format!("{}. Top consumers: {}", message, top)
617 };
618 self.callback
619 .on_budget_warning(&enriched, BudgetSeverity::Exceeded)
620 .await;
621 if self.budget.should_halt_on_exceed() {
622 warn!("Budget exceeded, halting: {}", enriched);
623 return Err(RustantError::Agent(AgentError::BudgetExceeded {
624 message: enriched,
625 }));
626 }
627 warn!("Budget warning (soft limit): {}", enriched);
628 }
629 crate::brain::BudgetCheckResult::Warning { message, .. } => {
630 let top = self.top_tool_consumers(3);
631 let enriched = if top.is_empty() {
632 message.clone()
633 } else {
634 format!("{}. Top consumers: {}", message, top)
635 };
636 self.callback
637 .on_budget_warning(&enriched, BudgetSeverity::Warning)
638 .await;
639 debug!("Budget warning: {}", enriched);
640 }
641 crate::brain::BudgetCheckResult::Ok => {}
642 }
643
644 {
646 let est_tokens = estimated_tokens + 500; let est_cost = est_tokens as f64 * input_rate;
648 if est_cost > 0.05 {
649 self.callback.on_cost_prediction(est_tokens, est_cost).await;
650 }
651 }
652
653 let response = if self.config.llm.use_streaming {
654 self.think_streaming(&conversation, tools).await?
655 } else {
656 self.brain.think_with_retry(&conversation, tools, 3).await?
657 };
658
659 self.budget.record_usage(
661 &response.usage,
662 &CostEstimate {
663 input_cost: response.usage.input_tokens as f64 * input_rate,
664 output_cost: response.usage.output_tokens as f64 * output_rate,
665 },
666 );
667 self.callback
668 .on_usage_update(self.brain.total_usage(), self.brain.total_cost())
669 .await;
670
671 self.state.status = AgentStatus::Deciding;
673 match &response.message.content {
674 Content::Text { text } => {
675 info!(task_id = %task_id, "Agent produced text response");
677 self.callback.on_assistant_message(text).await;
678 self.memory.add_message(response.message.clone());
679 final_response = text.clone();
680 break;
682 }
683 Content::ToolCall {
684 id,
685 name,
686 arguments,
687 } => {
688 info!(
690 task_id = %task_id,
691 tool = name,
692 "Agent requesting tool execution"
693 );
694 self.memory.add_message(response.message.clone());
695
696 let explanation = self.build_decision_explanation(name, arguments);
698 self.callback.on_decision_explanation(&explanation).await;
699 self.record_explanation(explanation);
700
701 let (actual_name, actual_args) = if let Some((corrected_name, corrected_args)) =
708 Self::auto_correct_tool_call(name, arguments, &self.state)
709 {
710 if corrected_name != *name {
711 info!(
712 original_tool = name,
713 corrected_tool = corrected_name,
714 "Auto-routing to correct macOS tool"
715 );
716 self.callback
717 .on_assistant_message(&format!(
718 "[Routed: {} → {}]",
719 name, corrected_name
720 ))
721 .await;
722 (corrected_name, corrected_args)
723 } else {
724 (name.to_string(), arguments.clone())
725 }
726 } else {
727 (name.to_string(), arguments.clone())
728 };
729
730 let result = self.execute_tool(id, &actual_name, &actual_args).await;
732 if let Err(ref e) = result {
733 debug!(tool = %actual_name, error = %e, "Tool execution failed");
734 }
735
736 let result_tokens = match &result {
738 Ok(output) => {
739 let result_msg = Message::tool_result(id, &output.content, false);
740 let tokens = output.content.len() / 4; self.memory.add_message(result_msg);
742 tokens
743 }
744 Err(e) => {
745 let error_msg = format!("Tool error: {}", e);
746 let tokens = error_msg.len() / 4;
747 let result_msg = Message::tool_result(id, &error_msg, true);
748 self.memory.add_message(result_msg);
749 tokens
750 }
751 };
752 *self.tool_token_usage.entry(name.to_string()).or_insert(0) += result_tokens;
753
754 if result.is_err() {
756 if self.consecutive_failures.0 == *name {
757 self.consecutive_failures.1 += 1;
758 } else {
759 self.consecutive_failures = (name.to_string(), 1);
760 }
761 } else {
762 self.consecutive_failures = (String::new(), 0);
763 }
764
765 self.check_and_compress().await;
767
768 }
770 Content::MultiPart { parts } => {
771 self.memory.add_message(response.message.clone());
773
774 let mut has_tool_call = false;
775 for part in parts {
776 match part {
777 Content::Text { text } => {
778 self.callback.on_assistant_message(text).await;
779 final_response = text.clone();
780 }
781 Content::ToolCall {
782 id,
783 name,
784 arguments,
785 } => {
786 has_tool_call = true;
787
788 let explanation = self.build_decision_explanation(name, arguments);
790 self.callback.on_decision_explanation(&explanation).await;
791 self.record_explanation(explanation);
792
793 let (actual_name, actual_args) = if let Some((cn, ca)) =
795 Self::auto_correct_tool_call(name, arguments, &self.state)
796 {
797 if cn != *name {
798 info!(
799 original_tool = name,
800 corrected_tool = cn,
801 "Auto-routing to correct macOS tool (multipart)"
802 );
803 self.callback
804 .on_assistant_message(&format!(
805 "[Routed: {} → {}]",
806 name, cn
807 ))
808 .await;
809 (cn, ca)
810 } else {
811 (name.to_string(), arguments.clone())
812 }
813 } else {
814 (name.to_string(), arguments.clone())
815 };
816
817 let result =
818 self.execute_tool(id, &actual_name, &actual_args).await;
819 let result_tokens = match &result {
820 Ok(output) => {
821 let msg = Message::tool_result(id, &output.content, false);
822 let tokens = output.content.len() / 4;
823 self.memory.add_message(msg);
824 tokens
825 }
826 Err(e) => {
827 let error_msg = format!("Tool error: {}", e);
828 let tokens = error_msg.len() / 4;
829 let msg = Message::tool_result(id, &error_msg, true);
830 self.memory.add_message(msg);
831 tokens
832 }
833 };
834
835 if result.is_err() {
837 if self.consecutive_failures.0 == *name {
838 self.consecutive_failures.1 += 1;
839 } else {
840 self.consecutive_failures = (name.to_string(), 1);
841 }
842 } else {
843 self.consecutive_failures = (String::new(), 0);
844 }
845 *self.tool_token_usage.entry(name.to_string()).or_insert(0) +=
846 result_tokens;
847 }
848 _ => {}
849 }
850 }
851
852 if !has_tool_call {
853 break; }
855
856 self.check_and_compress().await;
858
859 }
861 Content::ToolResult { .. } => {
862 warn!("Received unexpected ToolResult from LLM");
864 break;
865 }
866 }
867 }
868
869 self.state.complete();
870 self.callback.on_status_change(AgentStatus::Complete).await;
871
872 info!(
873 task_id = %task_id,
874 iterations = self.state.iteration,
875 total_tokens = self.brain.total_usage().total(),
876 total_cost = format!("${:.4}", self.brain.total_cost().total()),
877 "Task completed"
878 );
879
880 Ok(TaskResult {
881 task_id,
882 success: true,
883 response: final_response,
884 iterations: self.state.iteration,
885 total_usage: *self.brain.total_usage(),
886 total_cost: *self.brain.total_cost(),
887 })
888 }
889
890 async fn think_streaming(
895 &mut self,
896 conversation: &[Message],
897 tools: Option<Vec<ToolDefinition>>,
898 ) -> Result<CompletionResponse, LlmError> {
899 const MAX_RETRIES: usize = 3;
900 let mut last_error: Option<LlmError> = None;
901
902 for attempt in 0..=MAX_RETRIES {
903 match self.think_streaming_once(conversation, tools.clone()).await {
904 Ok(response) => return Ok(response),
905 Err(e) if Self::is_streaming_retryable(&e) => {
906 if attempt < MAX_RETRIES {
907 let backoff_secs = std::cmp::min(1u64 << attempt, 32);
908 let wait = match &e {
909 LlmError::RateLimited { retry_after_secs } => {
910 std::cmp::max(*retry_after_secs, backoff_secs)
911 }
912 _ => backoff_secs,
913 };
914 info!(
915 attempt = attempt + 1,
916 max_retries = MAX_RETRIES,
917 backoff_secs = wait,
918 error = %e,
919 "Retrying streaming after transient error"
920 );
921 self.callback
922 .on_token(&format!("\n[Retrying in {}s due to: {}]\n", wait, e))
923 .await;
924 tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
925 last_error = Some(e);
926 } else {
927 return Err(e);
928 }
929 }
930 Err(e) => return Err(e),
931 }
932 }
933
934 Err(last_error.unwrap_or(LlmError::Connection {
935 message: "Max streaming retries exceeded".to_string(),
936 }))
937 }
938
939 fn is_streaming_retryable(error: &LlmError) -> bool {
941 if Brain::is_retryable(error) {
942 return true;
943 }
944 if let LlmError::Streaming { message } = error {
946 let msg = message.to_lowercase();
947 return msg.contains("rate limit")
948 || msg.contains("429")
949 || msg.contains("timeout")
950 || msg.contains("timed out")
951 || msg.contains("connection")
952 || msg.contains("temporarily unavailable")
953 || msg.contains("503")
954 || msg.contains("502");
955 }
956 false
957 }
958
959 async fn think_streaming_once(
961 &mut self,
962 conversation: &[Message],
963 tools: Option<Vec<ToolDefinition>>,
964 ) -> Result<CompletionResponse, LlmError> {
965 let (tx, mut rx) = mpsc::channel(64);
966
967 let messages = self.brain.build_messages(conversation);
969 let token_estimate = self.brain.provider().estimate_tokens(&messages);
970 let context_limit = self.brain.provider().context_window();
971
972 if token_estimate > context_limit {
973 return Err(LlmError::ContextOverflow {
974 used: token_estimate,
975 limit: context_limit,
976 });
977 }
978
979 let request = crate::types::CompletionRequest {
980 messages,
981 tools,
982 temperature: 0.7,
983 max_tokens: None,
984 stop_sequences: Vec::new(),
985 model: None,
986 };
987
988 let provider = self.brain.provider_arc();
993 let producer = tokio::spawn(async move { provider.complete_streaming(request, tx).await });
994
995 let mut text_parts = String::new();
997 let mut usage = TokenUsage::default();
998 let mut tool_calls: std::collections::HashMap<String, (String, String)> =
1000 std::collections::HashMap::new();
1001 let mut tool_call_order: Vec<String> = Vec::new(); let mut raw_function_calls: std::collections::HashMap<String, serde_json::Value> =
1004 std::collections::HashMap::new();
1005
1006 while let Some(event) = rx.recv().await {
1007 match event {
1008 StreamEvent::Token(token) => {
1009 self.callback.on_token(&token).await;
1010 text_parts.push_str(&token);
1011 }
1012 StreamEvent::ToolCallStart {
1013 id,
1014 name,
1015 raw_function_call,
1016 } => {
1017 tool_call_order.push(id.clone());
1018 tool_calls.insert(id.clone(), (name, String::new()));
1019 if let Some(raw_fc) = raw_function_call {
1020 raw_function_calls.insert(id, raw_fc);
1021 }
1022 }
1023 StreamEvent::ToolCallDelta {
1024 id,
1025 arguments_delta,
1026 } => {
1027 if let Some((_, args)) = tool_calls.get_mut(&id) {
1028 args.push_str(&arguments_delta);
1029 }
1030 }
1031 StreamEvent::ToolCallEnd { id: _ } => {
1032 }
1034 StreamEvent::Done { usage: u } => {
1035 usage = u;
1036 break;
1037 }
1038 StreamEvent::Error(e) => {
1039 return Err(LlmError::Streaming { message: e });
1040 }
1041 }
1042 }
1043
1044 producer.await.map_err(|e| LlmError::Streaming {
1046 message: format!("Streaming task panicked: {}", e),
1047 })??;
1048
1049 self.brain.track_usage(&usage);
1051
1052 let raw_parts_metadata = if !raw_function_calls.is_empty() {
1055 let mut raw_parts = Vec::new();
1056 if !text_parts.is_empty() {
1057 raw_parts.push(serde_json::json!({"text": &text_parts}));
1058 }
1059 for id in &tool_call_order {
1060 if let Some(raw_fc) = raw_function_calls.get(id) {
1061 raw_parts.push(raw_fc.clone());
1062 }
1063 }
1064 Some(serde_json::Value::Array(raw_parts))
1065 } else {
1066 None
1067 };
1068
1069 let content = if !tool_call_order.is_empty() {
1071 let first_id = &tool_call_order[0];
1073 if let Some((name, args_str)) = tool_calls.get(first_id) {
1074 let arguments: serde_json::Value =
1075 serde_json::from_str(args_str).unwrap_or(serde_json::json!({}));
1076 if text_parts.is_empty() {
1077 Content::tool_call(first_id, name, arguments)
1078 } else {
1079 Content::MultiPart {
1080 parts: vec![
1081 Content::text(&text_parts),
1082 Content::tool_call(first_id, name, arguments),
1083 ],
1084 }
1085 }
1086 } else {
1087 Content::text(text_parts)
1088 }
1089 } else {
1090 Content::text(text_parts)
1091 };
1092 let finish_reason = if tool_call_order.is_empty() {
1093 "stop"
1094 } else {
1095 "tool_calls"
1096 };
1097
1098 let mut message = Message::new(Role::Assistant, content);
1099
1100 if let Some(raw_parts) = raw_parts_metadata {
1103 message = message.with_metadata("gemini_raw_parts", raw_parts);
1104 }
1105
1106 Ok(CompletionResponse {
1107 message,
1108 usage,
1109 model: self.brain.model_name().to_string(),
1110 finish_reason: Some(finish_reason.to_string()),
1111 })
1112 }
1113
1114 async fn execute_tool(
1116 &mut self,
1117 _call_id: &str,
1118 tool_name: &str,
1119 arguments: &serde_json::Value,
1120 ) -> Result<ToolOutput, ToolError> {
1121 if tool_name == "ask_user" {
1124 self.state.status = AgentStatus::WaitingForClarification;
1125 self.callback
1126 .on_status_change(AgentStatus::WaitingForClarification)
1127 .await;
1128 let question = arguments
1129 .get("question")
1130 .and_then(|v| v.as_str())
1131 .unwrap_or("Can you provide more details?");
1132 let answer = self.callback.on_clarification_request(question).await;
1133 self.state.status = AgentStatus::Executing;
1134 self.callback.on_status_change(AgentStatus::Executing).await;
1135 return Ok(ToolOutput::text(answer));
1136 }
1137
1138 let tool = self
1140 .tools
1141 .get(tool_name)
1142 .ok_or_else(|| ToolError::NotFound {
1143 name: tool_name.to_string(),
1144 })?;
1145
1146 let details = Self::parse_action_details(tool_name, arguments);
1148 let approval_context = Self::build_approval_context(tool_name, &details, tool.risk_level);
1149
1150 let action = SafetyGuardian::create_rich_action_request(
1152 tool_name,
1153 tool.risk_level,
1154 format!("Execute tool: {}", tool_name),
1155 details,
1156 approval_context,
1157 );
1158
1159 let perm = self.safety.check_permission(&action);
1161 match perm {
1162 PermissionResult::Allowed => {
1163 }
1165 PermissionResult::Denied { reason } => {
1166 let mut builder = ExplanationBuilder::new(DecisionType::ErrorRecovery {
1168 error: format!("Permission denied for tool '{}'", tool_name),
1169 strategy: "Returning error to LLM for re-planning".to_string(),
1170 });
1171 builder.add_reasoning_step(format!("Denied: {}", reason), None);
1172 builder.set_confidence(1.0);
1173 let explanation = builder.build();
1174 self.callback.on_decision_explanation(&explanation).await;
1175 self.record_explanation(explanation);
1176
1177 return Err(ToolError::PermissionDenied {
1178 name: tool_name.to_string(),
1179 reason,
1180 });
1181 }
1182 PermissionResult::RequiresApproval { context: _ } => {
1183 self.state.status = AgentStatus::WaitingForApproval;
1184 self.callback
1185 .on_status_change(AgentStatus::WaitingForApproval)
1186 .await;
1187
1188 let decision = self.callback.request_approval(&action).await;
1189 let approved = decision != ApprovalDecision::Deny;
1190 self.safety.log_approval_decision(tool_name, approved);
1191
1192 match decision {
1193 ApprovalDecision::Approve => {
1194 }
1196 ApprovalDecision::ApproveAllSimilar => {
1197 self.safety
1199 .add_session_allowlist(tool_name.to_string(), tool.risk_level);
1200 info!(
1201 tool = tool_name,
1202 risk = %tool.risk_level,
1203 "Added tool to session allowlist (approve all similar)"
1204 );
1205 }
1206 ApprovalDecision::Deny => {
1207 let mut builder = ExplanationBuilder::new(DecisionType::ErrorRecovery {
1209 error: format!("User denied approval for tool '{}'", tool_name),
1210 strategy: "Returning error to LLM for re-planning".to_string(),
1211 });
1212 builder.add_reasoning_step(
1213 "User rejected the action in approval dialog".to_string(),
1214 None,
1215 );
1216 builder.set_confidence(1.0);
1217 let explanation = builder.build();
1218 self.callback.on_decision_explanation(&explanation).await;
1219 self.record_explanation(explanation);
1220
1221 self.memory.long_term.add_correction(
1224 format!(
1225 "Attempted tool '{}' with args: {}",
1226 tool_name,
1227 arguments.to_string().chars().take(200).collect::<String>()
1228 ),
1229 "User denied this action".to_string(),
1230 format!(
1231 "Tool '{}' denied by user; goal: {:?}",
1232 tool_name, self.memory.working.current_goal
1233 ),
1234 );
1235
1236 return Err(ToolError::PermissionDenied {
1237 name: tool_name.to_string(),
1238 reason: "User rejected the action".to_string(),
1239 });
1240 }
1241 }
1242 }
1243 }
1244
1245 let tool_entry = self
1247 .tools
1248 .get(tool_name)
1249 .ok_or_else(|| ToolError::NotFound {
1250 name: tool_name.to_string(),
1251 })?;
1252 let risk_level = tool_entry.risk_level;
1253 let contract_result = self
1254 .safety
1255 .contract_enforcer_mut()
1256 .check_pre(tool_name, risk_level, arguments);
1257 if contract_result != ContractCheckResult::Satisfied {
1258 warn!(
1259 tool = tool_name,
1260 result = ?contract_result,
1261 "Safety contract violation (pre-check)"
1262 );
1263
1264 let mut builder = ExplanationBuilder::new(DecisionType::ErrorRecovery {
1266 error: format!("Contract violation: {:?}", contract_result),
1267 strategy: "Returning error to LLM for re-planning".to_string(),
1268 });
1269 builder.set_confidence(1.0);
1270 let explanation = builder.build();
1271 self.callback.on_decision_explanation(&explanation).await;
1272 self.record_explanation(explanation);
1273
1274 return Err(ToolError::PermissionDenied {
1275 name: tool_name.to_string(),
1276 reason: format!("Safety contract violation: {:?}", contract_result),
1277 });
1278 }
1279
1280 self.state.status = AgentStatus::Executing;
1282 self.callback.on_status_change(AgentStatus::Executing).await;
1283 self.callback.on_tool_start(tool_name, arguments).await;
1284
1285 let start = Instant::now();
1286
1287 let executor = &self
1289 .tools
1290 .get(tool_name)
1291 .ok_or_else(|| ToolError::NotFound {
1292 name: tool_name.to_string(),
1293 })?
1294 .executor;
1295 let result = (executor)(arguments.clone()).await;
1296 let duration_ms = start.elapsed().as_millis() as u64;
1297
1298 self.safety
1300 .contract_enforcer_mut()
1301 .record_execution(risk_level, 0.0);
1302
1303 match &result {
1304 Ok(output) => {
1305 self.safety.log_execution(tool_name, true, duration_ms);
1306 self.safety
1307 .record_behavioral_outcome(tool_name, risk_level, true);
1308 self.callback
1309 .on_tool_result(tool_name, output, duration_ms)
1310 .await;
1311
1312 if output.content.len() > 10 && output.content.len() < 5000 {
1316 let summary = if output.content.chars().count() > 200 {
1317 format!("{}...", truncate_str(&output.content, 200))
1318 } else {
1319 output.content.clone()
1320 };
1321 self.memory.long_term.add_fact(
1322 crate::memory::Fact::new(
1323 format!("Tool '{}' result: {}", tool_name, summary),
1324 format!("tool:{}", tool_name),
1325 )
1326 .with_tags(vec!["tool_result".to_string(), tool_name.to_string()]),
1327 );
1328 }
1329 }
1330 Err(e) => {
1331 self.safety.log_execution(tool_name, false, duration_ms);
1332 self.safety
1333 .record_behavioral_outcome(tool_name, risk_level, false);
1334 let error_output = ToolOutput::error(e.to_string());
1335 self.callback
1336 .on_tool_result(tool_name, &error_output, duration_ms)
1337 .await;
1338 }
1339 }
1340
1341 result
1342 }
1343
1344 fn record_explanation(&mut self, explanation: DecisionExplanation) {
1346 if self.recent_explanations.len() >= 50 {
1347 self.recent_explanations.remove(0);
1348 }
1349 self.recent_explanations.push(explanation);
1350 }
1351
1352 fn build_approval_context(
1355 tool_name: &str,
1356 details: &ActionDetails,
1357 risk_level: RiskLevel,
1358 ) -> ApprovalContext {
1359 let mut ctx = ApprovalContext::new();
1360
1361 match details {
1363 ActionDetails::FileWrite { path, size_bytes } => {
1364 ctx = ctx
1365 .with_reasoning(format!(
1366 "Writing {} bytes to {}",
1367 size_bytes,
1368 path.display()
1369 ))
1370 .with_consequence(format!(
1371 "File '{}' will be created or overwritten",
1372 path.display()
1373 ))
1374 .with_reversibility(ReversibilityInfo {
1375 is_reversible: true,
1376 undo_description: Some(
1377 "Revert via git checkout or checkpoint restore".to_string(),
1378 ),
1379 undo_window: None,
1380 });
1381 }
1382 ActionDetails::FileDelete { path } => {
1383 ctx = ctx
1384 .with_reasoning(format!("Deleting file {}", path.display()))
1385 .with_consequence(format!(
1386 "File '{}' will be permanently removed",
1387 path.display()
1388 ))
1389 .with_reversibility(ReversibilityInfo {
1390 is_reversible: true,
1391 undo_description: Some(
1392 "Restore via git checkout or checkpoint".to_string(),
1393 ),
1394 undo_window: None,
1395 });
1396 }
1397 ActionDetails::ShellCommand { command } => {
1398 ctx = ctx
1399 .with_reasoning(format!("Executing shell command: {}", command))
1400 .with_consequence("Shell command will run in the agent workspace".to_string());
1401 if risk_level >= RiskLevel::Execute {
1402 ctx = ctx.with_consequence(
1403 "Command may modify system state or produce side effects".to_string(),
1404 );
1405 }
1406 }
1407 ActionDetails::NetworkRequest { host, method } => {
1408 ctx = ctx
1409 .with_reasoning(format!("Making {} request to {}", method, host))
1410 .with_consequence(format!("Network request will be sent to {}", host));
1411 }
1412 ActionDetails::GitOperation { operation } => {
1413 ctx = ctx
1414 .with_reasoning(format!("Git operation: {}", operation))
1415 .with_reversibility(ReversibilityInfo {
1416 is_reversible: true,
1417 undo_description: Some(
1418 "Git operations are generally reversible via reflog".to_string(),
1419 ),
1420 undo_window: None,
1421 });
1422 }
1423 _ => {
1424 ctx = ctx.with_reasoning(format!("Executing {} tool", tool_name));
1425 }
1426 }
1427
1428 ctx = ctx.with_preview_from_tool(tool_name, details);
1430
1431 ctx
1432 }
1433
1434 fn parse_action_details(tool_name: &str, arguments: &serde_json::Value) -> ActionDetails {
1438 match tool_name {
1439 "file_read" | "file_list" | "file_search" => {
1440 if let Some(path) = arguments.get("path").and_then(|v| v.as_str()) {
1441 ActionDetails::FileRead { path: path.into() }
1442 } else {
1443 ActionDetails::Other {
1444 info: arguments.to_string(),
1445 }
1446 }
1447 }
1448 "file_write" | "file_patch" => {
1449 let path = arguments
1450 .get("path")
1451 .and_then(|v| v.as_str())
1452 .unwrap_or("unknown");
1453 let size = arguments
1454 .get("content")
1455 .and_then(|v| v.as_str())
1456 .map(|s| s.len())
1457 .unwrap_or(0);
1458 ActionDetails::FileWrite {
1459 path: path.into(),
1460 size_bytes: size,
1461 }
1462 }
1463 "shell_exec" => {
1464 let cmd = arguments
1465 .get("command")
1466 .and_then(|v| v.as_str())
1467 .unwrap_or("(unknown)");
1468 ActionDetails::ShellCommand {
1469 command: cmd.to_string(),
1470 }
1471 }
1472 "git_status" | "git_diff" => ActionDetails::GitOperation {
1473 operation: tool_name.to_string(),
1474 },
1475 "git_commit" => {
1476 let msg = arguments
1477 .get("message")
1478 .and_then(|v| v.as_str())
1479 .unwrap_or("");
1480 let truncated = truncate_str(msg, 80);
1481 ActionDetails::GitOperation {
1482 operation: format!("commit: {}", truncated),
1483 }
1484 }
1485 "macos_calendar" | "macos_reminders" | "macos_notes" => {
1487 let action = arguments
1488 .get("action")
1489 .and_then(|v| v.as_str())
1490 .unwrap_or("list");
1491 let title = arguments
1492 .get("title")
1493 .and_then(|v| v.as_str())
1494 .unwrap_or("");
1495 ActionDetails::Other {
1496 info: format!("{} {} {}", tool_name, action, title)
1497 .trim()
1498 .to_string(),
1499 }
1500 }
1501 "macos_app_control" => {
1502 let action = arguments
1503 .get("action")
1504 .and_then(|v| v.as_str())
1505 .unwrap_or("list_running");
1506 let app = arguments
1507 .get("app_name")
1508 .and_then(|v| v.as_str())
1509 .unwrap_or("");
1510 ActionDetails::ShellCommand {
1511 command: format!("{} {}", action, app).trim().to_string(),
1512 }
1513 }
1514 "macos_clipboard" => {
1515 let action = arguments
1516 .get("action")
1517 .and_then(|v| v.as_str())
1518 .unwrap_or("read");
1519 ActionDetails::Other {
1520 info: format!("clipboard {}", action),
1521 }
1522 }
1523 "macos_screenshot" => {
1524 let path = arguments
1525 .get("path")
1526 .and_then(|v| v.as_str())
1527 .unwrap_or("screenshot.png");
1528 ActionDetails::FileWrite {
1529 path: path.into(),
1530 size_bytes: 0,
1531 }
1532 }
1533 "macos_finder" => {
1534 let action = arguments
1535 .get("action")
1536 .and_then(|v| v.as_str())
1537 .unwrap_or("reveal");
1538 let path = arguments
1539 .get("path")
1540 .and_then(|v| v.as_str())
1541 .unwrap_or(".");
1542 if action == "trash" {
1543 ActionDetails::FileDelete { path: path.into() }
1544 } else {
1545 ActionDetails::Other {
1546 info: format!("Finder: {} {}", action, path),
1547 }
1548 }
1549 }
1550 "macos_notification" | "macos_system_info" | "macos_spotlight" => {
1551 ActionDetails::Other {
1552 info: arguments
1553 .as_object()
1554 .map(|o| {
1555 o.iter()
1556 .map(|(k, v)| {
1557 format!("{}={}", k, v.as_str().unwrap_or(&v.to_string()))
1558 })
1559 .collect::<Vec<_>>()
1560 .join(", ")
1561 })
1562 .unwrap_or_default(),
1563 }
1564 }
1565 "macos_mail" => {
1566 let action = arguments["action"]
1567 .as_str()
1568 .unwrap_or("unknown")
1569 .to_string();
1570 if action == "send" {
1571 let to = arguments["to"].as_str().unwrap_or("unknown").to_string();
1572 let subject = arguments["subject"]
1573 .as_str()
1574 .unwrap_or("(no subject)")
1575 .to_string();
1576 ActionDetails::Other {
1577 info: format!("SEND EMAIL to {} — subject: {}", to, subject),
1578 }
1579 } else {
1580 ActionDetails::Other {
1581 info: format!("macos_mail: {}", action),
1582 }
1583 }
1584 }
1585 "macos_safari" => {
1586 let action = arguments["action"]
1587 .as_str()
1588 .unwrap_or("unknown")
1589 .to_string();
1590 if action == "run_javascript" {
1591 ActionDetails::ShellCommand {
1592 command: format!(
1593 "Safari JS: {}",
1594 arguments["script"].as_str().unwrap_or("(unknown)")
1595 ),
1596 }
1597 } else if action == "navigate" {
1598 ActionDetails::BrowserAction {
1599 action: "navigate".to_string(),
1600 url: arguments["url"].as_str().map(|s| s.to_string()),
1601 selector: None,
1602 }
1603 } else {
1604 ActionDetails::Other {
1605 info: format!("macos_safari: {}", action),
1606 }
1607 }
1608 }
1609 "macos_screen_analyze" => {
1610 let action = arguments["action"].as_str().unwrap_or("ocr").to_string();
1611 let app = arguments["app_name"]
1612 .as_str()
1613 .map(|s| s.to_string())
1614 .unwrap_or_else(|| "screen".to_string());
1615 ActionDetails::GuiAction {
1616 app_name: app,
1617 action,
1618 element: None,
1619 }
1620 }
1621 "macos_contacts" => {
1622 let action = arguments["action"].as_str().unwrap_or("search").to_string();
1623 let query = arguments["query"]
1624 .as_str()
1625 .or_else(|| arguments["name"].as_str())
1626 .map(|q| format!("'{}'", q))
1627 .unwrap_or_default();
1628 ActionDetails::Other {
1629 info: format!("Contacts: {} {}", action, query),
1630 }
1631 }
1632 "macos_gui_scripting" | "macos_accessibility" => {
1633 let app_name = arguments["app_name"]
1634 .as_str()
1635 .unwrap_or("unknown")
1636 .to_string();
1637 let action = arguments["action"]
1638 .as_str()
1639 .unwrap_or("unknown")
1640 .to_string();
1641 let element = arguments["element_description"]
1642 .as_str()
1643 .map(|s| s.to_string());
1644 ActionDetails::GuiAction {
1645 app_name,
1646 action,
1647 element,
1648 }
1649 }
1650 name if name.starts_with("browser_") => {
1652 let action = name.strip_prefix("browser_").unwrap_or(name).to_string();
1653 let url = arguments["url"].as_str().map(|s| s.to_string());
1654 let selector = arguments["selector"]
1655 .as_str()
1656 .or_else(|| arguments["ref"].as_str())
1657 .map(|s| s.to_string());
1658 ActionDetails::BrowserAction {
1659 action,
1660 url,
1661 selector,
1662 }
1663 }
1664 "web_search" | "web_fetch" => {
1666 let host = if tool_name == "web_search" {
1667 "api.duckduckgo.com".to_string()
1668 } else {
1669 let url_str = arguments["url"].as_str().unwrap_or("unknown URL");
1671 url_str
1672 .strip_prefix("https://")
1673 .or_else(|| url_str.strip_prefix("http://"))
1674 .and_then(|s| s.split('/').next())
1675 .unwrap_or(url_str)
1676 .to_string()
1677 };
1678 ActionDetails::NetworkRequest {
1679 host,
1680 method: if tool_name == "web_search" {
1681 "SEARCH".to_string()
1682 } else {
1683 "GET".to_string()
1684 },
1685 }
1686 }
1687 "imessage_send" => {
1689 let recipient = arguments["recipient"]
1690 .as_str()
1691 .unwrap_or("unknown")
1692 .to_string();
1693 let preview = arguments["message"]
1694 .as_str()
1695 .map(|s| {
1696 if s.len() > 100 {
1697 format!("{}...", &s[..97])
1698 } else {
1699 s.to_string()
1700 }
1701 })
1702 .unwrap_or_default();
1703 ActionDetails::ChannelReply {
1704 channel: "iMessage".to_string(),
1705 recipient,
1706 preview,
1707 priority: MessagePriority::Normal,
1708 }
1709 }
1710 "slack" => {
1712 let action = arguments
1713 .get("action")
1714 .and_then(|v| v.as_str())
1715 .unwrap_or("send_message");
1716 match action {
1717 "send_message" | "reply_thread" => {
1718 let recipient = arguments["channel"]
1719 .as_str()
1720 .unwrap_or("unknown")
1721 .to_string();
1722 let preview = arguments["message"]
1723 .as_str()
1724 .map(|s| {
1725 if s.len() > 100 {
1726 format!("{}...", &s[..97])
1727 } else {
1728 s.to_string()
1729 }
1730 })
1731 .unwrap_or_default();
1732 ActionDetails::ChannelReply {
1733 channel: "Slack".to_string(),
1734 recipient,
1735 preview,
1736 priority: MessagePriority::Normal,
1737 }
1738 }
1739 "add_reaction" => ActionDetails::ChannelReply {
1740 channel: "Slack".to_string(),
1741 recipient: arguments["channel"]
1742 .as_str()
1743 .unwrap_or("unknown")
1744 .to_string(),
1745 preview: format!(":{}:", arguments["emoji"].as_str().unwrap_or("?")),
1746 priority: MessagePriority::Normal,
1747 },
1748 _ => ActionDetails::Other {
1749 info: format!("slack:{}", action),
1750 },
1751 }
1752 }
1753 "arxiv_research" => {
1755 let action = arguments
1756 .get("action")
1757 .and_then(|v| v.as_str())
1758 .unwrap_or("search");
1759 match action {
1760 "save" | "remove" | "collections" | "digest_config" => {
1761 ActionDetails::FileWrite {
1762 path: ".rustant/arxiv/library.json".into(),
1763 size_bytes: 0,
1764 }
1765 }
1766 _ => ActionDetails::NetworkRequest {
1767 host: "export.arxiv.org".to_string(),
1768 method: "GET".to_string(),
1769 },
1770 }
1771 }
1772 "knowledge_graph" => {
1774 let action = arguments
1775 .get("action")
1776 .and_then(|v| v.as_str())
1777 .unwrap_or("list");
1778 match action {
1779 "add_node" | "update_node" | "remove_node" | "add_edge" | "remove_edge"
1780 | "import_arxiv" => ActionDetails::FileWrite {
1781 path: ".rustant/knowledge/graph.json".into(),
1782 size_bytes: 0,
1783 },
1784 _ => ActionDetails::FileRead {
1785 path: ".rustant/knowledge/graph.json".into(),
1786 },
1787 }
1788 }
1789 "experiment_tracker" => {
1791 let action = arguments
1792 .get("action")
1793 .and_then(|v| v.as_str())
1794 .unwrap_or("list_experiments");
1795 match action {
1796 "add_hypothesis"
1797 | "update_hypothesis"
1798 | "add_experiment"
1799 | "start_experiment"
1800 | "complete_experiment"
1801 | "fail_experiment"
1802 | "record_evidence" => ActionDetails::FileWrite {
1803 path: ".rustant/experiments/tracker.json".into(),
1804 size_bytes: 0,
1805 },
1806 _ => ActionDetails::FileRead {
1807 path: ".rustant/experiments/tracker.json".into(),
1808 },
1809 }
1810 }
1811 "code_intelligence" => {
1813 let path = arguments
1814 .get("path")
1815 .and_then(|v| v.as_str())
1816 .unwrap_or(".");
1817 ActionDetails::FileRead { path: path.into() }
1818 }
1819 "content_engine" => {
1821 let action = arguments
1822 .get("action")
1823 .and_then(|v| v.as_str())
1824 .unwrap_or("list");
1825 match action {
1826 "create" | "update" | "set_status" | "delete" | "schedule" | "calendar_add"
1827 | "calendar_remove" => ActionDetails::FileWrite {
1828 path: ".rustant/content/library.json".into(),
1829 size_bytes: 0,
1830 },
1831 _ => ActionDetails::FileRead {
1832 path: ".rustant/content/library.json".into(),
1833 },
1834 }
1835 }
1836 "skill_tracker" => {
1838 let action = arguments
1839 .get("action")
1840 .and_then(|v| v.as_str())
1841 .unwrap_or("list_skills");
1842 match action {
1843 "add_skill" | "log_practice" | "learning_path" => ActionDetails::FileWrite {
1844 path: ".rustant/skills/tracker.json".into(),
1845 size_bytes: 0,
1846 },
1847 _ => ActionDetails::FileRead {
1848 path: ".rustant/skills/tracker.json".into(),
1849 },
1850 }
1851 }
1852 "career_intel" => {
1854 let action = arguments
1855 .get("action")
1856 .and_then(|v| v.as_str())
1857 .unwrap_or("progress_report");
1858 match action {
1859 "set_goal" | "log_achievement" | "add_portfolio" | "network_note" => {
1860 ActionDetails::FileWrite {
1861 path: ".rustant/career/intel.json".into(),
1862 size_bytes: 0,
1863 }
1864 }
1865 _ => ActionDetails::FileRead {
1866 path: ".rustant/career/intel.json".into(),
1867 },
1868 }
1869 }
1870 "system_monitor" => {
1872 let action = arguments
1873 .get("action")
1874 .and_then(|v| v.as_str())
1875 .unwrap_or("list_services");
1876 match action {
1877 "health_check" => ActionDetails::NetworkRequest {
1878 host: "service health check".to_string(),
1879 method: "GET".to_string(),
1880 },
1881 "add_service" | "log_incident" => ActionDetails::FileWrite {
1882 path: ".rustant/monitoring/topology.json".into(),
1883 size_bytes: 0,
1884 },
1885 _ => ActionDetails::FileRead {
1886 path: ".rustant/monitoring/topology.json".into(),
1887 },
1888 }
1889 }
1890 "life_planner" => {
1892 let action = arguments
1893 .get("action")
1894 .and_then(|v| v.as_str())
1895 .unwrap_or("daily_plan");
1896 match action {
1897 "set_energy_profile" | "add_deadline" | "log_habit" | "context_switch_log" => {
1898 ActionDetails::FileWrite {
1899 path: ".rustant/life/planner.json".into(),
1900 size_bytes: 0,
1901 }
1902 }
1903 _ => ActionDetails::FileRead {
1904 path: ".rustant/life/planner.json".into(),
1905 },
1906 }
1907 }
1908 "privacy_manager" => {
1910 let action = arguments
1911 .get("action")
1912 .and_then(|v| v.as_str())
1913 .unwrap_or("list_boundaries");
1914 match action {
1915 "delete_data" => {
1916 let domain = arguments
1917 .get("domain")
1918 .and_then(|v| v.as_str())
1919 .unwrap_or("unknown");
1920 ActionDetails::FileDelete {
1921 path: format!(".rustant/{}/", domain).into(),
1922 }
1923 }
1924 "set_boundary" | "encrypt_store" => ActionDetails::FileWrite {
1925 path: ".rustant/privacy/config.json".into(),
1926 size_bytes: 0,
1927 },
1928 _ => ActionDetails::FileRead {
1929 path: ".rustant/privacy/config.json".into(),
1930 },
1931 }
1932 }
1933 "self_improvement" => {
1935 let action = arguments
1936 .get("action")
1937 .and_then(|v| v.as_str())
1938 .unwrap_or("analyze_patterns");
1939 match action {
1940 "set_preference" | "feedback" | "reset_baseline" => ActionDetails::FileWrite {
1941 path: ".rustant/meta/improvement.json".into(),
1942 size_bytes: 0,
1943 },
1944 _ => ActionDetails::FileRead {
1945 path: ".rustant/meta/improvement.json".into(),
1946 },
1947 }
1948 }
1949 _ => ActionDetails::Other {
1950 info: arguments.to_string(),
1951 },
1952 }
1953 }
1954
1955 fn workflow_routing_hint(classification: &TaskClassification) -> Option<String> {
1966 let workflow = match classification {
1967 TaskClassification::Workflow(name) => name.as_str(),
1968 _ => return None,
1969 };
1970
1971 Some(format!(
1972 "WORKFLOW ROUTING: For this task, run the '{}' workflow. \
1973 Use shell_exec to run: `rustant workflow run {}` — or accomplish \
1974 the task directly step by step using available tools.",
1975 workflow, workflow
1976 ))
1977 }
1978
1979 #[cfg(target_os = "macos")]
1980 fn tool_routing_hint_from_classification(
1981 classification: &TaskClassification,
1982 ) -> Option<String> {
1983 if let Some(hint) = Self::workflow_routing_hint(classification) {
1985 return Some(hint);
1986 }
1987
1988 let tool_hint = match classification {
1989 TaskClassification::Clipboard => {
1990 "For this task, call the 'macos_clipboard' tool with {\"action\":\"read\"} to read the clipboard or {\"action\":\"write\",\"content\":\"...\"} to write to it."
1991 }
1992 TaskClassification::SystemInfo => {
1993 "For this task, call the 'macos_system_info' tool with the appropriate action: \"battery\", \"disk\", \"memory\", \"cpu\", \"network\", or \"version\"."
1994 }
1995 TaskClassification::AppControl => {
1996 "For this task, call the 'macos_app_control' tool with the appropriate action: \"list_running\", \"open\", \"quit\", or \"activate\"."
1997 }
1998 TaskClassification::Meeting => {
1999 "For this task, call 'macos_meeting_recorder'. Use action 'record_and_transcribe' to start (announces via TTS, records with silence detection, auto-transcribes to Notes.app). Use 'stop' to stop manually. Use 'status' to check state."
2000 }
2001 TaskClassification::Calendar => {
2002 "For this task, call the 'macos_calendar' tool with the appropriate action."
2003 }
2004 TaskClassification::Reminders => {
2005 "For this task, call the 'macos_reminders' tool with the appropriate action."
2006 }
2007 TaskClassification::Notes => {
2008 "For this task, call the 'macos_notes' tool with the appropriate action."
2009 }
2010 TaskClassification::Screenshot => {
2011 "For this task, call the 'macos_screenshot' tool with the appropriate action."
2012 }
2013 TaskClassification::Notification => {
2014 "For this task, call the 'macos_notification' tool."
2015 }
2016 TaskClassification::Spotlight => {
2017 "For this task, call the 'macos_spotlight' tool to search files using Spotlight."
2018 }
2019 TaskClassification::FocusMode => "For this task, call the 'macos_focus_mode' tool.",
2020 TaskClassification::Music => {
2021 "For this task, call the 'macos_music' tool with the appropriate action."
2022 }
2023 TaskClassification::Email => {
2024 "For this task, call the 'macos_mail' tool with the appropriate action."
2025 }
2026 TaskClassification::Finder => {
2027 "For this task, call the 'macos_finder' tool with the appropriate action."
2028 }
2029 TaskClassification::Contacts => {
2030 "For this task, call the 'macos_contacts' tool with the appropriate action."
2031 }
2032 TaskClassification::WebSearch => {
2033 "For this task, call the 'web_search' tool with {\"query\": \"your search terms\"}. Do NOT use macos_safari or shell_exec for web searches — use the dedicated web_search tool which queries DuckDuckGo."
2034 }
2035 TaskClassification::WebFetch => {
2036 "For this task, call the 'web_fetch' tool with {\"url\": \"https://...\"} to retrieve page content. Do NOT use macos_safari or shell_exec — use the dedicated web_fetch tool."
2037 }
2038 TaskClassification::Safari => {
2039 "For this task, call the 'macos_safari' tool with the appropriate action. Note: for simple web searches use 'web_search' instead, and for fetching page content use 'web_fetch' instead."
2040 }
2041 TaskClassification::Slack => {
2042 "For this task, call the 'slack' tool with the appropriate action (send_message, read_messages, list_channels, reply_thread, list_users, add_reaction). Do NOT use macos_gui_scripting or macos_app_control to interact with Slack."
2043 }
2044 TaskClassification::Messaging => {
2045 "For this task, call the appropriate iMessage tool: 'imessage_read', 'imessage_send', or 'imessage_contacts'."
2046 }
2047 TaskClassification::ArxivResearch => {
2048 "For this task, call the 'arxiv_research' tool with {\"action\": \"search\", \"query\": \"your search terms\", \"max_results\": 10}. This tool uses the arXiv API directly — do NOT use macos_safari, shell_exec, or curl. Other actions: fetch (get by ID), analyze (LLM summary), trending (recent papers), paper_to_code, paper_to_notebook, save/library/remove, export_bibtex."
2049 }
2050 TaskClassification::KnowledgeGraph => {
2051 "For this task, call the 'knowledge_graph' tool. Actions: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot."
2052 }
2053 TaskClassification::ExperimentTracking => {
2054 "For this task, call the 'experiment_tracker' tool. Actions: add_hypothesis, update_hypothesis, list_hypotheses, get_hypothesis, add_experiment, start_experiment, complete_experiment, fail_experiment, get_experiment, list_experiments, record_evidence, compare_experiments, summary, export_markdown."
2055 }
2056 TaskClassification::CodeIntelligence => {
2057 "For this task, call the 'code_intelligence' tool. Actions: analyze_architecture, detect_patterns, translate_snippet, compare_implementations, tech_debt_report, api_surface, dependency_map."
2058 }
2059 TaskClassification::ContentEngine => {
2060 "For this task, call the 'content_engine' tool. Actions: create, update, set_status, get, list, search, delete, schedule, calendar_add, calendar_list, calendar_remove, stats, adapt, export_markdown."
2061 }
2062 TaskClassification::SkillTracker => {
2063 "For this task, call the 'skill_tracker' tool. Actions: add_skill, log_practice, assess, list_skills, knowledge_gaps, learning_path, progress_report, daily_practice."
2064 }
2065 TaskClassification::CareerIntel => {
2066 "For this task, call the 'career_intel' tool. Actions: set_goal, log_achievement, add_portfolio, gap_analysis, market_scan, network_note, progress_report, strategy_review."
2067 }
2068 TaskClassification::SystemMonitor => {
2069 "For this task, call the 'system_monitor' tool. Actions: add_service, topology, health_check, log_incident, correlate, generate_runbook, impact_analysis, list_services."
2070 }
2071 TaskClassification::LifePlanner => {
2072 "For this task, call the 'life_planner' tool. Actions: set_energy_profile, add_deadline, log_habit, daily_plan, weekly_review, context_switch_log, balance_report, optimize_schedule."
2073 }
2074 TaskClassification::PrivacyManager => {
2075 "For this task, call the 'privacy_manager' tool. Actions: set_boundary, list_boundaries, audit_access, compliance_check, export_data, delete_data, encrypt_store, privacy_report."
2076 }
2077 TaskClassification::SelfImprovement => {
2078 "For this task, call the 'self_improvement' tool. Actions: analyze_patterns, performance_report, suggest_improvements, set_preference, get_preferences, cognitive_load, feedback, reset_baseline."
2079 }
2080 _ => return None,
2081 };
2082
2083 Some(format!("TOOL ROUTING: {}", tool_hint))
2084 }
2085
2086 #[cfg(not(target_os = "macos"))]
2088 fn tool_routing_hint_from_classification(
2089 classification: &TaskClassification,
2090 ) -> Option<String> {
2091 if let Some(hint) = Self::workflow_routing_hint(classification) {
2093 return Some(hint);
2094 }
2095
2096 let tool_hint = match classification {
2097 TaskClassification::WebSearch => {
2098 "For this task, call the 'web_search' tool with {\"query\": \"your search terms\"}. Do NOT use shell_exec for web searches — use the dedicated web_search tool which queries DuckDuckGo."
2099 }
2100 TaskClassification::WebFetch => {
2101 "For this task, call the 'web_fetch' tool with {\"url\": \"https://...\"} to retrieve page content. Do NOT use shell_exec — use the dedicated web_fetch tool."
2102 }
2103 TaskClassification::Slack => {
2104 "For this task, call the 'slack' tool with the appropriate action (send_message, read_messages, list_channels, reply_thread, list_users, add_reaction). Do NOT use shell_exec to interact with Slack."
2105 }
2106 TaskClassification::ArxivResearch => {
2107 "For this task, call the 'arxiv_research' tool with {\"action\": \"search\", \"query\": \"your search terms\", \"max_results\": 10}. This tool uses the arXiv API directly — do NOT use shell_exec, or curl. Other actions: fetch (get by ID), analyze (LLM summary), trending (recent papers), paper_to_code, paper_to_notebook, save/library/remove, export_bibtex."
2108 }
2109 TaskClassification::KnowledgeGraph => {
2110 "For this task, call the 'knowledge_graph' tool. Actions: add_node, get_node, update_node, remove_node, add_edge, remove_edge, neighbors, search, list, path, stats, import_arxiv, export_dot."
2111 }
2112 TaskClassification::ExperimentTracking => {
2113 "For this task, call the 'experiment_tracker' tool. Actions: add_hypothesis, update_hypothesis, list_hypotheses, get_hypothesis, add_experiment, start_experiment, complete_experiment, fail_experiment, get_experiment, list_experiments, record_evidence, compare_experiments, summary, export_markdown."
2114 }
2115 TaskClassification::CodeIntelligence => {
2116 "For this task, call the 'code_intelligence' tool. Actions: analyze_architecture, detect_patterns, translate_snippet, compare_implementations, tech_debt_report, api_surface, dependency_map."
2117 }
2118 TaskClassification::ContentEngine => {
2119 "For this task, call the 'content_engine' tool. Actions: create, update, set_status, get, list, search, delete, schedule, calendar_add, calendar_list, calendar_remove, stats, adapt, export_markdown."
2120 }
2121 TaskClassification::SkillTracker => {
2122 "For this task, call the 'skill_tracker' tool. Actions: add_skill, log_practice, assess, list_skills, knowledge_gaps, learning_path, progress_report, daily_practice."
2123 }
2124 TaskClassification::CareerIntel => {
2125 "For this task, call the 'career_intel' tool. Actions: set_goal, log_achievement, add_portfolio, gap_analysis, market_scan, network_note, progress_report, strategy_review."
2126 }
2127 TaskClassification::SystemMonitor => {
2128 "For this task, call the 'system_monitor' tool. Actions: add_service, topology, health_check, log_incident, correlate, generate_runbook, impact_analysis, list_services."
2129 }
2130 TaskClassification::LifePlanner => {
2131 "For this task, call the 'life_planner' tool. Actions: set_energy_profile, add_deadline, log_habit, daily_plan, weekly_review, context_switch_log, balance_report, optimize_schedule."
2132 }
2133 TaskClassification::PrivacyManager => {
2134 "For this task, call the 'privacy_manager' tool. Actions: set_boundary, list_boundaries, audit_access, compliance_check, export_data, delete_data, encrypt_store, privacy_report."
2135 }
2136 TaskClassification::SelfImprovement => {
2137 "For this task, call the 'self_improvement' tool. Actions: analyze_patterns, performance_report, suggest_improvements, set_preference, get_preferences, cognitive_load, feedback, reset_baseline."
2138 }
2139 _ => return None,
2140 };
2141
2142 Some(format!("TOOL ROUTING: {}", tool_hint))
2143 }
2144
2145 #[cfg(target_os = "macos")]
2149 fn auto_correct_tool_call(
2150 failed_tool: &str,
2151 _args: &serde_json::Value,
2152 state: &AgentState,
2153 ) -> Option<(String, serde_json::Value)> {
2154 let classification = state.task_classification.as_ref()?;
2155 let task = state.current_goal.as_deref().unwrap_or("");
2156
2157 match classification {
2158 TaskClassification::Slack
2160 if matches!(
2161 failed_tool,
2162 "macos_gui_scripting" | "macos_app_control" | "shell_exec"
2163 ) =>
2164 {
2165 Some((
2166 "slack".to_string(),
2167 serde_json::json!({"action": "send_message"}),
2168 ))
2169 }
2170 TaskClassification::ArxivResearch
2172 if matches!(
2173 failed_tool,
2174 "macos_safari" | "shell_exec" | "web_fetch" | "web_search"
2175 ) =>
2176 {
2177 Some((
2178 "arxiv_research".to_string(),
2179 serde_json::json!({"action": "search", "query": task, "max_results": 10}),
2180 ))
2181 }
2182 TaskClassification::WebSearch
2184 if matches!(failed_tool, "macos_safari" | "shell_exec") =>
2185 {
2186 Some(("web_search".to_string(), serde_json::json!({"query": task})))
2187 }
2188 TaskClassification::Clipboard
2190 if matches!(failed_tool, "document_read" | "file_read" | "shell_exec") =>
2191 {
2192 Some((
2193 "macos_clipboard".to_string(),
2194 serde_json::json!({"action": "read"}),
2195 ))
2196 }
2197 TaskClassification::SystemInfo
2199 if matches!(failed_tool, "document_read" | "file_read" | "shell_exec") =>
2200 {
2201 let lower = task.to_lowercase();
2203 let action = if lower.contains("battery") {
2204 "battery"
2205 } else if lower.contains("disk") {
2206 "disk"
2207 } else if lower.contains("cpu") || lower.contains("processor") {
2208 "cpu"
2209 } else if lower.contains("memory") || lower.contains("ram") {
2210 "memory"
2211 } else {
2212 "version"
2213 };
2214 Some((
2215 "macos_system_info".to_string(),
2216 serde_json::json!({"action": action}),
2217 ))
2218 }
2219 TaskClassification::AppControl
2221 if matches!(failed_tool, "document_read" | "file_read" | "shell_exec") =>
2222 {
2223 Some((
2224 "macos_app_control".to_string(),
2225 serde_json::json!({"action": "list_running"}),
2226 ))
2227 }
2228 _ => None,
2229 }
2230 }
2231
2232 #[cfg(not(target_os = "macos"))]
2234 fn auto_correct_tool_call(
2235 failed_tool: &str,
2236 _args: &serde_json::Value,
2237 state: &AgentState,
2238 ) -> Option<(String, serde_json::Value)> {
2239 let classification = state.task_classification.as_ref()?;
2240
2241 if matches!(classification, TaskClassification::Slack)
2242 && matches!(failed_tool, "shell_exec" | "web_fetch")
2243 {
2244 return Some((
2245 "slack".to_string(),
2246 serde_json::json!({"action": "send_message"}),
2247 ));
2248 }
2249
2250 None
2251 }
2252
2253 fn build_decision_explanation(
2255 &self,
2256 tool_name: &str,
2257 arguments: &serde_json::Value,
2258 ) -> DecisionExplanation {
2259 let risk_level = self
2260 .tools
2261 .get(tool_name)
2262 .map(|t| t.risk_level)
2263 .unwrap_or(RiskLevel::Execute);
2264
2265 let mut builder = ExplanationBuilder::new(DecisionType::ToolSelection {
2266 selected_tool: tool_name.to_string(),
2267 });
2268
2269 builder.add_reasoning_step(
2271 format!("Selected tool '{}' (risk: {})", tool_name, risk_level),
2272 None,
2273 );
2274
2275 if let Some(obj) = arguments.as_object() {
2277 let param_keys: Vec<&str> = obj.keys().map(|k| k.as_str()).collect();
2278 if !param_keys.is_empty() {
2279 builder.add_reasoning_step(
2280 format!("Parameters: {}", param_keys.join(", ")),
2281 Some(&arguments.to_string()),
2282 );
2283 }
2284 }
2285
2286 if let Some(goal) = &self.memory.working.current_goal {
2288 builder.add_context_factor(
2289 &format!("Current goal: {}", goal),
2290 FactorInfluence::Positive,
2291 );
2292 }
2293
2294 builder.add_context_factor(
2295 &format!("Approval mode: {}", self.safety.approval_mode()),
2296 FactorInfluence::Neutral,
2297 );
2298
2299 builder.add_context_factor(
2300 &format!(
2301 "Iteration {}/{}",
2302 self.state.iteration, self.state.max_iterations
2303 ),
2304 if self.state.iteration as f64 / self.state.max_iterations as f64 > 0.8 {
2305 FactorInfluence::Negative
2306 } else {
2307 FactorInfluence::Neutral
2308 },
2309 );
2310
2311 for (name, tool) in &self.tools {
2313 if name != tool_name && tool.risk_level <= risk_level {
2314 builder.add_alternative(name, "Not selected by LLM for this step", tool.risk_level);
2315 }
2316 }
2317
2318 let confidence = self.calculate_tool_confidence(tool_name, risk_level);
2320 builder.set_confidence(confidence);
2321
2322 builder.build()
2323 }
2324
2325 fn calculate_tool_confidence(&self, tool_name: &str, risk_level: RiskLevel) -> f32 {
2329 let mut confidence: f32 = match risk_level {
2331 RiskLevel::ReadOnly => 0.90,
2332 RiskLevel::Write => 0.75,
2333 RiskLevel::Execute => 0.65,
2334 RiskLevel::Network => 0.70,
2335 RiskLevel::Destructive => 0.45,
2336 };
2337
2338 if self.tool_token_usage.contains_key(tool_name) {
2340 confidence += 0.05;
2341 }
2342
2343 if self.state.iteration > 10 {
2345 confidence -= 0.1;
2346 }
2347
2348 if self.state.max_iterations > 0
2350 && (self.state.iteration as f64 / self.state.max_iterations as f64) > 0.8
2351 {
2352 confidence -= 0.05;
2353 }
2354
2355 confidence.clamp(0.0, 1.0)
2356 }
2357
2358 pub fn state(&self) -> &AgentState {
2360 &self.state
2361 }
2362
2363 pub fn cancellation_token(&self) -> CancellationToken {
2365 self.cancellation.clone()
2366 }
2367
2368 pub fn cancel(&self) {
2370 self.cancellation.cancel();
2371 }
2372
2373 pub fn reset_cancellation(&mut self) {
2376 self.cancellation = CancellationToken::new();
2377 }
2378
2379 pub fn brain(&self) -> &Brain {
2381 &self.brain
2382 }
2383
2384 pub fn safety(&self) -> &SafetyGuardian {
2386 &self.safety
2387 }
2388
2389 pub fn safety_mut(&mut self) -> &mut SafetyGuardian {
2391 &mut self.safety
2392 }
2393
2394 pub fn memory(&self) -> &MemorySystem {
2396 &self.memory
2397 }
2398
2399 pub fn memory_mut(&mut self) -> &mut MemorySystem {
2401 &mut self.memory
2402 }
2403
2404 pub fn config(&self) -> &AgentConfig {
2406 &self.config
2407 }
2408
2409 pub fn config_mut(&mut self) -> &mut AgentConfig {
2411 &mut self.config
2412 }
2413
2414 pub fn cron_scheduler(&self) -> Option<&CronScheduler> {
2416 self.cron_scheduler.as_ref()
2417 }
2418
2419 pub fn cron_scheduler_mut(&mut self) -> Option<&mut CronScheduler> {
2421 self.cron_scheduler.as_mut()
2422 }
2423
2424 pub fn job_manager(&self) -> &JobManager {
2426 &self.job_manager
2427 }
2428
2429 pub fn job_manager_mut(&mut self) -> &mut JobManager {
2431 &mut self.job_manager
2432 }
2433
2434 pub fn check_scheduler(&mut self) -> Vec<String> {
2436 let mut due_tasks = Vec::new();
2437
2438 if let Some(ref scheduler) = self.cron_scheduler {
2440 let due_jobs: Vec<String> = scheduler
2441 .due_jobs()
2442 .iter()
2443 .map(|j| j.config.name.clone())
2444 .collect();
2445 for name in &due_jobs {
2446 if let Some(ref scheduler) = self.cron_scheduler
2447 && let Some(job) = scheduler.get_job(name)
2448 {
2449 due_tasks.push(job.config.task.clone());
2450 }
2451 }
2452 if let Some(ref mut scheduler) = self.cron_scheduler {
2454 for name in &due_jobs {
2455 let _ = scheduler.mark_executed(name);
2456 }
2457 }
2458 }
2459
2460 if let Some(ref mut heartbeat) = self.heartbeat_manager {
2462 let ready: Vec<(String, String)> = heartbeat
2463 .ready_tasks()
2464 .iter()
2465 .map(|t| (t.name.clone(), t.action.clone()))
2466 .collect();
2467 for (name, action) in &ready {
2468 if let Some(ref task_condition) = heartbeat
2469 .config()
2470 .tasks
2471 .iter()
2472 .find(|t| t.name == *name)
2473 .and_then(|t| t.condition.clone())
2474 {
2475 if HeartbeatManager::check_condition(task_condition) {
2476 due_tasks.push(action.clone());
2477 heartbeat.mark_executed(name);
2478 }
2479 } else {
2480 due_tasks.push(action.clone());
2481 heartbeat.mark_executed(name);
2482 }
2483 }
2484 }
2485
2486 due_tasks
2487 }
2488
2489 pub fn save_scheduler_state(
2491 &self,
2492 state_dir: &std::path::Path,
2493 ) -> Result<(), crate::error::SchedulerError> {
2494 if let Some(ref scheduler) = self.cron_scheduler {
2495 crate::scheduler::save_state(scheduler, &self.job_manager, state_dir)
2496 } else {
2497 Ok(())
2499 }
2500 }
2501
2502 pub fn load_scheduler_state(&mut self, state_dir: &std::path::Path) {
2504 if self.cron_scheduler.is_some() {
2505 let (loaded_scheduler, loaded_jm) = crate::scheduler::load_state(state_dir);
2506 if !loaded_scheduler.is_empty() {
2507 self.cron_scheduler = Some(loaded_scheduler);
2508 info!("Restored cron scheduler state from {:?}", state_dir);
2509 }
2510 if !loaded_jm.is_empty() {
2511 self.job_manager = loaded_jm;
2512 info!("Restored job manager state from {:?}", state_dir);
2513 }
2514 }
2515 }
2516
2517 pub fn recent_explanations(&self) -> &[DecisionExplanation] {
2519 &self.recent_explanations
2520 }
2521
2522 pub fn tool_token_breakdown(&self) -> &HashMap<String, usize> {
2524 &self.tool_token_usage
2525 }
2526
2527 pub fn top_tool_consumers(&self, n: usize) -> String {
2529 if self.tool_token_usage.is_empty() {
2530 return String::new();
2531 }
2532 let total: usize = self.tool_token_usage.values().sum();
2533 if total == 0 {
2534 return String::new();
2535 }
2536 let mut sorted: Vec<_> = self.tool_token_usage.iter().collect();
2537 sorted.sort_by(|a, b| b.1.cmp(a.1));
2538 let top: Vec<String> = sorted
2539 .iter()
2540 .take(n)
2541 .map(|(name, tokens)| {
2542 let pct = (**tokens as f64 / total as f64 * 100.0) as u8;
2543 format!("{} ({}%)", name, pct)
2544 })
2545 .collect();
2546 top.join(", ")
2547 }
2548
2549 pub async fn think_with_council(
2554 &self,
2555 task: &str,
2556 council: &crate::council::PlanningCouncil,
2557 ) -> Option<crate::council::CouncilResult> {
2558 if !crate::council::should_use_council(task) {
2559 debug!(task, "Skipping council — task is not a planning task");
2560 return None;
2561 }
2562
2563 info!(task, "Using council deliberation for planning task");
2564 match council.deliberate(task).await {
2565 Ok(result) => {
2566 info!(
2567 responses = result.member_responses.len(),
2568 reviews = result.peer_reviews.len(),
2569 cost = format!("${:.4}", result.total_cost),
2570 "Council deliberation succeeded"
2571 );
2572 Some(result)
2573 }
2574 Err(e) => {
2575 warn!(error = %e, "Council deliberation failed, falling back to single model");
2576 None
2577 }
2578 }
2579 }
2580
2581 pub fn set_plan_mode(&mut self, enabled: bool) {
2585 self.plan_mode = enabled;
2586 }
2587
2588 pub fn plan_mode(&self) -> bool {
2590 self.plan_mode
2591 }
2592
2593 pub fn current_plan(&self) -> Option<&crate::plan::ExecutionPlan> {
2595 self.current_plan.as_ref()
2596 }
2597
2598 async fn generate_plan(
2600 &mut self,
2601 task: &str,
2602 ) -> Result<crate::plan::ExecutionPlan, RustantError> {
2603 use crate::plan::{PLAN_GENERATION_PROMPT, PlanStatus};
2604
2605 let tool_list: Vec<String> = self
2608 .tool_definitions(None)
2609 .iter()
2610 .map(|t| format!("- {} — {}", t.name, t.description))
2611 .collect();
2612 let tools_str = tool_list.join("\n");
2613
2614 let plan_prompt = format!(
2615 "{}\n\nAvailable tools:\n{}\n\nTask: {}",
2616 PLAN_GENERATION_PROMPT, tools_str, task
2617 );
2618
2619 let messages = vec![Message::system(&plan_prompt), Message::user(task)];
2621
2622 let response = self
2623 .brain
2624 .think_with_retry(&messages, None, 3)
2625 .await
2626 .map_err(RustantError::Llm)?;
2627
2628 self.budget.record_usage(
2630 &response.usage,
2631 &CostEstimate {
2632 input_cost: 0.0,
2633 output_cost: 0.0,
2634 },
2635 );
2636
2637 let text = response.message.content.as_text().unwrap_or("").to_string();
2638 let mut plan = crate::plan::parse_plan_json(&text, task);
2639
2640 let max_steps = self.config.plan.as_ref().map(|p| p.max_steps).unwrap_or(20);
2642 if plan.steps.len() > max_steps {
2643 plan.steps.truncate(max_steps);
2644 }
2645
2646 plan.status = PlanStatus::PendingReview;
2647 Ok(plan)
2648 }
2649
2650 async fn execute_plan(
2652 &mut self,
2653 plan: &mut crate::plan::ExecutionPlan,
2654 ) -> Result<TaskResult, RustantError> {
2655 use crate::plan::{PlanStatus, StepStatus};
2656
2657 plan.status = PlanStatus::Executing;
2658 let task_id = Uuid::new_v4();
2659
2660 while let Some(step_idx) = plan.next_pending_step() {
2661 plan.current_step = Some(step_idx);
2662 let step = &plan.steps[step_idx];
2663 let step_desc = step.description.clone();
2664 let step_tool = step.tool.clone();
2665 let step_args = step.tool_args.clone();
2666
2667 self.callback
2669 .on_plan_step_start(step_idx, &plan.steps[step_idx])
2670 .await;
2671 plan.steps[step_idx].status = StepStatus::InProgress;
2672
2673 let result = if let Some(tool_name) = &step_tool {
2674 let args = step_args.unwrap_or(serde_json::json!({}));
2676
2677 self.callback.on_tool_start(tool_name, &args).await;
2678 let start = std::time::Instant::now();
2679 let exec_result = self.execute_tool("plan", tool_name, &args).await;
2680 let duration_ms = start.elapsed().as_millis() as u64;
2681
2682 match exec_result {
2683 Ok(output) => {
2684 self.callback
2685 .on_tool_result(tool_name, &output, duration_ms)
2686 .await;
2687 Ok(output.content)
2688 }
2689 Err(e) => Err(format!("{}", e)),
2690 }
2691 } else {
2692 let step_prompt = format!(
2695 "Execute plan step {}: {}\n\nPrevious step results are in context.",
2696 step_idx + 1,
2697 step_desc
2698 );
2699 self.memory.add_message(Message::user(&step_prompt));
2700
2701 let conversation = self.memory.context_messages();
2702 let tools = Some(self.tool_definitions(self.state.task_classification.as_ref()));
2703 let response = if self.config.llm.use_streaming {
2704 self.think_streaming(&conversation, tools).await
2705 } else {
2706 self.brain.think_with_retry(&conversation, tools, 3).await
2707 };
2708
2709 match response {
2710 Ok(resp) => {
2711 let text = resp
2712 .message
2713 .content
2714 .as_text()
2715 .unwrap_or("(no output)")
2716 .to_string();
2717 self.callback.on_assistant_message(&text).await;
2718 self.memory.add_message(resp.message);
2719 Ok(text)
2720 }
2721 Err(e) => Err(format!("{}", e)),
2722 }
2723 };
2724
2725 match result {
2726 Ok(output) => {
2727 plan.complete_step(step_idx, &output);
2728 }
2729 Err(error) => {
2730 plan.fail_step(step_idx, &error);
2731 self.callback
2733 .on_plan_step_complete(step_idx, &plan.steps[step_idx])
2734 .await;
2735 plan.status = PlanStatus::Failed;
2737 break;
2738 }
2739 }
2740
2741 self.callback
2743 .on_plan_step_complete(step_idx, &plan.steps[step_idx])
2744 .await;
2745 }
2746
2747 if plan.status != PlanStatus::Failed {
2749 let all_done = plan
2750 .steps
2751 .iter()
2752 .all(|s| s.status == StepStatus::Completed || s.status == StepStatus::Skipped);
2753 plan.status = if all_done {
2754 PlanStatus::Completed
2755 } else {
2756 PlanStatus::Failed
2757 };
2758 }
2759
2760 let success = plan.status == PlanStatus::Completed;
2761 let response = plan.progress_summary();
2762
2763 Ok(TaskResult {
2764 task_id,
2765 success,
2766 response,
2767 iterations: plan.steps.len(),
2768 total_usage: *self.brain.total_usage(),
2769 total_cost: *self.brain.total_cost(),
2770 })
2771 }
2772
2773 async fn process_task_with_plan(&mut self, task: &str) -> Result<TaskResult, RustantError> {
2775 use crate::plan::{PlanDecision, PlanStatus};
2776
2777 self.state.status = AgentStatus::Planning;
2779 self.callback.on_status_change(AgentStatus::Planning).await;
2780 self.callback.on_plan_generating(task).await;
2781
2782 let mut plan = self.generate_plan(task).await?;
2783
2784 for question in &plan.clarifications.clone() {
2786 let answer = self.callback.on_clarification_request(question).await;
2787 if !answer.is_empty() {
2788 self.memory
2790 .add_message(Message::user(format!("Q: {} A: {}", question, answer)));
2791 }
2792 }
2793
2794 loop {
2796 let decision = self.callback.on_plan_review(&plan).await;
2797 match decision {
2798 PlanDecision::Approve => break,
2799 PlanDecision::Reject => {
2800 plan.status = PlanStatus::Cancelled;
2801 self.current_plan = Some(plan);
2802 self.state.complete();
2803 self.callback.on_status_change(AgentStatus::Complete).await;
2804 let task_id = self.state.task_id.unwrap_or_else(Uuid::new_v4);
2805 return Ok(TaskResult {
2806 task_id,
2807 success: false,
2808 response: "Plan rejected by user.".to_string(),
2809 iterations: 0,
2810 total_usage: *self.brain.total_usage(),
2811 total_cost: *self.brain.total_cost(),
2812 });
2813 }
2814 PlanDecision::EditStep(idx, new_desc) => {
2815 if let Some(step) = plan.steps.get_mut(idx) {
2816 step.description = new_desc;
2817 plan.updated_at = chrono::Utc::now();
2818 }
2819 }
2820 PlanDecision::RemoveStep(idx) => {
2821 if idx < plan.steps.len() {
2822 plan.steps.remove(idx);
2823 for (i, step) in plan.steps.iter_mut().enumerate() {
2825 step.index = i;
2826 }
2827 plan.updated_at = chrono::Utc::now();
2828 }
2829 }
2830 PlanDecision::AddStep(idx, desc) => {
2831 let new_step = crate::plan::PlanStep {
2832 index: idx,
2833 description: desc,
2834 ..Default::default()
2835 };
2836 if idx <= plan.steps.len() {
2837 plan.steps.insert(idx, new_step);
2838 } else {
2839 plan.steps.push(new_step);
2840 }
2841 for (i, step) in plan.steps.iter_mut().enumerate() {
2843 step.index = i;
2844 }
2845 plan.updated_at = chrono::Utc::now();
2846 }
2847 PlanDecision::ReorderSteps(new_order) => {
2848 let old_steps = plan.steps.clone();
2849 plan.steps.clear();
2850 for (i, &old_idx) in new_order.iter().enumerate() {
2851 if let Some(mut step) = old_steps.get(old_idx).cloned() {
2852 step.index = i;
2853 plan.steps.push(step);
2854 }
2855 }
2856 plan.updated_at = chrono::Utc::now();
2857 }
2858 PlanDecision::AskQuestion(question) => {
2859 let messages = vec![
2861 Message::system("Answer this question about the plan you generated."),
2862 Message::user(&question),
2863 ];
2864 if let Ok(resp) = self.brain.think_with_retry(&messages, None, 1).await
2865 && let Some(answer) = resp.message.content.as_text()
2866 {
2867 self.callback.on_assistant_message(answer).await;
2868 }
2869 }
2870 }
2871 }
2872
2873 self.current_plan = Some(plan.clone());
2875 let result = self.execute_plan(&mut plan).await?;
2876 self.current_plan = Some(plan);
2877 self.state.complete();
2878 self.callback.on_status_change(AgentStatus::Complete).await;
2879
2880 Ok(result)
2881 }
2882
2883 async fn check_and_compress(&mut self) {
2888 if !self.memory.short_term.needs_compression() {
2889 return;
2890 }
2891
2892 debug!("Triggering LLM-based context compression");
2893 let msgs_to_summarize: Vec<crate::types::Message> = self
2894 .memory
2895 .short_term
2896 .messages_to_summarize()
2897 .into_iter()
2898 .cloned()
2899 .collect();
2900 let msgs_count = msgs_to_summarize.len();
2901 let pinned_count = self.memory.short_term.pinned_count();
2902
2903 let (summary_text, was_llm) = match self.summarizer.summarize(&msgs_to_summarize).await {
2904 Ok(result) => {
2905 info!(
2906 messages_summarized = result.messages_summarized,
2907 tokens_saved = result.tokens_saved,
2908 "Context compression via LLM summarization"
2909 );
2910 (result.text, true)
2911 }
2912 Err(e) => {
2913 warn!(
2914 error = %e,
2915 "LLM summarization failed, falling back to truncation"
2916 );
2917 let text = crate::summarizer::smart_fallback_summary(&msgs_to_summarize, 500);
2918 (text, false)
2919 }
2920 };
2921
2922 self.memory.short_term.compress(summary_text);
2923
2924 self.callback
2925 .on_context_health(&ContextHealthEvent::Compressed {
2926 messages_compressed: msgs_count,
2927 was_llm_summarized: was_llm,
2928 pinned_preserved: pinned_count,
2929 })
2930 .await;
2931 }
2932
2933 pub fn compact(&mut self) -> (usize, usize) {
2936 let before = self.memory.short_term.len();
2937 if before <= 2 {
2938 return (before, before);
2939 }
2940 let msgs: Vec<crate::types::Message> =
2941 self.memory.short_term.messages().iter().cloned().collect();
2942 let summary = crate::summarizer::smart_fallback_summary(&msgs, 500);
2943 self.memory.short_term.compress(summary);
2944 let after = self.memory.short_term.len();
2945 (before, after)
2946 }
2947}
2948
2949pub struct NoOpCallback;
2951
2952#[async_trait::async_trait]
2953impl AgentCallback for NoOpCallback {
2954 async fn on_assistant_message(&self, _message: &str) {}
2955 async fn on_token(&self, _token: &str) {}
2956 async fn request_approval(&self, _action: &ActionRequest) -> ApprovalDecision {
2957 ApprovalDecision::Approve }
2959 async fn on_tool_start(&self, _tool_name: &str, _args: &serde_json::Value) {}
2960 async fn on_tool_result(&self, _tool_name: &str, _output: &ToolOutput, _duration_ms: u64) {}
2961 async fn on_status_change(&self, _status: AgentStatus) {}
2962 async fn on_usage_update(&self, _usage: &TokenUsage, _cost: &CostEstimate) {}
2963 async fn on_decision_explanation(&self, _explanation: &DecisionExplanation) {}
2964}
2965
2966pub struct RecordingCallback {
2968 messages: tokio::sync::Mutex<Vec<String>>,
2969 tool_calls: tokio::sync::Mutex<Vec<String>>,
2970 status_changes: tokio::sync::Mutex<Vec<AgentStatus>>,
2971 explanations: tokio::sync::Mutex<Vec<DecisionExplanation>>,
2972 budget_warnings: tokio::sync::Mutex<Vec<(String, BudgetSeverity)>>,
2973 context_health_events: tokio::sync::Mutex<Vec<ContextHealthEvent>>,
2974}
2975
2976impl RecordingCallback {
2977 pub fn new() -> Self {
2978 Self {
2979 messages: tokio::sync::Mutex::new(Vec::new()),
2980 tool_calls: tokio::sync::Mutex::new(Vec::new()),
2981 status_changes: tokio::sync::Mutex::new(Vec::new()),
2982 explanations: tokio::sync::Mutex::new(Vec::new()),
2983 budget_warnings: tokio::sync::Mutex::new(Vec::new()),
2984 context_health_events: tokio::sync::Mutex::new(Vec::new()),
2985 }
2986 }
2987
2988 pub async fn messages(&self) -> Vec<String> {
2989 self.messages.lock().await.clone()
2990 }
2991
2992 pub async fn tool_calls(&self) -> Vec<String> {
2993 self.tool_calls.lock().await.clone()
2994 }
2995
2996 pub async fn status_changes(&self) -> Vec<AgentStatus> {
2997 self.status_changes.lock().await.clone()
2998 }
2999
3000 pub async fn explanations(&self) -> Vec<DecisionExplanation> {
3001 self.explanations.lock().await.clone()
3002 }
3003
3004 pub async fn budget_warnings(&self) -> Vec<(String, BudgetSeverity)> {
3005 self.budget_warnings.lock().await.clone()
3006 }
3007
3008 pub async fn context_health_events(&self) -> Vec<ContextHealthEvent> {
3009 self.context_health_events.lock().await.clone()
3010 }
3011}
3012
3013impl Default for RecordingCallback {
3014 fn default() -> Self {
3015 Self::new()
3016 }
3017}
3018
3019#[async_trait::async_trait]
3020impl AgentCallback for RecordingCallback {
3021 async fn on_assistant_message(&self, message: &str) {
3022 self.messages.lock().await.push(message.to_string());
3023 }
3024 async fn on_token(&self, _token: &str) {}
3025 async fn request_approval(&self, _action: &ActionRequest) -> ApprovalDecision {
3026 ApprovalDecision::Approve
3027 }
3028 async fn on_tool_start(&self, tool_name: &str, _args: &serde_json::Value) {
3029 self.tool_calls.lock().await.push(tool_name.to_string());
3030 }
3031 async fn on_tool_result(&self, _tool_name: &str, _output: &ToolOutput, _duration_ms: u64) {}
3032 async fn on_status_change(&self, status: AgentStatus) {
3033 self.status_changes.lock().await.push(status);
3034 }
3035 async fn on_usage_update(&self, _usage: &TokenUsage, _cost: &CostEstimate) {}
3036 async fn on_decision_explanation(&self, explanation: &DecisionExplanation) {
3037 self.explanations.lock().await.push(explanation.clone());
3038 }
3039 async fn on_budget_warning(&self, message: &str, severity: BudgetSeverity) {
3040 self.budget_warnings
3041 .lock()
3042 .await
3043 .push((message.to_string(), severity));
3044 }
3045 async fn on_context_health(&self, event: &ContextHealthEvent) {
3046 self.context_health_events.lock().await.push(event.clone());
3047 }
3048}
3049
3050#[cfg(test)]
3051mod tests {
3052 use super::*;
3053 use crate::brain::MockLlmProvider;
3054
3055 fn create_test_agent(provider: Arc<MockLlmProvider>) -> (Agent, Arc<RecordingCallback>) {
3056 let callback = Arc::new(RecordingCallback::new());
3057 let mut config = AgentConfig::default();
3058 config.llm.use_streaming = false;
3060 let agent = Agent::new(provider, config, callback.clone());
3061 (agent, callback)
3062 }
3063
3064 #[tokio::test]
3065 async fn test_agent_simple_text_response() {
3066 let provider = Arc::new(MockLlmProvider::new());
3067 provider.queue_response(MockLlmProvider::text_response("Hello! I can help you."));
3068
3069 let (mut agent, callback) = create_test_agent(provider);
3070 let result = agent.process_task("Say hello").await.unwrap();
3071
3072 assert!(result.success);
3073 assert_eq!(result.response, "Hello! I can help you.");
3074 assert_eq!(result.iterations, 1);
3075
3076 let messages = callback.messages().await;
3077 assert_eq!(messages.len(), 1);
3078 assert_eq!(messages[0], "Hello! I can help you.");
3079 }
3080
3081 #[tokio::test]
3082 async fn test_agent_tool_call_then_response() {
3083 let provider = Arc::new(MockLlmProvider::new());
3084
3085 provider.queue_response(MockLlmProvider::tool_call_response(
3087 "echo",
3088 serde_json::json!({"text": "test"}),
3089 ));
3090 provider.queue_response(MockLlmProvider::text_response(
3092 "I executed the echo tool successfully.",
3093 ));
3094
3095 let (mut agent, callback) = create_test_agent(provider);
3096
3097 agent.register_tool(RegisteredTool {
3099 definition: ToolDefinition {
3100 name: "echo".to_string(),
3101 description: "Echo input text".to_string(),
3102 parameters: serde_json::json!({
3103 "type": "object",
3104 "properties": { "text": { "type": "string" } },
3105 "required": ["text"]
3106 }),
3107 },
3108 risk_level: RiskLevel::ReadOnly,
3109 executor: Box::new(|args: serde_json::Value| {
3110 Box::pin(async move {
3111 let text = args["text"].as_str().unwrap_or("no text");
3112 Ok(ToolOutput::text(format!("Echo: {}", text)))
3113 })
3114 }),
3115 });
3116
3117 let result = agent.process_task("Test echo tool").await.unwrap();
3118
3119 assert!(result.success);
3120 assert_eq!(result.iterations, 2);
3121
3122 let tool_calls = callback.tool_calls().await;
3123 assert_eq!(tool_calls.len(), 1);
3124 assert_eq!(tool_calls[0], "echo");
3125 }
3126
3127 #[tokio::test]
3128 async fn test_agent_tool_not_found() {
3129 let provider = Arc::new(MockLlmProvider::new());
3130 provider.queue_response(MockLlmProvider::tool_call_response(
3131 "nonexistent_tool",
3132 serde_json::json!({}),
3133 ));
3134 provider.queue_response(MockLlmProvider::text_response(
3136 "Sorry, that tool doesn't exist.",
3137 ));
3138
3139 let (mut agent, _callback) = create_test_agent(provider);
3140 let result = agent.process_task("Use nonexistent tool").await.unwrap();
3141
3142 assert!(result.success);
3144 }
3145
3146 #[tokio::test]
3147 async fn test_agent_state_tracking() {
3148 let provider = Arc::new(MockLlmProvider::new());
3149 provider.queue_response(MockLlmProvider::text_response("Done"));
3150
3151 let (mut agent, callback) = create_test_agent(provider);
3152
3153 assert_eq!(agent.state().status, AgentStatus::Idle);
3154
3155 agent.process_task("Simple task").await.unwrap();
3156
3157 assert_eq!(agent.state().status, AgentStatus::Complete);
3158
3159 let statuses = callback.status_changes().await;
3160 assert!(statuses.contains(&AgentStatus::Thinking));
3161 assert!(statuses.contains(&AgentStatus::Complete));
3162 }
3163
3164 #[tokio::test]
3165 async fn test_agent_max_iterations() {
3166 let provider = Arc::new(MockLlmProvider::new());
3167 for _ in 0..55 {
3169 provider.queue_response(MockLlmProvider::tool_call_response(
3170 "echo",
3171 serde_json::json!({"text": "loop"}),
3172 ));
3173 }
3174
3175 let (mut agent, _callback) = create_test_agent(provider);
3176 agent.register_tool(RegisteredTool {
3177 definition: ToolDefinition {
3178 name: "echo".to_string(),
3179 description: "Echo".to_string(),
3180 parameters: serde_json::json!({}),
3181 },
3182 risk_level: RiskLevel::ReadOnly,
3183 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("echoed")) })),
3184 });
3185
3186 let result = agent.process_task("Infinite loop test").await;
3187 assert!(result.is_err());
3188 match result.unwrap_err() {
3189 RustantError::Agent(AgentError::MaxIterationsReached { max }) => {
3190 assert_eq!(max, 50);
3191 }
3192 e => panic!("Expected MaxIterationsReached, got: {:?}", e),
3193 }
3194 }
3195
3196 #[tokio::test]
3197 async fn test_agent_cancellation() {
3198 let provider = Arc::new(MockLlmProvider::new());
3199 provider.queue_response(MockLlmProvider::tool_call_response(
3201 "echo",
3202 serde_json::json!({"text": "test"}),
3203 ));
3204
3205 let (mut agent, _callback) = create_test_agent(provider);
3206 agent.register_tool(RegisteredTool {
3207 definition: ToolDefinition {
3208 name: "echo".to_string(),
3209 description: "Echo".to_string(),
3210 parameters: serde_json::json!({}),
3211 },
3212 risk_level: RiskLevel::ReadOnly,
3213 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("echoed")) })),
3214 });
3215
3216 agent.cancel();
3218 let result = agent.process_task("Cancelled task").await;
3219 assert!(result.is_err());
3220 match result.unwrap_err() {
3221 RustantError::Agent(AgentError::Cancelled) => {}
3222 e => panic!("Expected Cancelled, got: {:?}", e),
3223 }
3224 }
3225
3226 #[test]
3227 fn test_no_op_callback() {
3228 let _callback = NoOpCallback;
3230 }
3231
3232 #[tokio::test]
3233 async fn test_agent_streaming_mode() {
3234 let provider = Arc::new(MockLlmProvider::new());
3235 provider.queue_response(MockLlmProvider::text_response("streaming response"));
3236
3237 let callback = Arc::new(RecordingCallback::new());
3238 let mut config = AgentConfig::default();
3239 config.llm.use_streaming = true;
3240
3241 let mut agent = Agent::new(provider, config, callback.clone());
3242 let result = agent.process_task("Test streaming").await.unwrap();
3243
3244 assert!(result.success);
3245 assert!(result.response.contains("streaming"));
3246 }
3249
3250 #[tokio::test]
3251 async fn test_recording_callback() {
3252 let callback = RecordingCallback::new();
3253 callback.on_assistant_message("hello").await;
3254 callback
3255 .on_tool_start("file_read", &serde_json::json!({}))
3256 .await;
3257 callback.on_status_change(AgentStatus::Thinking).await;
3258
3259 assert_eq!(callback.messages().await, vec!["hello"]);
3260 assert_eq!(callback.tool_calls().await, vec!["file_read"]);
3261 assert_eq!(callback.status_changes().await, vec![AgentStatus::Thinking]);
3262 }
3263
3264 #[tokio::test]
3267 async fn test_recording_callback_records_explanations() {
3268 let callback = RecordingCallback::new();
3269 let explanation = ExplanationBuilder::new(DecisionType::ToolSelection {
3270 selected_tool: "echo".into(),
3271 })
3272 .build();
3273 callback.on_decision_explanation(&explanation).await;
3274
3275 let explanations = callback.explanations().await;
3276 assert_eq!(explanations.len(), 1);
3277 match &explanations[0].decision_type {
3278 DecisionType::ToolSelection { selected_tool } => {
3279 assert_eq!(selected_tool, "echo");
3280 }
3281 other => panic!("Expected ToolSelection, got {:?}", other),
3282 }
3283 }
3284
3285 #[tokio::test]
3286 async fn test_multipart_tool_call_emits_explanation() {
3287 let provider = Arc::new(MockLlmProvider::new());
3288
3289 provider.queue_response(MockLlmProvider::multipart_response(
3291 "I'll echo for you",
3292 "echo",
3293 serde_json::json!({"text": "test"}),
3294 ));
3295 provider.queue_response(MockLlmProvider::text_response("Done."));
3297
3298 let (mut agent, callback) = create_test_agent(provider);
3299 agent.register_tool(RegisteredTool {
3300 definition: ToolDefinition {
3301 name: "echo".to_string(),
3302 description: "Echo input text".to_string(),
3303 parameters: serde_json::json!({
3304 "type": "object",
3305 "properties": { "text": { "type": "string" } },
3306 "required": ["text"]
3307 }),
3308 },
3309 risk_level: RiskLevel::ReadOnly,
3310 executor: Box::new(|args: serde_json::Value| {
3311 Box::pin(async move {
3312 let text = args["text"].as_str().unwrap_or("no text");
3313 Ok(ToolOutput::text(format!("Echo: {}", text)))
3314 })
3315 }),
3316 });
3317
3318 agent.process_task("Echo test").await.unwrap();
3319
3320 let explanations = callback.explanations().await;
3321 assert!(
3322 !explanations.is_empty(),
3323 "MultiPart tool calls should emit explanations"
3324 );
3325 let has_echo = explanations.iter().any(|e| {
3327 matches!(&e.decision_type, DecisionType::ToolSelection { selected_tool } if selected_tool == "echo")
3328 });
3329 assert!(has_echo, "Should have explanation for echo tool selection");
3330 }
3331
3332 #[tokio::test]
3333 async fn test_single_tool_call_emits_explanation() {
3334 let provider = Arc::new(MockLlmProvider::new());
3335 provider.queue_response(MockLlmProvider::tool_call_response(
3336 "echo",
3337 serde_json::json!({"text": "hi"}),
3338 ));
3339 provider.queue_response(MockLlmProvider::text_response("Done."));
3340
3341 let (mut agent, callback) = create_test_agent(provider);
3342 agent.register_tool(RegisteredTool {
3343 definition: ToolDefinition {
3344 name: "echo".to_string(),
3345 description: "Echo".to_string(),
3346 parameters: serde_json::json!({}),
3347 },
3348 risk_level: RiskLevel::ReadOnly,
3349 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("echoed")) })),
3350 });
3351
3352 agent.process_task("Echo test").await.unwrap();
3353
3354 let explanations = callback.explanations().await;
3355 assert!(
3356 !explanations.is_empty(),
3357 "Single tool calls should emit explanations"
3358 );
3359 }
3360
3361 #[tokio::test]
3362 async fn test_contract_violation_emits_error_recovery_explanation() {
3363 use crate::safety::{Invariant, Predicate, SafetyContract};
3364
3365 let provider = Arc::new(MockLlmProvider::new());
3366 provider.queue_response(MockLlmProvider::tool_call_response(
3367 "echo",
3368 serde_json::json!({"text": "test"}),
3369 ));
3370 provider.queue_response(MockLlmProvider::text_response("OK, I'll skip that."));
3372
3373 let callback = Arc::new(RecordingCallback::new());
3374 let mut config = AgentConfig::default();
3375 config.llm.use_streaming = false;
3376 let mut agent = Agent::new(provider, config, callback.clone());
3377 agent.register_tool(RegisteredTool {
3378 definition: ToolDefinition {
3379 name: "echo".to_string(),
3380 description: "Echo".to_string(),
3381 parameters: serde_json::json!({}),
3382 },
3383 risk_level: RiskLevel::ReadOnly,
3384 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("echoed")) })),
3385 });
3386
3387 agent.safety_mut().set_contract(SafetyContract {
3389 name: "deny-all".into(),
3390 invariants: vec![Invariant {
3391 description: "no tools allowed".into(),
3392 predicate: Predicate::AlwaysFalse,
3393 }],
3394 ..Default::default()
3395 });
3396
3397 agent.process_task("Echo test").await.unwrap();
3398
3399 let explanations = callback.explanations().await;
3400 let has_error_recovery = explanations.iter().any(|e| {
3401 matches!(
3402 &e.decision_type,
3403 DecisionType::ErrorRecovery { error, .. } if error.contains("Contract violation")
3404 )
3405 });
3406 assert!(
3407 has_error_recovery,
3408 "Contract violations should emit ErrorRecovery explanations, got: {:?}",
3409 explanations
3410 .iter()
3411 .map(|e| &e.decision_type)
3412 .collect::<Vec<_>>()
3413 );
3414 }
3415
3416 #[tokio::test]
3419 async fn test_recording_callback_records_budget_warnings() {
3420 let callback = RecordingCallback::new();
3421 callback
3422 .on_budget_warning(
3423 "Session cost at 85% of $1.00 limit",
3424 BudgetSeverity::Warning,
3425 )
3426 .await;
3427 callback
3428 .on_budget_warning("Budget exceeded!", BudgetSeverity::Exceeded)
3429 .await;
3430
3431 let warnings = callback.budget_warnings().await;
3432 assert_eq!(warnings.len(), 2);
3433 assert!(warnings[0].0.contains("85%"));
3434 assert_eq!(warnings[0].1, BudgetSeverity::Warning);
3435 assert_eq!(warnings[1].1, BudgetSeverity::Exceeded);
3436 }
3437
3438 #[test]
3439 fn test_budget_severity_enum() {
3440 assert_ne!(BudgetSeverity::Warning, BudgetSeverity::Exceeded);
3441 assert_eq!(BudgetSeverity::Warning, BudgetSeverity::Warning);
3442 }
3443
3444 #[test]
3447 fn test_parse_action_details_file_read() {
3448 let args = serde_json::json!({"path": "src/lib.rs"});
3449 let details = Agent::parse_action_details("file_read", &args);
3450 match details {
3451 ActionDetails::FileRead { path } => {
3452 assert_eq!(path, std::path::PathBuf::from("src/lib.rs"));
3453 }
3454 other => panic!("Expected FileRead, got {:?}", other),
3455 }
3456 }
3457
3458 #[test]
3459 fn test_parse_action_details_file_list() {
3460 let args = serde_json::json!({"path": "src/"});
3461 let details = Agent::parse_action_details("file_list", &args);
3462 assert!(matches!(details, ActionDetails::FileRead { .. }));
3463 }
3464
3465 #[test]
3466 fn test_parse_action_details_file_write() {
3467 let args = serde_json::json!({"path": "x.rs", "content": "hello"});
3468 let details = Agent::parse_action_details("file_write", &args);
3469 match details {
3470 ActionDetails::FileWrite { path, size_bytes } => {
3471 assert_eq!(path, std::path::PathBuf::from("x.rs"));
3472 assert_eq!(size_bytes, 5); }
3474 other => panic!("Expected FileWrite, got {:?}", other),
3475 }
3476 }
3477
3478 #[test]
3479 fn test_parse_action_details_shell_exec() {
3480 let args = serde_json::json!({"command": "cargo test"});
3481 let details = Agent::parse_action_details("shell_exec", &args);
3482 match details {
3483 ActionDetails::ShellCommand { command } => {
3484 assert_eq!(command, "cargo test");
3485 }
3486 other => panic!("Expected ShellCommand, got {:?}", other),
3487 }
3488 }
3489
3490 #[test]
3491 fn test_parse_action_details_git_commit() {
3492 let args = serde_json::json!({"message": "fix bug"});
3493 let details = Agent::parse_action_details("git_commit", &args);
3494 match details {
3495 ActionDetails::GitOperation { operation } => {
3496 assert!(
3497 operation.contains("commit"),
3498 "Expected 'commit' in '{}'",
3499 operation
3500 );
3501 assert!(
3502 operation.contains("fix bug"),
3503 "Expected 'fix bug' in '{}'",
3504 operation
3505 );
3506 }
3507 other => panic!("Expected GitOperation, got {:?}", other),
3508 }
3509 }
3510
3511 #[test]
3512 fn test_parse_action_details_git_status() {
3513 let args = serde_json::json!({});
3514 let details = Agent::parse_action_details("git_status", &args);
3515 assert!(matches!(details, ActionDetails::GitOperation { .. }));
3516 }
3517
3518 #[test]
3519 fn test_parse_action_details_unknown_falls_back() {
3520 let args = serde_json::json!({"foo": "bar"});
3521 let details = Agent::parse_action_details("custom_tool", &args);
3522 assert!(matches!(details, ActionDetails::Other { .. }));
3523 }
3524
3525 #[test]
3526 fn test_build_approval_context_file_write_has_reasoning() {
3527 let details = ActionDetails::FileWrite {
3528 path: "test.rs".into(),
3529 size_bytes: 100,
3530 };
3531 let ctx = Agent::build_approval_context("file_write", &details, RiskLevel::Write);
3532 assert!(
3533 ctx.reasoning.is_some(),
3534 "FileWrite should produce reasoning"
3535 );
3536 let reasoning = ctx.reasoning.unwrap();
3537 assert!(
3538 reasoning.contains("100 bytes"),
3539 "Reasoning should mention size: {}",
3540 reasoning
3541 );
3542 assert!(
3543 !ctx.consequences.is_empty(),
3544 "FileWrite should have consequences"
3545 );
3546 }
3547
3548 #[test]
3549 fn test_build_approval_context_shell_command_has_reasoning() {
3550 let details = ActionDetails::ShellCommand {
3551 command: "rm -rf /tmp/test".to_string(),
3552 };
3553 let ctx = Agent::build_approval_context("shell_exec", &details, RiskLevel::Execute);
3554 assert!(ctx.reasoning.is_some());
3555 let reasoning = ctx.reasoning.unwrap();
3556 assert!(reasoning.contains("rm -rf"));
3557 }
3558
3559 struct SelectiveDenyCallback {
3563 deny_tools: Vec<String>,
3564 }
3565
3566 impl SelectiveDenyCallback {
3567 fn new(deny_tools: Vec<String>) -> Self {
3568 Self { deny_tools }
3569 }
3570 }
3571
3572 #[async_trait::async_trait]
3573 impl AgentCallback for SelectiveDenyCallback {
3574 async fn on_assistant_message(&self, _message: &str) {}
3575 async fn on_token(&self, _token: &str) {}
3576 async fn request_approval(&self, action: &ActionRequest) -> ApprovalDecision {
3577 if self.deny_tools.contains(&action.tool_name) {
3578 ApprovalDecision::Deny
3579 } else {
3580 ApprovalDecision::Approve
3581 }
3582 }
3583 async fn on_tool_start(&self, _tool_name: &str, _args: &serde_json::Value) {}
3584 async fn on_tool_result(&self, _tool_name: &str, _output: &ToolOutput, _duration_ms: u64) {}
3585 async fn on_status_change(&self, _status: AgentStatus) {}
3586 async fn on_usage_update(&self, _usage: &TokenUsage, _cost: &CostEstimate) {}
3587 async fn on_decision_explanation(&self, _explanation: &DecisionExplanation) {}
3588 }
3589
3590 #[tokio::test]
3591 async fn test_successful_tool_execution_records_fact() {
3592 let provider = Arc::new(MockLlmProvider::new());
3593 provider.queue_response(MockLlmProvider::tool_call_response(
3594 "echo",
3595 serde_json::json!({"text": "important finding about the code"}),
3596 ));
3597 provider.queue_response(MockLlmProvider::text_response("Done."));
3598
3599 let (mut agent, _callback) = create_test_agent(provider);
3600 agent.register_tool(RegisteredTool {
3601 definition: ToolDefinition {
3602 name: "echo".to_string(),
3603 description: "Echo text".to_string(),
3604 parameters: serde_json::json!({}),
3605 },
3606 risk_level: RiskLevel::ReadOnly,
3607 executor: Box::new(|args: serde_json::Value| {
3608 Box::pin(async move {
3609 let text = args["text"].as_str().unwrap_or("no text");
3610 Ok(ToolOutput::text(format!("Echo: {}", text)))
3611 })
3612 }),
3613 });
3614
3615 agent.process_task("Test echo").await.unwrap();
3616
3617 assert!(
3618 !agent.memory().long_term.facts.is_empty(),
3619 "Successful tool execution should record a fact"
3620 );
3621 let fact = &agent.memory().long_term.facts[0];
3622 assert!(
3623 fact.content.contains("echo"),
3624 "Fact should mention tool name: {}",
3625 fact.content
3626 );
3627 assert!(
3628 fact.tags.contains(&"tool_result".to_string()),
3629 "Fact should have 'tool_result' tag"
3630 );
3631 }
3632
3633 #[tokio::test]
3634 async fn test_short_tool_output_not_recorded() {
3635 let provider = Arc::new(MockLlmProvider::new());
3636 provider.queue_response(MockLlmProvider::tool_call_response(
3637 "echo",
3638 serde_json::json!({"text": "x"}),
3639 ));
3640 provider.queue_response(MockLlmProvider::text_response("Done."));
3641
3642 let (mut agent, _callback) = create_test_agent(provider);
3643 agent.register_tool(RegisteredTool {
3644 definition: ToolDefinition {
3645 name: "echo".to_string(),
3646 description: "Echo".to_string(),
3647 parameters: serde_json::json!({}),
3648 },
3649 risk_level: RiskLevel::ReadOnly,
3650 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("ok")) })),
3652 });
3653
3654 agent.process_task("Test").await.unwrap();
3655
3656 assert!(
3657 agent.memory().long_term.facts.is_empty(),
3658 "Short tool output (<10 chars) should NOT be recorded as fact"
3659 );
3660 }
3661
3662 #[tokio::test]
3663 async fn test_huge_tool_output_not_recorded() {
3664 let provider = Arc::new(MockLlmProvider::new());
3665 provider.queue_response(MockLlmProvider::tool_call_response(
3666 "echo",
3667 serde_json::json!({"text": "x"}),
3668 ));
3669 provider.queue_response(MockLlmProvider::text_response("Done."));
3670
3671 let (mut agent, _callback) = create_test_agent(provider);
3672 let huge = "x".repeat(10_000);
3673 agent.register_tool(RegisteredTool {
3674 definition: ToolDefinition {
3675 name: "echo".to_string(),
3676 description: "Echo".to_string(),
3677 parameters: serde_json::json!({}),
3678 },
3679 risk_level: RiskLevel::ReadOnly,
3680 executor: Box::new(move |_| {
3681 let h = huge.clone();
3682 Box::pin(async move { Ok(ToolOutput::text(h)) })
3683 }),
3684 });
3685
3686 agent.process_task("Test").await.unwrap();
3687
3688 assert!(
3689 agent.memory().long_term.facts.is_empty(),
3690 "Huge tool output (>5000 chars) should NOT be recorded as fact"
3691 );
3692 }
3693
3694 #[tokio::test]
3695 async fn test_user_denial_records_correction() {
3696 let provider = Arc::new(MockLlmProvider::new());
3697 provider.queue_response(MockLlmProvider::tool_call_response(
3699 "file_write",
3700 serde_json::json!({"path": "test.rs", "content": "bad code"}),
3701 ));
3702 provider.queue_response(MockLlmProvider::text_response("Understood, I won't write."));
3704
3705 let callback = Arc::new(SelectiveDenyCallback::new(vec!["file_write".to_string()]));
3706 let mut config = AgentConfig::default();
3707 config.llm.use_streaming = false;
3708 config.safety.approval_mode = crate::config::ApprovalMode::Paranoid;
3710
3711 let mut agent = Agent::new(provider, config, callback);
3712 agent.register_tool(RegisteredTool {
3713 definition: ToolDefinition {
3714 name: "file_write".to_string(),
3715 description: "Write file".to_string(),
3716 parameters: serde_json::json!({}),
3717 },
3718 risk_level: RiskLevel::Write,
3719 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("written")) })),
3720 });
3721
3722 agent.process_task("Write something").await.unwrap();
3723
3724 assert!(
3725 !agent.memory().long_term.corrections.is_empty(),
3726 "User denial should record a correction"
3727 );
3728 let correction = &agent.memory().long_term.corrections[0];
3729 assert!(
3730 correction.original.contains("file_write"),
3731 "Correction original should mention denied tool: {}",
3732 correction.original
3733 );
3734 assert!(
3735 correction.context.contains("denied"),
3736 "Correction context should mention denial: {}",
3737 correction.context
3738 );
3739 }
3740
3741 #[test]
3742 fn test_scheduler_fields_none_when_disabled() {
3743 let provider = Arc::new(MockLlmProvider::new());
3744 let (agent, _) = create_test_agent(provider);
3745 assert!(agent.cron_scheduler().is_none());
3747 }
3748
3749 #[test]
3750 fn test_save_scheduler_state_noop_when_disabled() {
3751 let provider = Arc::new(MockLlmProvider::new());
3752 let (agent, _) = create_test_agent(provider);
3753 let dir = tempfile::TempDir::new().unwrap();
3754 assert!(agent.save_scheduler_state(dir.path()).is_ok());
3756 }
3757
3758 #[test]
3759 fn test_load_scheduler_state_noop_when_disabled() {
3760 let provider = Arc::new(MockLlmProvider::new());
3761 let (mut agent, _) = create_test_agent(provider);
3762 let dir = tempfile::TempDir::new().unwrap();
3763 agent.load_scheduler_state(dir.path());
3765 assert!(agent.cron_scheduler().is_none());
3766 }
3767
3768 #[test]
3769 fn test_save_load_scheduler_roundtrip() {
3770 let provider = Arc::new(MockLlmProvider::new());
3771 let callback = Arc::new(RecordingCallback::new());
3772 let mut config = AgentConfig::default();
3773 config.llm.use_streaming = false;
3774 config.scheduler = Some(crate::config::SchedulerConfig {
3775 enabled: true,
3776 cron_jobs: vec![crate::scheduler::CronJobConfig::new(
3777 "test_job",
3778 "0 0 9 * * * *",
3779 "do something",
3780 )],
3781 ..Default::default()
3782 });
3783 let agent = Agent::new(provider.clone(), config, callback);
3784 assert_eq!(agent.cron_scheduler().unwrap().len(), 1);
3785
3786 let dir = tempfile::TempDir::new().unwrap();
3787 agent.save_scheduler_state(dir.path()).unwrap();
3788
3789 let callback2 = Arc::new(RecordingCallback::new());
3791 let mut config2 = AgentConfig::default();
3792 config2.llm.use_streaming = false;
3793 config2.scheduler = Some(crate::config::SchedulerConfig {
3794 enabled: true,
3795 cron_jobs: vec![],
3796 ..Default::default()
3797 });
3798 let mut agent2 = Agent::new(provider, config2, callback2);
3799 assert_eq!(agent2.cron_scheduler().unwrap().len(), 0);
3800
3801 agent2.load_scheduler_state(dir.path());
3802 assert_eq!(agent2.cron_scheduler().unwrap().len(), 1);
3803 }
3804
3805 #[test]
3806 fn test_tools_for_classification_calendar() {
3807 let set = Agent::tools_for_classification(&TaskClassification::Calendar)
3808 .expect("Calendar should return Some");
3809 assert!(set.contains("file_read"), "Missing core tool file_read");
3811 assert!(set.contains("ask_user"), "Missing core tool ask_user");
3812 assert!(set.contains("calculator"), "Missing core tool calculator");
3813 assert!(set.contains("macos_calendar"), "Missing macos_calendar");
3815 assert!(
3816 set.contains("macos_notification"),
3817 "Missing macos_notification"
3818 );
3819 assert!(
3821 !set.contains("macos_music"),
3822 "Should not include macos_music"
3823 );
3824 assert!(!set.contains("git_status"), "Should not include git_status");
3825 assert_eq!(set.len(), 12);
3827 }
3828
3829 #[test]
3830 fn test_tools_for_classification_general_returns_none() {
3831 assert!(
3832 Agent::tools_for_classification(&TaskClassification::General).is_none(),
3833 "General classification should return None (all tools)"
3834 );
3835 }
3836
3837 #[test]
3838 fn test_tools_for_classification_workflow_returns_none() {
3839 assert!(
3840 Agent::tools_for_classification(&TaskClassification::Workflow("security_scan".into()))
3841 .is_none(),
3842 "Workflow classification should return None (all tools)"
3843 );
3844 }
3845
3846 #[test]
3847 fn test_tool_definitions_filtered() {
3848 let provider = Arc::new(MockLlmProvider::new());
3849 let (mut agent, _) = create_test_agent(provider);
3850
3851 for name in &[
3853 "echo",
3854 "file_read",
3855 "macos_calendar",
3856 "git_status",
3857 "macos_music",
3858 ] {
3859 agent.register_tool(RegisteredTool {
3860 definition: ToolDefinition {
3861 name: name.to_string(),
3862 description: format!("{} tool", name),
3863 parameters: serde_json::json!({"type": "object"}),
3864 },
3865 risk_level: RiskLevel::ReadOnly,
3866 executor: Box::new(|_| Box::pin(async { Ok(ToolOutput::text("ok")) })),
3867 });
3868 }
3869
3870 let all_defs = agent.tool_definitions(None);
3872 assert_eq!(
3873 all_defs.len(),
3874 6,
3875 "Unfiltered should return all tools + ask_user"
3876 );
3877
3878 let calendar_defs = agent.tool_definitions(Some(&TaskClassification::Calendar));
3880 let names: Vec<&str> = calendar_defs.iter().map(|d| d.name.as_str()).collect();
3881 assert!(
3882 names.contains(&"echo"),
3883 "Calendar should include core tool echo"
3884 );
3885 assert!(
3886 names.contains(&"file_read"),
3887 "Calendar should include core tool file_read"
3888 );
3889 assert!(
3890 names.contains(&"macos_calendar"),
3891 "Calendar should include macos_calendar"
3892 );
3893 assert!(
3894 names.contains(&"ask_user"),
3895 "Should always include ask_user"
3896 );
3897 assert!(
3898 !names.contains(&"git_status"),
3899 "Calendar should NOT include git_status"
3900 );
3901 assert!(
3902 !names.contains(&"macos_music"),
3903 "Calendar should NOT include macos_music"
3904 );
3905
3906 let general_defs = agent.tool_definitions(Some(&TaskClassification::General));
3908 assert_eq!(general_defs.len(), 6, "General should return all tools");
3909 }
3910}