1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use thiserror::Error;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::Client as ApiClient;
10use crate::app::conversation::Message;
11use crate::app::domain::action::ApprovalDecision;
12use crate::app::domain::event::{CancellationInfo, OperationKind, SessionEvent};
13use crate::app::domain::session::EventStore;
14use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
15use crate::config::model::builtin::default_model;
16use crate::session::state::{
17 SessionConfig, SessionPolicyOverrides, SessionToolConfig, ToolApprovalPolicyOverrides,
18 ToolVisibility, WorkspaceConfig,
19};
20use crate::tools::{SessionMcpBackends, ToolExecutor};
21
22use super::interpreter::EffectInterpreter;
23use super::stepper::{AgentConfig, AgentInput, AgentOutput, AgentState, AgentStepper};
24
25#[derive(Clone, Default)]
26pub struct AgentInterpreterConfig {
27 pub auto_approve_tools: bool,
28 pub parent_session_id: Option<SessionId>,
29 pub session_config: Option<SessionConfig>,
30 pub session_backends: Option<Arc<SessionMcpBackends>>,
31}
32
33impl std::fmt::Debug for AgentInterpreterConfig {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("AgentInterpreterConfig")
36 .field("auto_approve_tools", &self.auto_approve_tools)
37 .field("parent_session_id", &self.parent_session_id)
38 .field("session_config", &self.session_config)
39 .field("session_backends", &self.session_backends.is_some())
40 .finish()
41 }
42}
43
44impl AgentInterpreterConfig {
45 pub fn for_sub_agent(parent_session_id: SessionId) -> Self {
46 Self {
47 auto_approve_tools: true,
48 parent_session_id: Some(parent_session_id),
49 session_config: None,
50 session_backends: None,
51 }
52 }
53}
54
55pub struct AgentInterpreter {
56 session_id: SessionId,
57 op_id: OpId,
58 config: AgentInterpreterConfig,
59 event_store: Arc<dyn EventStore>,
60 effect_interpreter: EffectInterpreter,
61}
62
63impl AgentInterpreter {
64 pub async fn new(
65 event_store: Arc<dyn EventStore>,
66 api_client: Arc<ApiClient>,
67 tool_executor: Arc<ToolExecutor>,
68 config: AgentInterpreterConfig,
69 ) -> Result<Self, AgentInterpreterError> {
70 let session_id = SessionId::new();
71 let op_id = OpId::new();
72
73 event_store
74 .create_session(session_id)
75 .await
76 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
77
78 let mut session_config = config
79 .session_config
80 .clone()
81 .unwrap_or_else(|| default_session_config(default_model()));
82 if session_config.parent_session_id.is_none() {
83 session_config.parent_session_id = config.parent_session_id;
84 }
85
86 let session_created_event = SessionEvent::SessionCreated {
87 config: Box::new(session_config),
88 metadata: HashMap::new(),
89 parent_session_id: config.parent_session_id,
90 };
91 event_store
92 .append(session_id, &session_created_event)
93 .await
94 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
95
96 let mut effect_interpreter =
97 EffectInterpreter::new(api_client, tool_executor).with_session(session_id);
98 if let Some(backends) = config.session_backends.clone() {
99 effect_interpreter = effect_interpreter.with_session_backends(backends);
100 }
101
102 Ok(Self {
103 session_id,
104 op_id,
105 config,
106 event_store,
107 effect_interpreter,
108 })
109 }
110
111 pub fn session_id(&self) -> SessionId {
112 self.session_id
113 }
114
115 pub async fn run(
116 &self,
117 agent_config: AgentConfig,
118 initial_messages: Vec<Message>,
119 message_tx: Option<mpsc::Sender<Message>>,
120 cancel_token: CancellationToken,
121 ) -> Result<Message, AgentInterpreterError> {
122 self.emit_event(SessionEvent::OperationStarted {
123 op_id: self.op_id,
124 kind: OperationKind::AgentLoop,
125 })
126 .await?;
127
128 let stepper = AgentStepper::new(agent_config.clone());
129 let mut state = AgentStepper::initial_state(initial_messages.clone());
130
131 let initial_outputs = vec![AgentOutput::CallModel {
132 model: agent_config.model.clone(),
133 messages: initial_messages,
134 system_context: Box::new(agent_config.system_context.clone()),
135 tools: agent_config.tools.clone(),
136 }];
137
138 let mut pending_outputs: VecDeque<AgentOutput> = VecDeque::from(initial_outputs);
139
140 loop {
141 if cancel_token.is_cancelled()
142 && !matches!(state, AgentState::Cancelled)
143 && !stepper.is_terminal(&state)
144 {
145 let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
146 state = new_state;
147 pending_outputs = VecDeque::from(outputs);
148 }
149
150 let output = if let Some(o) = pending_outputs.pop_front() {
151 o
152 } else {
153 if stepper.is_terminal(&state) {
154 match state {
155 AgentState::Complete { final_message } => {
156 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
157 .await?;
158 return Ok(final_message);
159 }
160 AgentState::Failed { error } => {
161 self.emit_event(SessionEvent::Error {
162 message: error.clone(),
163 })
164 .await?;
165 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
166 .await?;
167 return Err(AgentInterpreterError::Agent(error));
168 }
169 AgentState::Cancelled => {
170 self.emit_event(SessionEvent::OperationCancelled {
171 op_id: self.op_id,
172 info: CancellationInfo {
173 pending_tool_calls: 0,
174 popped_queued_item: None,
175 },
176 })
177 .await?;
178 return Err(AgentInterpreterError::Cancelled);
179 }
180 _ => unreachable!(),
181 }
182 }
183 return Err(AgentInterpreterError::Agent(
184 "Stepper stuck with no outputs".to_string(),
185 ));
186 };
187
188 match output {
189 AgentOutput::CallModel {
190 model,
191 messages,
192 system_context,
193 tools,
194 } => {
195 let result = self
196 .effect_interpreter
197 .call_model(
198 model.clone(),
199 messages,
200 *system_context,
201 tools,
202 cancel_token.clone(),
203 )
204 .await;
205
206 let message_id = MessageId::new();
207 let timestamp = current_timestamp();
208
209 let input = match result {
210 Ok(response) => {
211 let tool_calls: Vec<_> = response
212 .content
213 .iter()
214 .filter_map(|c| {
215 if let crate::app::conversation::AssistantContent::ToolCall {
216 tool_call,
217 ..
218 } = c
219 {
220 Some(tool_call.clone())
221 } else {
222 None
223 }
224 })
225 .collect();
226
227 AgentInput::ModelResponse {
228 content: response.content,
229 tool_calls,
230 message_id,
231 timestamp,
232 }
233 }
234 Err(error) => AgentInput::ModelError { error },
235 };
236
237 let (new_state, outputs) = stepper.step(state, input);
238 state = new_state;
239 pending_outputs.extend(outputs);
240 }
241
242 AgentOutput::RequestApproval { tool_call } => {
243 let tool_call_id = ToolCallId::from_string(&tool_call.id);
244 let request_id = RequestId::new();
245
246 self.emit_event(SessionEvent::ApprovalRequested {
247 request_id,
248 tool_call: tool_call.clone(),
249 })
250 .await?;
251
252 if !self.config.auto_approve_tools {
253 return Err(AgentInterpreterError::Agent(
254 "Interactive tool approval not supported in AgentInterpreter".into(),
255 ));
256 }
257
258 self.emit_event(SessionEvent::ApprovalDecided {
259 request_id,
260 decision: ApprovalDecision::Approved,
261 remember: None,
262 })
263 .await?;
264
265 let input = AgentInput::ToolApproved { tool_call_id };
266
267 let (new_state, outputs) = stepper.step(state, input);
268 state = new_state;
269 pending_outputs.extend(outputs);
270 }
271
272 AgentOutput::ExecuteTool { tool_call } => {
273 let tool_call_id = ToolCallId::from_string(&tool_call.id);
274
275 self.emit_event(SessionEvent::ToolCallStarted {
276 id: tool_call_id.clone(),
277 name: tool_call.name.clone(),
278 parameters: tool_call.parameters.clone(),
279 model: agent_config.model.clone(),
280 })
281 .await?;
282
283 let result = self
284 .effect_interpreter
285 .execute_tool(tool_call.clone(), cancel_token.clone())
286 .await;
287
288 let message_id = MessageId::new();
289 let timestamp = current_timestamp();
290
291 let input = match result {
292 Ok(tool_result) => {
293 self.emit_event(SessionEvent::ToolCallCompleted {
294 id: tool_call_id.clone(),
295 name: tool_call.name.clone(),
296 result: tool_result.clone(),
297 model: agent_config.model.clone(),
298 })
299 .await?;
300
301 AgentInput::ToolCompleted {
302 tool_call_id,
303 result: tool_result,
304 message_id,
305 timestamp,
306 }
307 }
308 Err(error) => {
309 self.emit_event(SessionEvent::ToolCallFailed {
310 id: tool_call_id.clone(),
311 name: tool_call.name.clone(),
312 error: error.to_string(),
313 model: agent_config.model.clone(),
314 })
315 .await?;
316
317 AgentInput::ToolFailed {
318 tool_call_id,
319 error,
320 message_id,
321 timestamp,
322 }
323 }
324 };
325
326 let (new_state, outputs) = stepper.step(state, input);
327 state = new_state;
328 pending_outputs.extend(outputs);
329 }
330
331 AgentOutput::EmitMessage { message } => {
332 self.emit_event(SessionEvent::AssistantMessageAdded {
333 message: message.clone(),
334 model: agent_config.model.clone(),
335 })
336 .await?;
337
338 if let Some(ref tx) = message_tx {
339 let _ = tx.send(message).await;
340 }
341 }
342
343 AgentOutput::Done { final_message } => {
344 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
345 .await?;
346 return Ok(final_message);
347 }
348
349 AgentOutput::Error { error } => {
350 self.emit_event(SessionEvent::Error {
351 message: error.clone(),
352 })
353 .await?;
354 self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
355 .await?;
356 return Err(AgentInterpreterError::Agent(error));
357 }
358
359 AgentOutput::Cancelled => {
360 self.emit_event(SessionEvent::OperationCancelled {
361 op_id: self.op_id,
362 info: CancellationInfo {
363 pending_tool_calls: 0,
364 popped_queued_item: None,
365 },
366 })
367 .await?;
368 return Err(AgentInterpreterError::Cancelled);
369 }
370 }
371 }
372 }
373
374 async fn emit_event(&self, event: SessionEvent) -> Result<(), AgentInterpreterError> {
375 self.event_store
376 .append(self.session_id, &event)
377 .await
378 .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
379 Ok(())
380 }
381}
382
383fn default_session_config(default_model: crate::config::model::ModelId) -> SessionConfig {
384 SessionConfig {
385 workspace: WorkspaceConfig::Local {
386 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
387 },
388 workspace_ref: None,
389 workspace_id: None,
390 repo_ref: None,
391 parent_session_id: None,
392 workspace_name: None,
393 tool_config: SessionToolConfig {
394 backends: Vec::new(),
395 visibility: ToolVisibility::All,
396 approval_policy: crate::session::state::ToolApprovalPolicy::default(),
397 metadata: HashMap::new(),
398 },
399 system_prompt: None,
400 primary_agent_id: None,
401 policy_overrides: SessionPolicyOverrides {
402 default_model: None,
403 tool_visibility: Some(ToolVisibility::ReadOnly),
404 approval_policy: ToolApprovalPolicyOverrides::empty(),
405 },
406 metadata: HashMap::new(),
407 default_model,
408 auto_compaction: crate::session::state::AutoCompactionConfig::default(),
409 }
410}
411
412fn current_timestamp() -> u64 {
413 SystemTime::now()
414 .duration_since(UNIX_EPOCH)
415 .unwrap_or_default()
416 .as_secs()
417}
418
419#[derive(Debug, Error)]
420pub enum AgentInterpreterError {
421 #[error("API error: {0}")]
422 Api(String),
423
424 #[error("Agent error: {0}")]
425 Agent(String),
426
427 #[error("Event store error: {0}")]
428 EventStore(String),
429
430 #[error("Cancelled")]
431 Cancelled,
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::api::error::ApiError;
438 use crate::api::provider::{CompletionResponse, Provider};
439 use crate::app::SystemContext;
440 use crate::app::conversation::AssistantContent;
441 use crate::app::domain::session::event_store::InMemoryEventStore;
442 use crate::app::validation::ValidatorRegistry;
443 use crate::auth::ProviderRegistry;
444 use crate::config::model::ModelId;
445 use crate::config::provider::ProviderId;
446 use crate::model_registry::ModelRegistry;
447 use crate::tools::BackendRegistry;
448 use async_trait::async_trait;
449 use steer_tools::ToolSchema;
450
451 #[derive(Clone)]
452 struct StubProvider {
453 cancel_on_complete: bool,
454 }
455
456 #[async_trait]
457 impl Provider for StubProvider {
458 fn name(&self) -> &'static str {
459 "stub"
460 }
461
462 async fn complete(
463 &self,
464 _model_id: &ModelId,
465 _messages: Vec<Message>,
466 _system: Option<SystemContext>,
467 _tools: Option<Vec<ToolSchema>>,
468 _call_options: Option<crate::config::model::ModelParameters>,
469 token: CancellationToken,
470 ) -> Result<CompletionResponse, ApiError> {
471 if self.cancel_on_complete {
472 token.cancel();
473 }
474
475 Ok(CompletionResponse {
476 content: vec![AssistantContent::Text {
477 text: "ok".to_string(),
478 }],
479 usage: None,
480 })
481 }
482 }
483
484 async fn create_test_deps() -> (Arc<dyn EventStore>, Arc<ApiClient>, Arc<ToolExecutor>) {
485 let event_store = Arc::new(InMemoryEventStore::new());
486 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
487 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
488 let api_client = Arc::new(ApiClient::new_with_deps(
489 crate::test_utils::test_llm_config_provider().unwrap(),
490 provider_registry,
491 model_registry,
492 ));
493
494 let tool_executor = Arc::new(ToolExecutor::with_components(
495 Arc::new(BackendRegistry::new()),
496 Arc::new(ValidatorRegistry::new()),
497 ));
498
499 (event_store, api_client, tool_executor)
500 }
501
502 #[tokio::test]
503 async fn test_cancel_after_completion_does_not_override_outputs() {
504 let (event_store, api_client, tool_executor) = create_test_deps().await;
505 let provider_id = ProviderId("stub".to_string());
506 let model_id = ModelId::new(provider_id.clone(), "stub-model");
507 api_client.insert_test_provider(
508 provider_id,
509 Arc::new(StubProvider {
510 cancel_on_complete: true,
511 }),
512 );
513
514 let interpreter = AgentInterpreter::new(
515 event_store.clone(),
516 api_client,
517 tool_executor,
518 AgentInterpreterConfig::default(),
519 )
520 .await
521 .expect("interpreter");
522
523 let cancel_token = CancellationToken::new();
524 let result = interpreter
525 .run(
526 AgentConfig {
527 model: model_id,
528 system_context: None,
529 tools: vec![],
530 },
531 vec![],
532 None,
533 cancel_token.clone(),
534 )
535 .await;
536
537 assert!(result.is_ok(), "expected run to complete, got {result:?}");
538 assert!(cancel_token.is_cancelled(), "cancel token should be set");
539
540 let events = event_store
541 .load_events(interpreter.session_id())
542 .await
543 .expect("load events");
544
545 assert!(
546 events
547 .iter()
548 .any(|(_, event)| matches!(event, SessionEvent::AssistantMessageAdded { .. })),
549 "assistant message should be emitted"
550 );
551 assert!(
552 events
553 .iter()
554 .any(|(_, event)| matches!(event, SessionEvent::OperationCompleted { .. })),
555 "operation should complete"
556 );
557 assert!(
558 !events
559 .iter()
560 .any(|(_, event)| matches!(event, SessionEvent::OperationCancelled { .. })),
561 "operation should not be cancelled"
562 );
563 }
564}