1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use thiserror::Error;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::Client as ApiClient;
10use crate::app::conversation::Message;
11use crate::app::domain::action::ApprovalDecision;
12use crate::app::domain::event::{CancellationInfo, OperationKind, SessionEvent};
13use crate::app::domain::session::EventStore;
14use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
15use crate::config::model::builtin::default_model;
16use crate::session::state::{
17 SessionConfig, SessionPolicyOverrides, SessionToolConfig, ToolApprovalPolicyOverrides,
18 ToolVisibility, WorkspaceConfig,
19};
20use crate::tools::{SessionMcpBackends, ToolExecutor};
21
22use super::interpreter::EffectInterpreter;
23use super::stepper::{AgentConfig, AgentInput, AgentOutput, AgentState, AgentStepper};
24
25#[derive(Clone, Default)]
26pub struct AgentInterpreterConfig {
27 pub auto_approve_tools: bool,
28 pub parent_session_id: Option<SessionId>,
29 pub session_config: Option<SessionConfig>,
30 pub session_backends: Option<Arc<SessionMcpBackends>>,
31}
32
33impl std::fmt::Debug for AgentInterpreterConfig {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("AgentInterpreterConfig")
36 .field("auto_approve_tools", &self.auto_approve_tools)
37 .field("parent_session_id", &self.parent_session_id)
38 .field("session_config", &self.session_config)
39 .field("session_backends", &self.session_backends.is_some())
40 .finish()
41 }
42}
43
44impl AgentInterpreterConfig {
45 pub fn for_sub_agent(parent_session_id: SessionId) -> Self {
46 Self {
47 auto_approve_tools: true,
48 parent_session_id: Some(parent_session_id),
49 session_config: None,
50 session_backends: None,
51 }
52 }
53}
54
55pub struct AgentInterpreter {
56 session_id: SessionId,
57 op_id: OpId,
58 config: AgentInterpreterConfig,
59 event_store: Arc<dyn EventStore>,
60 effect_interpreter: EffectInterpreter,
61}
62
63impl AgentInterpreter {
64 pub async fn new(
65 event_store: Arc<dyn EventStore>,
66 api_client: Arc<ApiClient>,
67 tool_executor: Arc<ToolExecutor>,
68 config: AgentInterpreterConfig,
69 ) -> Result<Self, AgentInterpreterError> {
70 let session_id = SessionId::new();
71 let op_id = OpId::new();
72
73 event_store
74 .create_session(session_id)
75 .await
76 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
77
78 let mut session_config = config
79 .session_config
80 .clone()
81 .unwrap_or_else(|| default_session_config(default_model()));
82 if session_config.parent_session_id.is_none() {
83 session_config.parent_session_id = config.parent_session_id;
84 }
85
86 let session_created_event = SessionEvent::SessionCreated {
87 config: Box::new(session_config),
88 metadata: HashMap::new(),
89 parent_session_id: config.parent_session_id,
90 };
91 event_store
92 .append(session_id, &session_created_event)
93 .await
94 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
95
96 let mut effect_interpreter =
97 EffectInterpreter::new(api_client, tool_executor).with_session(session_id);
98 if let Some(backends) = config.session_backends.clone() {
99 effect_interpreter = effect_interpreter.with_session_backends(backends);
100 }
101
102 Ok(Self {
103 session_id,
104 op_id,
105 config,
106 event_store,
107 effect_interpreter,
108 })
109 }
110
111 pub fn session_id(&self) -> SessionId {
112 self.session_id
113 }
114
115 pub async fn run(
116 &self,
117 agent_config: AgentConfig,
118 initial_messages: Vec<Message>,
119 message_tx: Option<mpsc::Sender<Message>>,
120 cancel_token: CancellationToken,
121 ) -> Result<Message, AgentInterpreterError> {
122 self.emit_event(SessionEvent::OperationStarted {
123 op_id: self.op_id,
124 kind: OperationKind::AgentLoop,
125 })
126 .await?;
127
128 let stepper = AgentStepper::new(agent_config.clone());
129 let mut state = AgentStepper::initial_state(initial_messages.clone());
130
131 let initial_outputs = vec![AgentOutput::CallModel {
132 model: agent_config.model.clone(),
133 messages: initial_messages,
134 system_context: Box::new(agent_config.system_context.clone()),
135 tools: agent_config.tools.clone(),
136 }];
137
138 let mut pending_outputs: VecDeque<AgentOutput> = VecDeque::from(initial_outputs);
139
140 loop {
141 if cancel_token.is_cancelled()
142 && !matches!(state, AgentState::Cancelled)
143 && !stepper.is_terminal(&state)
144 {
145 let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
146 state = new_state;
147 pending_outputs = VecDeque::from(outputs);
148 }
149
150 let output = if let Some(o) = pending_outputs.pop_front() {
151 o
152 } else {
153 if stepper.is_terminal(&state) {
154 match state {
155 AgentState::Complete { final_message } => {
156 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
157 .await?;
158 return Ok(final_message);
159 }
160 AgentState::Failed { error } => {
161 self.emit_event(SessionEvent::Error {
162 message: error.clone(),
163 })
164 .await?;
165 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
166 .await?;
167 return Err(AgentInterpreterError::Agent(error));
168 }
169 AgentState::Cancelled => {
170 self.emit_event(SessionEvent::OperationCancelled {
171 op_id: self.op_id,
172 info: CancellationInfo {
173 pending_tool_calls: 0,
174 popped_queued_item: None,
175 },
176 })
177 .await?;
178 return Err(AgentInterpreterError::Cancelled);
179 }
180 _ => unreachable!(),
181 }
182 }
183 return Err(AgentInterpreterError::Agent(
184 "Stepper stuck with no outputs".to_string(),
185 ));
186 };
187
188 match output {
189 AgentOutput::CallModel {
190 model,
191 messages,
192 system_context,
193 tools,
194 } => {
195 let result = self
196 .effect_interpreter
197 .call_model(
198 model.clone(),
199 messages,
200 *system_context,
201 tools,
202 cancel_token.clone(),
203 )
204 .await;
205
206 let message_id = MessageId::new();
207 let timestamp = current_timestamp();
208
209 let input = match result {
210 Ok(response) => {
211 let tool_calls: Vec<_> = response
212 .content
213 .iter()
214 .filter_map(|c| {
215 if let crate::app::conversation::AssistantContent::ToolCall {
216 tool_call,
217 ..
218 } = c
219 {
220 Some(tool_call.clone())
221 } else {
222 None
223 }
224 })
225 .collect();
226
227 AgentInput::ModelResponse {
228 content: response.content,
229 tool_calls,
230 message_id,
231 timestamp,
232 }
233 }
234 Err(error) => AgentInput::ModelError { error },
235 };
236
237 let (new_state, outputs) = stepper.step(state, input);
238 state = new_state;
239 pending_outputs.extend(outputs);
240 }
241
242 AgentOutput::RequestApproval { tool_call } => {
243 let tool_call_id = ToolCallId::from_string(&tool_call.id);
244 let request_id = RequestId::new();
245
246 self.emit_event(SessionEvent::ApprovalRequested {
247 request_id,
248 tool_call: tool_call.clone(),
249 })
250 .await?;
251
252 if !self.config.auto_approve_tools {
253 return Err(AgentInterpreterError::Agent(
254 "Interactive tool approval not supported in AgentInterpreter".into(),
255 ));
256 }
257
258 self.emit_event(SessionEvent::ApprovalDecided {
259 request_id,
260 decision: ApprovalDecision::Approved,
261 remember: None,
262 })
263 .await?;
264
265 let input = AgentInput::ToolApproved { tool_call_id };
266
267 let (new_state, outputs) = stepper.step(state, input);
268 state = new_state;
269 pending_outputs.extend(outputs);
270 }
271
272 AgentOutput::ExecuteTool { tool_call } => {
273 let tool_call_id = ToolCallId::from_string(&tool_call.id);
274
275 self.emit_event(SessionEvent::ToolCallStarted {
276 id: tool_call_id.clone(),
277 name: tool_call.name.clone(),
278 parameters: tool_call.parameters.clone(),
279 model: agent_config.model.clone(),
280 })
281 .await?;
282
283 let result = self
284 .effect_interpreter
285 .execute_tool(
286 tool_call.clone(),
287 Some(agent_config.model.clone()),
288 cancel_token.clone(),
289 )
290 .await;
291
292 let message_id = MessageId::new();
293 let timestamp = current_timestamp();
294
295 let input = match result {
296 Ok(tool_result) => {
297 self.emit_event(SessionEvent::ToolCallCompleted {
298 id: tool_call_id.clone(),
299 name: tool_call.name.clone(),
300 result: tool_result.clone(),
301 model: agent_config.model.clone(),
302 })
303 .await?;
304
305 AgentInput::ToolCompleted {
306 tool_call_id,
307 result: tool_result,
308 message_id,
309 timestamp,
310 }
311 }
312 Err(error) => {
313 self.emit_event(SessionEvent::ToolCallFailed {
314 id: tool_call_id.clone(),
315 name: tool_call.name.clone(),
316 error: error.to_string(),
317 model: agent_config.model.clone(),
318 })
319 .await?;
320
321 AgentInput::ToolFailed {
322 tool_call_id,
323 error,
324 message_id,
325 timestamp,
326 }
327 }
328 };
329
330 let (new_state, outputs) = stepper.step(state, input);
331 state = new_state;
332 pending_outputs.extend(outputs);
333 }
334
335 AgentOutput::EmitMessage { message } => {
336 self.emit_event(SessionEvent::AssistantMessageAdded {
337 message: message.clone(),
338 model: agent_config.model.clone(),
339 })
340 .await?;
341
342 if let Some(ref tx) = message_tx {
343 let _ = tx.send(message).await;
344 }
345 }
346
347 AgentOutput::Done { final_message } => {
348 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
349 .await?;
350 return Ok(final_message);
351 }
352
353 AgentOutput::Error { error } => {
354 self.emit_event(SessionEvent::Error {
355 message: error.clone(),
356 })
357 .await?;
358 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
359 .await?;
360 return Err(AgentInterpreterError::Agent(error));
361 }
362
363 AgentOutput::Cancelled => {
364 self.emit_event(SessionEvent::OperationCancelled {
365 op_id: self.op_id,
366 info: CancellationInfo {
367 pending_tool_calls: 0,
368 popped_queued_item: None,
369 },
370 })
371 .await?;
372 return Err(AgentInterpreterError::Cancelled);
373 }
374 }
375 }
376 }
377
378 async fn emit_event(&self, event: SessionEvent) -> Result<(), AgentInterpreterError> {
379 self.event_store
380 .append(self.session_id, &event)
381 .await
382 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
383 Ok(())
384 }
385}
386
387fn default_session_config(default_model: crate::config::model::ModelId) -> SessionConfig {
388 SessionConfig {
389 workspace: WorkspaceConfig::Local {
390 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
391 },
392 workspace_ref: None,
393 workspace_id: None,
394 repo_ref: None,
395 parent_session_id: None,
396 workspace_name: None,
397 tool_config: SessionToolConfig {
398 backends: Vec::new(),
399 visibility: ToolVisibility::All,
400 approval_policy: crate::session::state::ToolApprovalPolicy::default(),
401 metadata: HashMap::new(),
402 },
403 system_prompt: None,
404 primary_agent_id: None,
405 policy_overrides: SessionPolicyOverrides {
406 default_model: None,
407 tool_visibility: Some(ToolVisibility::ReadOnly),
408 approval_policy: ToolApprovalPolicyOverrides::empty(),
409 },
410 metadata: HashMap::new(),
411 default_model,
412 auto_compaction: crate::session::state::AutoCompactionConfig::default(),
413 }
414}
415
416fn current_timestamp() -> u64 {
417 SystemTime::now()
418 .duration_since(UNIX_EPOCH)
419 .unwrap_or_default()
420 .as_secs()
421}
422
423#[derive(Debug, Error)]
424pub enum AgentInterpreterError {
425 #[error("API error: {0}")]
426 Api(String),
427
428 #[error("Agent error: {0}")]
429 Agent(String),
430
431 #[error("Event store error: {0}")]
432 EventStore(String),
433
434 #[error("Cancelled")]
435 Cancelled,
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::api::error::ApiError;
442 use crate::api::provider::{CompletionResponse, Provider};
443 use crate::app::SystemContext;
444 use crate::app::conversation::AssistantContent;
445 use crate::app::domain::session::event_store::InMemoryEventStore;
446 use crate::app::validation::ValidatorRegistry;
447 use crate::auth::ProviderRegistry;
448 use crate::config::model::ModelId;
449 use crate::config::provider::ProviderId;
450 use crate::model_registry::ModelRegistry;
451 use crate::tools::BackendRegistry;
452 use async_trait::async_trait;
453 use steer_tools::ToolSchema;
454
455 #[derive(Clone)]
456 struct StubProvider {
457 cancel_on_complete: bool,
458 }
459
460 #[async_trait]
461 impl Provider for StubProvider {
462 fn name(&self) -> &'static str {
463 "stub"
464 }
465
466 async fn complete(
467 &self,
468 _model_id: &ModelId,
469 _messages: Vec<Message>,
470 _system: Option<SystemContext>,
471 _tools: Option<Vec<ToolSchema>>,
472 _call_options: Option<crate::config::model::ModelParameters>,
473 token: CancellationToken,
474 ) -> Result<CompletionResponse, ApiError> {
475 if self.cancel_on_complete {
476 token.cancel();
477 }
478
479 Ok(CompletionResponse {
480 content: vec![AssistantContent::Text {
481 text: "ok".to_string(),
482 }],
483 usage: None,
484 })
485 }
486 }
487
488 async fn create_test_deps() -> (Arc<dyn EventStore>, Arc<ApiClient>, Arc<ToolExecutor>) {
489 let event_store = Arc::new(InMemoryEventStore::new());
490 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
491 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
492 let api_client = Arc::new(ApiClient::new_with_deps(
493 crate::test_utils::test_llm_config_provider().unwrap(),
494 provider_registry,
495 model_registry,
496 ));
497
498 let tool_executor = Arc::new(ToolExecutor::with_components(
499 Arc::new(BackendRegistry::new()),
500 Arc::new(ValidatorRegistry::new()),
501 ));
502
503 (event_store, api_client, tool_executor)
504 }
505
506 #[tokio::test]
507 async fn test_cancel_after_completion_does_not_override_outputs() {
508 let (event_store, api_client, tool_executor) = create_test_deps().await;
509 let provider_id = ProviderId("stub".to_string());
510 let model_id = ModelId::new(provider_id.clone(), "stub-model");
511 api_client.insert_test_provider(
512 provider_id,
513 Arc::new(StubProvider {
514 cancel_on_complete: true,
515 }),
516 );
517
518 let interpreter = AgentInterpreter::new(
519 event_store.clone(),
520 api_client,
521 tool_executor,
522 AgentInterpreterConfig::default(),
523 )
524 .await
525 .expect("interpreter");
526
527 let cancel_token = CancellationToken::new();
528 let result = interpreter
529 .run(
530 AgentConfig {
531 model: model_id,
532 system_context: None,
533 tools: vec![],
534 },
535 vec![],
536 None,
537 cancel_token.clone(),
538 )
539 .await;
540
541 assert!(result.is_ok(), "expected run to complete, got {result:?}");
542 assert!(cancel_token.is_cancelled(), "cancel token should be set");
543
544 let events = event_store
545 .load_events(interpreter.session_id())
546 .await
547 .expect("load events");
548
549 assert!(
550 events
551 .iter()
552 .any(|(_, event)| matches!(event, SessionEvent::AssistantMessageAdded { .. })),
553 "assistant message should be emitted"
554 );
555 assert!(
556 events
557 .iter()
558 .any(|(_, event)| matches!(event, SessionEvent::OperationCompleted { .. })),
559 "operation should complete"
560 );
561 assert!(
562 !events
563 .iter()
564 .any(|(_, event)| matches!(event, SessionEvent::OperationCancelled { .. })),
565 "operation should not be cancelled"
566 );
567 }
568}