Skip to main content

traitclaw_core/
default_strategy.rs

1//! Default strategy — preserves v0.1.0 agent loop behavior.
2//!
3//! This module provides `DefaultStrategy`, which encapsulates the original
4//! agent runtime loop. When no custom strategy is configured, this strategy
5//! is used automatically, ensuring backward compatibility.
6
7use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::agent::{AgentOutput, RunUsage};
12use crate::traits::execution_strategy::PendingToolCall;
13use crate::traits::hook::HookAction;
14use crate::traits::strategy::{AgentRuntime, AgentStrategy};
15use crate::types::agent_state::AgentState;
16use crate::types::completion::{CompletionRequest, ResponseContent};
17use crate::types::message::Message;
18use crate::types::tool_call::ToolCall;
19use crate::Result;
20
21/// The default agent strategy preserving v0.1.0 behavior.
22///
23/// This strategy implements the standard agent loop:
24/// 1. Load context (memory + system prompt + user message)
25/// 2. Loop: LLM call → parse response → execute tools → repeat
26/// 3. Exit when LLM returns text instead of tool calls
27pub struct DefaultStrategy;
28
29#[async_trait]
30#[allow(deprecated)]
31impl AgentStrategy for DefaultStrategy {
32    #[tracing::instrument(skip_all, fields(session_id = session_id, model = %runtime.provider.model_info().name))]
33    async fn execute(
34        &self,
35        runtime: &AgentRuntime,
36        input: &str,
37        session_id: &str,
38    ) -> Result<AgentOutput> {
39        let start = Instant::now();
40        let model_info = runtime.provider.model_info();
41
42        // Fire on_agent_start hooks
43        for hook in &runtime.hooks {
44            hook.on_agent_start(input).await;
45        }
46
47        let mut state = AgentState::new(model_info.tier, model_info.context_window);
48        if let Some(budget) = runtime.config.token_budget {
49            state.token_budget = budget;
50        }
51
52        let mut messages = match load_context(runtime, session_id, input).await {
53            Ok(msgs) => msgs,
54            Err(e) => {
55                for hook in &runtime.hooks {
56                    hook.on_error(&e).await;
57                }
58                return Err(e);
59            }
60        };
61        let tool_schemas = runtime.tools.iter().map(|t| t.schema()).collect::<Vec<_>>();
62
63        // === Agent Loop ===
64        for _iteration in 0..runtime.config.max_iterations {
65            state.iteration_count += 1;
66            runtime.tracker.on_iteration(&mut state);
67
68            inject_hints(runtime, &state, &mut messages);
69
70            runtime
71                .context_manager
72                .prepare(&mut messages, model_info.context_window, &mut state)
73                .await;
74
75            let request = CompletionRequest {
76                model: model_info.name.clone(),
77                messages: messages.clone(),
78                tools: tool_schemas.clone(),
79                max_tokens: runtime.config.max_tokens,
80                temperature: runtime.config.temperature,
81                response_format: None,
82                stream: false,
83            };
84
85            // Fire on_provider_start hooks
86            for hook in &runtime.hooks {
87                hook.on_provider_start(&request).await;
88            }
89
90            let provider_start = Instant::now();
91            let response = match runtime.provider.complete(request).await {
92                Ok(res) => res,
93                Err(e) => {
94                    for hook in &runtime.hooks {
95                        hook.on_error(&e).await;
96                    }
97                    return Err(e);
98                }
99            };
100            let provider_duration = provider_start.elapsed();
101
102            // Fire on_provider_end hooks
103            for hook in &runtime.hooks {
104                hook.on_provider_end(&response, provider_duration).await;
105            }
106
107            state.token_usage += response.usage.total_tokens;
108            state.total_context_tokens = response.usage.prompt_tokens;
109            runtime.tracker.on_llm_response(&response, &mut state);
110
111            match response.content {
112                ResponseContent::Text(text) => {
113                    let assistant_msg = Message::assistant(&text);
114                    if let Err(e) = runtime.memory.append(session_id, assistant_msg).await {
115                        tracing::warn!("Failed to save assistant response to memory: {e}");
116                    }
117
118                    let usage = RunUsage {
119                        tokens: state.token_usage,
120                        iterations: state.iteration_count,
121                        duration: start.elapsed(),
122                    };
123
124                    #[allow(clippy::cast_possible_truncation)]
125                    let duration_ms = usage.duration.as_millis() as u64;
126
127                    tracing::info!(
128                        iterations = usage.iterations,
129                        tokens = usage.tokens,
130                        duration_ms,
131                        "Agent completed"
132                    );
133
134                    let output = AgentOutput::text_with_usage(text, usage);
135
136                    // Fire on_agent_end hooks
137                    for hook in &runtime.hooks {
138                        hook.on_agent_end(&output, start.elapsed()).await;
139                    }
140
141                    return Ok(output);
142                }
143                ResponseContent::ToolCalls(tool_calls) => {
144                    process_tool_calls(runtime, &tool_calls, &state, &mut messages).await;
145                }
146            }
147        }
148
149        let err = crate::Error::Runtime(format!(
150            "Agent reached maximum iterations ({})",
151            runtime.config.max_iterations
152        ));
153
154        // Fire on_error hooks
155        for hook in &runtime.hooks {
156            hook.on_error(&err).await;
157        }
158
159        Err(err)
160    }
161
162    fn stream(
163        &self,
164        runtime: &AgentRuntime,
165        input: &str,
166        session_id: &str,
167    ) -> std::pin::Pin<
168        Box<dyn tokio_stream::Stream<Item = Result<crate::types::stream::StreamEvent>> + Send>,
169    > {
170        // Fire synchronous starting hooks if any (currently all hooks in v0.2.0 are async so we emit them inside the spawned stream task)
171        crate::streaming::stream_runtime(runtime.clone(), input.to_string(), session_id.to_string())
172    }
173}
174
175/// Load conversation context: history + system prompt + user message.
176async fn load_context(
177    runtime: &AgentRuntime,
178    session_id: &str,
179    input: &str,
180) -> Result<Vec<Message>> {
181    let mut messages = runtime
182        .memory
183        .messages(session_id)
184        .await
185        .unwrap_or_else(|e| {
186            tracing::warn!("Failed to load memory (continuing fresh): {e}");
187            Vec::new()
188        });
189
190    if let Some(ref system_prompt) = runtime.config.system_prompt {
191        if messages.is_empty() || messages[0].role != crate::types::message::MessageRole::System {
192            messages.insert(0, Message::system(system_prompt));
193        }
194    }
195
196    let user_msg = Message::user(input);
197    messages.push(user_msg.clone());
198
199    if let Err(e) = runtime.memory.append(session_id, user_msg).await {
200        tracing::warn!("Failed to save user message to memory: {e}");
201    }
202
203    Ok(messages)
204}
205
206/// Check hints and inject guidance messages.
207fn inject_hints(runtime: &AgentRuntime, state: &AgentState, messages: &mut Vec<Message>) {
208    for hint in &runtime.hints {
209        if hint.should_trigger(state) {
210            let hint_msg = hint.generate(state);
211            messages.push(Message {
212                role: hint_msg.role,
213                content: hint_msg.content,
214                tool_call_id: None,
215            });
216            tracing::debug!(hint = hint.name(), "Hint injected");
217        }
218    }
219}
220
221/// Process tool calls with hook interception support.
222#[allow(deprecated)]
223async fn process_tool_calls(
224    runtime: &AgentRuntime,
225    tool_calls: &[ToolCall],
226    state: &AgentState,
227    messages: &mut Vec<Message>,
228) {
229    if tool_calls.is_empty() {
230        tracing::debug!("process_tool_calls: empty tool-call slice, skipping");
231        return;
232    }
233
234    let summary: Vec<String> = tool_calls
235        .iter()
236        .map(|tc| format!("{}({})", tc.name, tc.arguments))
237        .collect();
238    messages.push(Message::assistant(format!(
239        "[Tool calls: {}]",
240        summary.join(", ")
241    )));
242
243    // Check hooks for interception before executing
244    for tc in tool_calls {
245        let mut blocked = false;
246
247        for hook in &runtime.hooks {
248            if let HookAction::Block(reason) =
249                hook.before_tool_execute(&tc.name, &tc.arguments).await
250            {
251                messages.push(Message::tool_result(&tc.id, &reason));
252                tracing::debug!(
253                    tool = tc.name.as_str(),
254                    reason = reason.as_str(),
255                    "Tool blocked by hook"
256                );
257                blocked = true;
258                break;
259            }
260        }
261
262        if blocked {
263            continue;
264        }
265
266        let tool_start = Instant::now();
267
268        // Execute single tool via execution strategy
269        let pending = vec![PendingToolCall::from(tc)];
270        let results = runtime
271            .execution_strategy
272            .execute_batch(pending, &runtime.tools, &runtime.guards, state)
273            .await;
274
275        for result in results {
276            let processed = runtime
277                .output_transformer
278                .transform(result.output, &tc.name, state)
279                .await;
280
281            // Fire after_tool_execute hooks
282            for hook in &runtime.hooks {
283                hook.after_tool_execute(&tc.name, &processed, tool_start.elapsed())
284                    .await;
285            }
286
287            messages.push(Message::tool_result(&result.id, &processed));
288            tracing::debug!(tool_call_id = result.id.as_str(), "Tool call processed");
289        }
290    }
291}