traitclaw_core/
default_strategy.rs1use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::agent::{AgentOutput, RunUsage};
12use crate::traits::execution_strategy::PendingToolCall;
13use crate::traits::hook::HookAction;
14use crate::traits::strategy::{AgentRuntime, AgentStrategy};
15use crate::types::agent_state::AgentState;
16use crate::types::completion::{CompletionRequest, ResponseContent};
17use crate::types::message::Message;
18use crate::types::tool_call::ToolCall;
19use crate::Result;
20
21pub struct DefaultStrategy;
28
29#[async_trait]
30#[allow(deprecated)]
31impl AgentStrategy for DefaultStrategy {
32 #[tracing::instrument(skip_all, fields(session_id = session_id, model = %runtime.provider.model_info().name))]
33 async fn execute(
34 &self,
35 runtime: &AgentRuntime,
36 input: &str,
37 session_id: &str,
38 ) -> Result<AgentOutput> {
39 let start = Instant::now();
40 let model_info = runtime.provider.model_info();
41
42 for hook in &runtime.hooks {
44 hook.on_agent_start(input).await;
45 }
46
47 let mut state = AgentState::new(model_info.tier, model_info.context_window);
48 if let Some(budget) = runtime.config.token_budget {
49 state.token_budget = budget;
50 }
51
52 let mut messages = match load_context(runtime, session_id, input).await {
53 Ok(msgs) => msgs,
54 Err(e) => {
55 for hook in &runtime.hooks {
56 hook.on_error(&e).await;
57 }
58 return Err(e);
59 }
60 };
61 let tool_schemas = runtime.tools.iter().map(|t| t.schema()).collect::<Vec<_>>();
62
63 for _iteration in 0..runtime.config.max_iterations {
65 state.iteration_count += 1;
66 runtime.tracker.on_iteration(&mut state);
67
68 inject_hints(runtime, &state, &mut messages);
69
70 runtime
71 .context_manager
72 .prepare(&mut messages, model_info.context_window, &mut state)
73 .await;
74
75 let request = CompletionRequest {
76 model: model_info.name.clone(),
77 messages: messages.clone(),
78 tools: tool_schemas.clone(),
79 max_tokens: runtime.config.max_tokens,
80 temperature: runtime.config.temperature,
81 response_format: None,
82 stream: false,
83 };
84
85 for hook in &runtime.hooks {
87 hook.on_provider_start(&request).await;
88 }
89
90 let provider_start = Instant::now();
91 let response = match runtime.provider.complete(request).await {
92 Ok(res) => res,
93 Err(e) => {
94 for hook in &runtime.hooks {
95 hook.on_error(&e).await;
96 }
97 return Err(e);
98 }
99 };
100 let provider_duration = provider_start.elapsed();
101
102 for hook in &runtime.hooks {
104 hook.on_provider_end(&response, provider_duration).await;
105 }
106
107 state.token_usage += response.usage.total_tokens;
108 state.total_context_tokens = response.usage.prompt_tokens;
109 runtime.tracker.on_llm_response(&response, &mut state);
110
111 match response.content {
112 ResponseContent::Text(text) => {
113 let assistant_msg = Message::assistant(&text);
114 if let Err(e) = runtime.memory.append(session_id, assistant_msg).await {
115 tracing::warn!("Failed to save assistant response to memory: {e}");
116 }
117
118 let usage = RunUsage {
119 tokens: state.token_usage,
120 iterations: state.iteration_count,
121 duration: start.elapsed(),
122 };
123
124 #[allow(clippy::cast_possible_truncation)]
125 let duration_ms = usage.duration.as_millis() as u64;
126
127 tracing::info!(
128 iterations = usage.iterations,
129 tokens = usage.tokens,
130 duration_ms,
131 "Agent completed"
132 );
133
134 let output = AgentOutput::text_with_usage(text, usage);
135
136 for hook in &runtime.hooks {
138 hook.on_agent_end(&output, start.elapsed()).await;
139 }
140
141 return Ok(output);
142 }
143 ResponseContent::ToolCalls(tool_calls) => {
144 process_tool_calls(runtime, &tool_calls, &state, &mut messages).await;
145 }
146 }
147 }
148
149 let err = crate::Error::Runtime(format!(
150 "Agent reached maximum iterations ({})",
151 runtime.config.max_iterations
152 ));
153
154 for hook in &runtime.hooks {
156 hook.on_error(&err).await;
157 }
158
159 Err(err)
160 }
161
162 fn stream(
163 &self,
164 runtime: &AgentRuntime,
165 input: &str,
166 session_id: &str,
167 ) -> std::pin::Pin<
168 Box<dyn tokio_stream::Stream<Item = Result<crate::types::stream::StreamEvent>> + Send>,
169 > {
170 crate::streaming::stream_runtime(runtime.clone(), input.to_string(), session_id.to_string())
172 }
173}
174
175async fn load_context(
177 runtime: &AgentRuntime,
178 session_id: &str,
179 input: &str,
180) -> Result<Vec<Message>> {
181 let mut messages = runtime
182 .memory
183 .messages(session_id)
184 .await
185 .unwrap_or_else(|e| {
186 tracing::warn!("Failed to load memory (continuing fresh): {e}");
187 Vec::new()
188 });
189
190 if let Some(ref system_prompt) = runtime.config.system_prompt {
191 if messages.is_empty() || messages[0].role != crate::types::message::MessageRole::System {
192 messages.insert(0, Message::system(system_prompt));
193 }
194 }
195
196 let user_msg = Message::user(input);
197 messages.push(user_msg.clone());
198
199 if let Err(e) = runtime.memory.append(session_id, user_msg).await {
200 tracing::warn!("Failed to save user message to memory: {e}");
201 }
202
203 Ok(messages)
204}
205
206fn inject_hints(runtime: &AgentRuntime, state: &AgentState, messages: &mut Vec<Message>) {
208 for hint in &runtime.hints {
209 if hint.should_trigger(state) {
210 let hint_msg = hint.generate(state);
211 messages.push(Message {
212 role: hint_msg.role,
213 content: hint_msg.content,
214 tool_call_id: None,
215 });
216 tracing::debug!(hint = hint.name(), "Hint injected");
217 }
218 }
219}
220
221#[allow(deprecated)]
223async fn process_tool_calls(
224 runtime: &AgentRuntime,
225 tool_calls: &[ToolCall],
226 state: &AgentState,
227 messages: &mut Vec<Message>,
228) {
229 if tool_calls.is_empty() {
230 tracing::debug!("process_tool_calls: empty tool-call slice, skipping");
231 return;
232 }
233
234 let summary: Vec<String> = tool_calls
235 .iter()
236 .map(|tc| format!("{}({})", tc.name, tc.arguments))
237 .collect();
238 messages.push(Message::assistant(format!(
239 "[Tool calls: {}]",
240 summary.join(", ")
241 )));
242
243 for tc in tool_calls {
245 let mut blocked = false;
246
247 for hook in &runtime.hooks {
248 if let HookAction::Block(reason) =
249 hook.before_tool_execute(&tc.name, &tc.arguments).await
250 {
251 messages.push(Message::tool_result(&tc.id, &reason));
252 tracing::debug!(
253 tool = tc.name.as_str(),
254 reason = reason.as_str(),
255 "Tool blocked by hook"
256 );
257 blocked = true;
258 break;
259 }
260 }
261
262 if blocked {
263 continue;
264 }
265
266 let tool_start = Instant::now();
267
268 let pending = vec![PendingToolCall::from(tc)];
270 let results = runtime
271 .execution_strategy
272 .execute_batch(pending, &runtime.tools, &runtime.guards, state)
273 .await;
274
275 for result in results {
276 let processed = runtime
277 .output_transformer
278 .transform(result.output, &tc.name, state)
279 .await;
280
281 for hook in &runtime.hooks {
283 hook.after_tool_execute(&tc.name, &processed, tool_start.elapsed())
284 .await;
285 }
286
287 messages.push(Message::tool_result(&result.id, &processed));
288 tracing::debug!(tool_call_id = result.id.as_str(), "Tool call processed");
289 }
290 }
291}