1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use thiserror::Error;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::Client as ApiClient;
10use crate::app::conversation::Message;
11use crate::app::domain::action::ApprovalDecision;
12use crate::app::domain::event::{CancellationInfo, OperationKind, SessionEvent};
13use crate::app::domain::session::EventStore;
14use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
15use crate::config::model::builtin::default_model;
16use crate::session::state::{
17 SessionConfig, SessionPolicyOverrides, SessionToolConfig, ToolApprovalPolicyOverrides,
18 ToolVisibility, WorkspaceConfig,
19};
20use crate::tools::{SessionMcpBackends, ToolExecutor};
21
22use super::interpreter::EffectInterpreter;
23use super::stepper::{AgentConfig, AgentInput, AgentOutput, AgentState, AgentStepper};
24
25#[derive(Clone, Default)]
26pub struct AgentInterpreterConfig {
27 pub auto_approve_tools: bool,
28 pub parent_session_id: Option<SessionId>,
29 pub session_config: Option<SessionConfig>,
30 pub session_backends: Option<Arc<SessionMcpBackends>>,
31}
32
33impl std::fmt::Debug for AgentInterpreterConfig {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("AgentInterpreterConfig")
36 .field("auto_approve_tools", &self.auto_approve_tools)
37 .field("parent_session_id", &self.parent_session_id)
38 .field("session_config", &self.session_config)
39 .field("session_backends", &self.session_backends.is_some())
40 .finish()
41 }
42}
43
44impl AgentInterpreterConfig {
45 pub fn for_sub_agent(parent_session_id: SessionId) -> Self {
46 Self {
47 auto_approve_tools: true,
48 parent_session_id: Some(parent_session_id),
49 session_config: None,
50 session_backends: None,
51 }
52 }
53}
54
55pub struct AgentInterpreter {
56 session_id: SessionId,
57 op_id: OpId,
58 config: AgentInterpreterConfig,
59 event_store: Arc<dyn EventStore>,
60 effect_interpreter: EffectInterpreter,
61}
62
63impl AgentInterpreter {
64 pub async fn new(
65 event_store: Arc<dyn EventStore>,
66 api_client: Arc<ApiClient>,
67 tool_executor: Arc<ToolExecutor>,
68 config: AgentInterpreterConfig,
69 ) -> Result<Self, AgentInterpreterError> {
70 let session_id = SessionId::new();
71 let op_id = OpId::new();
72
73 event_store
74 .create_session(session_id)
75 .await
76 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
77
78 let mut session_config = config
79 .session_config
80 .clone()
81 .unwrap_or_else(|| default_session_config(default_model()));
82 if session_config.parent_session_id.is_none() {
83 session_config.parent_session_id = config.parent_session_id;
84 }
85
86 let session_created_event = SessionEvent::SessionCreated {
87 config: Box::new(session_config),
88 metadata: HashMap::new(),
89 parent_session_id: config.parent_session_id,
90 };
91 event_store
92 .append(session_id, &session_created_event)
93 .await
94 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
95
96 let mut effect_interpreter =
97 EffectInterpreter::new(api_client, tool_executor).with_session(session_id);
98 if let Some(backends) = config.session_backends.clone() {
99 effect_interpreter = effect_interpreter.with_session_backends(backends);
100 }
101
102 Ok(Self {
103 session_id,
104 op_id,
105 config,
106 event_store,
107 effect_interpreter,
108 })
109 }
110
111 pub fn session_id(&self) -> SessionId {
112 self.session_id
113 }
114
115 pub async fn run(
116 &self,
117 agent_config: AgentConfig,
118 initial_messages: Vec<Message>,
119 message_tx: Option<mpsc::Sender<Message>>,
120 cancel_token: CancellationToken,
121 ) -> Result<Message, AgentInterpreterError> {
122 self.emit_event(SessionEvent::OperationStarted {
123 op_id: self.op_id,
124 kind: OperationKind::AgentLoop,
125 })
126 .await?;
127
128 let stepper = AgentStepper::new(agent_config.clone());
129 let mut state = AgentStepper::initial_state(initial_messages.clone());
130
131 let initial_outputs = vec![AgentOutput::CallModel {
132 model: agent_config.model.clone(),
133 messages: initial_messages,
134 system_context: Box::new(agent_config.system_context.clone()),
135 tools: agent_config.tools.clone(),
136 }];
137
138 let mut pending_outputs: VecDeque<AgentOutput> = VecDeque::from(initial_outputs);
139
140 loop {
141 if cancel_token.is_cancelled()
142 && !matches!(state, AgentState::Cancelled)
143 && !stepper.is_terminal(&state)
144 {
145 let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
146 state = new_state;
147 pending_outputs = VecDeque::from(outputs);
148 }
149
150 let output = if let Some(o) = pending_outputs.pop_front() {
151 o
152 } else {
153 if stepper.is_terminal(&state) {
154 match state {
155 AgentState::Complete { final_message } => {
156 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
157 .await?;
158 return Ok(final_message);
159 }
160 AgentState::Failed { error } => {
161 self.emit_event(SessionEvent::Error {
162 message: error.clone(),
163 })
164 .await?;
165 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
166 .await?;
167 return Err(AgentInterpreterError::Agent(error));
168 }
169 AgentState::Cancelled => {
170 self.emit_event(SessionEvent::OperationCancelled {
171 op_id: self.op_id,
172 info: CancellationInfo {
173 pending_tool_calls: 0,
174 popped_queued_item: None,
175 },
176 })
177 .await?;
178 return Err(AgentInterpreterError::Cancelled);
179 }
180 _ => unreachable!(),
181 }
182 }
183 return Err(AgentInterpreterError::Agent(
184 "Stepper stuck with no outputs".to_string(),
185 ));
186 };
187
188 match output {
189 AgentOutput::CallModel {
190 model,
191 messages,
192 system_context,
193 tools,
194 } => {
195 let result = self
196 .effect_interpreter
197 .call_model(
198 model.clone(),
199 messages,
200 *system_context,
201 tools,
202 cancel_token.clone(),
203 )
204 .await;
205
206 let message_id = MessageId::new();
207 let timestamp = current_timestamp();
208
209 let input = match result {
210 Ok(content) => {
211 let tool_calls: Vec<_> = content
212 .iter()
213 .filter_map(|c| {
214 if let crate::app::conversation::AssistantContent::ToolCall {
215 tool_call,
216 ..
217 } = c
218 {
219 Some(tool_call.clone())
220 } else {
221 None
222 }
223 })
224 .collect();
225
226 AgentInput::ModelResponse {
227 content,
228 tool_calls,
229 message_id,
230 timestamp,
231 }
232 }
233 Err(error) => AgentInput::ModelError { error },
234 };
235
236 let (new_state, outputs) = stepper.step(state, input);
237 state = new_state;
238 pending_outputs.extend(outputs);
239 }
240
241 AgentOutput::RequestApproval { tool_call } => {
242 let tool_call_id = ToolCallId::from_string(&tool_call.id);
243 let request_id = RequestId::new();
244
245 self.emit_event(SessionEvent::ApprovalRequested {
246 request_id,
247 tool_call: tool_call.clone(),
248 })
249 .await?;
250
251 if !self.config.auto_approve_tools {
252 return Err(AgentInterpreterError::Agent(
253 "Interactive tool approval not supported in AgentInterpreter".into(),
254 ));
255 }
256
257 self.emit_event(SessionEvent::ApprovalDecided {
258 request_id,
259 decision: ApprovalDecision::Approved,
260 remember: None,
261 })
262 .await?;
263
264 let input = AgentInput::ToolApproved { tool_call_id };
265
266 let (new_state, outputs) = stepper.step(state, input);
267 state = new_state;
268 pending_outputs.extend(outputs);
269 }
270
271 AgentOutput::ExecuteTool { tool_call } => {
272 let tool_call_id = ToolCallId::from_string(&tool_call.id);
273
274 self.emit_event(SessionEvent::ToolCallStarted {
275 id: tool_call_id.clone(),
276 name: tool_call.name.clone(),
277 parameters: tool_call.parameters.clone(),
278 model: agent_config.model.clone(),
279 })
280 .await?;
281
282 let result = self
283 .effect_interpreter
284 .execute_tool(tool_call.clone(), cancel_token.clone())
285 .await;
286
287 let message_id = MessageId::new();
288 let timestamp = current_timestamp();
289
290 let input = match result {
291 Ok(tool_result) => {
292 self.emit_event(SessionEvent::ToolCallCompleted {
293 id: tool_call_id.clone(),
294 name: tool_call.name.clone(),
295 result: tool_result.clone(),
296 model: agent_config.model.clone(),
297 })
298 .await?;
299
300 AgentInput::ToolCompleted {
301 tool_call_id,
302 result: tool_result,
303 message_id,
304 timestamp,
305 }
306 }
307 Err(error) => {
308 self.emit_event(SessionEvent::ToolCallFailed {
309 id: tool_call_id.clone(),
310 name: tool_call.name.clone(),
311 error: error.to_string(),
312 model: agent_config.model.clone(),
313 })
314 .await?;
315
316 AgentInput::ToolFailed {
317 tool_call_id,
318 error,
319 message_id,
320 timestamp,
321 }
322 }
323 };
324
325 let (new_state, outputs) = stepper.step(state, input);
326 state = new_state;
327 pending_outputs.extend(outputs);
328 }
329
330 AgentOutput::EmitMessage { message } => {
331 self.emit_event(SessionEvent::AssistantMessageAdded {
332 message: message.clone(),
333 model: agent_config.model.clone(),
334 })
335 .await?;
336
337 if let Some(ref tx) = message_tx {
338 let _ = tx.send(message).await;
339 }
340 }
341
342 AgentOutput::Done { final_message } => {
343 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
344 .await?;
345 return Ok(final_message);
346 }
347
348 AgentOutput::Error { error } => {
349 self.emit_event(SessionEvent::Error {
350 message: error.clone(),
351 })
352 .await?;
353 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
354 .await?;
355 return Err(AgentInterpreterError::Agent(error));
356 }
357
358 AgentOutput::Cancelled => {
359 self.emit_event(SessionEvent::OperationCancelled {
360 op_id: self.op_id,
361 info: CancellationInfo {
362 pending_tool_calls: 0,
363 popped_queued_item: None,
364 },
365 })
366 .await?;
367 return Err(AgentInterpreterError::Cancelled);
368 }
369 }
370 }
371 }
372
373 async fn emit_event(&self, event: SessionEvent) -> Result<(), AgentInterpreterError> {
374 self.event_store
375 .append(self.session_id, &event)
376 .await
377 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
378 Ok(())
379 }
380}
381
382fn default_session_config(default_model: crate::config::model::ModelId) -> SessionConfig {
383 SessionConfig {
384 workspace: WorkspaceConfig::Local {
385 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
386 },
387 workspace_ref: None,
388 workspace_id: None,
389 repo_ref: None,
390 parent_session_id: None,
391 workspace_name: None,
392 tool_config: SessionToolConfig {
393 backends: Vec::new(),
394 visibility: ToolVisibility::All,
395 approval_policy: crate::session::state::ToolApprovalPolicy::default(),
396 metadata: HashMap::new(),
397 },
398 system_prompt: None,
399 primary_agent_id: None,
400 policy_overrides: SessionPolicyOverrides {
401 default_model: None,
402 tool_visibility: Some(ToolVisibility::ReadOnly),
403 approval_policy: ToolApprovalPolicyOverrides::empty(),
404 },
405 metadata: HashMap::new(),
406 default_model,
407 }
408}
409
410fn current_timestamp() -> u64 {
411 SystemTime::now()
412 .duration_since(UNIX_EPOCH)
413 .unwrap_or_default()
414 .as_secs()
415}
416
417#[derive(Debug, Error)]
418pub enum AgentInterpreterError {
419 #[error("API error: {0}")]
420 Api(String),
421
422 #[error("Agent error: {0}")]
423 Agent(String),
424
425 #[error("Event store error: {0}")]
426 EventStore(String),
427
428 #[error("Cancelled")]
429 Cancelled,
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::api::error::ApiError;
436 use crate::api::provider::{CompletionResponse, Provider};
437 use crate::app::SystemContext;
438 use crate::app::conversation::AssistantContent;
439 use crate::app::domain::session::event_store::InMemoryEventStore;
440 use crate::app::validation::ValidatorRegistry;
441 use crate::auth::ProviderRegistry;
442 use crate::config::model::ModelId;
443 use crate::config::provider::ProviderId;
444 use crate::model_registry::ModelRegistry;
445 use crate::tools::BackendRegistry;
446 use async_trait::async_trait;
447 use steer_tools::ToolSchema;
448
449 #[derive(Clone)]
450 struct StubProvider {
451 cancel_on_complete: bool,
452 }
453
454 #[async_trait]
455 impl Provider for StubProvider {
456 fn name(&self) -> &'static str {
457 "stub"
458 }
459
460 async fn complete(
461 &self,
462 _model_id: &ModelId,
463 _messages: Vec<Message>,
464 _system: Option<SystemContext>,
465 _tools: Option<Vec<ToolSchema>>,
466 _call_options: Option<crate::config::model::ModelParameters>,
467 token: CancellationToken,
468 ) -> Result<CompletionResponse, ApiError> {
469 if self.cancel_on_complete {
470 token.cancel();
471 }
472
473 Ok(CompletionResponse {
474 content: vec![AssistantContent::Text {
475 text: "ok".to_string(),
476 }],
477 })
478 }
479 }
480
481 async fn create_test_deps() -> (Arc<dyn EventStore>, Arc<ApiClient>, Arc<ToolExecutor>) {
482 let event_store = Arc::new(InMemoryEventStore::new());
483 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
484 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
485 let api_client = Arc::new(ApiClient::new_with_deps(
486 crate::test_utils::test_llm_config_provider().unwrap(),
487 provider_registry,
488 model_registry,
489 ));
490
491 let tool_executor = Arc::new(ToolExecutor::with_components(
492 Arc::new(BackendRegistry::new()),
493 Arc::new(ValidatorRegistry::new()),
494 ));
495
496 (event_store, api_client, tool_executor)
497 }
498
499 #[tokio::test]
500 async fn test_cancel_after_completion_does_not_override_outputs() {
501 let (event_store, api_client, tool_executor) = create_test_deps().await;
502 let provider_id = ProviderId("stub".to_string());
503 let model_id = ModelId::new(provider_id.clone(), "stub-model");
504 api_client.insert_test_provider(
505 provider_id,
506 Arc::new(StubProvider {
507 cancel_on_complete: true,
508 }),
509 );
510
511 let interpreter = AgentInterpreter::new(
512 event_store.clone(),
513 api_client,
514 tool_executor,
515 AgentInterpreterConfig::default(),
516 )
517 .await
518 .expect("interpreter");
519
520 let cancel_token = CancellationToken::new();
521 let result = interpreter
522 .run(
523 AgentConfig {
524 model: model_id,
525 system_context: None,
526 tools: vec![],
527 },
528 vec![],
529 None,
530 cancel_token.clone(),
531 )
532 .await;
533
534 assert!(result.is_ok(), "expected run to complete, got {result:?}");
535 assert!(cancel_token.is_cancelled(), "cancel token should be set");
536
537 let events = event_store
538 .load_events(interpreter.session_id())
539 .await
540 .expect("load events");
541
542 assert!(
543 events
544 .iter()
545 .any(|(_, event)| matches!(event, SessionEvent::AssistantMessageAdded { .. })),
546 "assistant message should be emitted"
547 );
548 assert!(
549 events
550 .iter()
551 .any(|(_, event)| matches!(event, SessionEvent::OperationCompleted { .. })),
552 "operation should complete"
553 );
554 assert!(
555 !events
556 .iter()
557 .any(|(_, event)| matches!(event, SessionEvent::OperationCancelled { .. })),
558 "operation should not be cancelled"
559 );
560 }
561}