Skip to main content

steer_core/app/domain/runtime/
agent_interpreter.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use thiserror::Error;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::Client as ApiClient;
10use crate::app::conversation::Message;
11use crate::app::domain::action::ApprovalDecision;
12use crate::app::domain::event::{CancellationInfo, OperationKind, SessionEvent};
13use crate::app::domain::session::EventStore;
14use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
15use crate::config::model::builtin::default_model;
16use crate::session::state::{
17    SessionConfig, SessionPolicyOverrides, SessionToolConfig, ToolApprovalPolicyOverrides,
18    ToolVisibility, WorkspaceConfig,
19};
20use crate::tools::{SessionMcpBackends, ToolExecutor};
21
22use super::interpreter::EffectInterpreter;
23use super::stepper::{AgentConfig, AgentInput, AgentOutput, AgentState, AgentStepper};
24
25#[derive(Clone, Default)]
26pub struct AgentInterpreterConfig {
27    pub auto_approve_tools: bool,
28    pub parent_session_id: Option<SessionId>,
29    pub session_config: Option<SessionConfig>,
30    pub session_backends: Option<Arc<SessionMcpBackends>>,
31}
32
33impl std::fmt::Debug for AgentInterpreterConfig {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("AgentInterpreterConfig")
36            .field("auto_approve_tools", &self.auto_approve_tools)
37            .field("parent_session_id", &self.parent_session_id)
38            .field("session_config", &self.session_config)
39            .field("session_backends", &self.session_backends.is_some())
40            .finish()
41    }
42}
43
44impl AgentInterpreterConfig {
45    pub fn for_sub_agent(parent_session_id: SessionId) -> Self {
46        Self {
47            auto_approve_tools: true,
48            parent_session_id: Some(parent_session_id),
49            session_config: None,
50            session_backends: None,
51        }
52    }
53}
54
55pub struct AgentInterpreter {
56    session_id: SessionId,
57    op_id: OpId,
58    config: AgentInterpreterConfig,
59    event_store: Arc<dyn EventStore>,
60    effect_interpreter: EffectInterpreter,
61}
62
63impl AgentInterpreter {
64    pub async fn new(
65        event_store: Arc<dyn EventStore>,
66        api_client: Arc<ApiClient>,
67        tool_executor: Arc<ToolExecutor>,
68        config: AgentInterpreterConfig,
69    ) -> Result<Self, AgentInterpreterError> {
70        let session_id = SessionId::new();
71        let op_id = OpId::new();
72
73        event_store
74            .create_session(session_id)
75            .await
76            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
77
78        let mut session_config = config
79            .session_config
80            .clone()
81            .unwrap_or_else(|| default_session_config(default_model()));
82        if session_config.parent_session_id.is_none() {
83            session_config.parent_session_id = config.parent_session_id;
84        }
85
86        let session_created_event = SessionEvent::SessionCreated {
87            config: Box::new(session_config),
88            metadata: HashMap::new(),
89            parent_session_id: config.parent_session_id,
90        };
91        event_store
92            .append(session_id, &session_created_event)
93            .await
94            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
95
96        let mut effect_interpreter =
97            EffectInterpreter::new(api_client, tool_executor).with_session(session_id);
98        if let Some(backends) = config.session_backends.clone() {
99            effect_interpreter = effect_interpreter.with_session_backends(backends);
100        }
101
102        Ok(Self {
103            session_id,
104            op_id,
105            config,
106            event_store,
107            effect_interpreter,
108        })
109    }
110
111    pub fn session_id(&self) -> SessionId {
112        self.session_id
113    }
114
115    pub async fn run(
116        &self,
117        agent_config: AgentConfig,
118        initial_messages: Vec<Message>,
119        message_tx: Option<mpsc::Sender<Message>>,
120        cancel_token: CancellationToken,
121    ) -> Result<Message, AgentInterpreterError> {
122        self.emit_event(SessionEvent::OperationStarted {
123            op_id: self.op_id,
124            kind: OperationKind::AgentLoop,
125        })
126        .await?;
127
128        let stepper = AgentStepper::new(agent_config.clone());
129        let mut state = AgentStepper::initial_state(initial_messages.clone());
130
131        let initial_outputs = vec![AgentOutput::CallModel {
132            model: agent_config.model.clone(),
133            messages: initial_messages,
134            system_context: Box::new(agent_config.system_context.clone()),
135            tools: agent_config.tools.clone(),
136        }];
137
138        let mut pending_outputs: VecDeque<AgentOutput> = VecDeque::from(initial_outputs);
139
140        loop {
141            if cancel_token.is_cancelled()
142                && !matches!(state, AgentState::Cancelled)
143                && !stepper.is_terminal(&state)
144            {
145                let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
146                state = new_state;
147                pending_outputs = VecDeque::from(outputs);
148            }
149
150            let output = if let Some(o) = pending_outputs.pop_front() {
151                o
152            } else {
153                if stepper.is_terminal(&state) {
154                    match state {
155                        AgentState::Complete { final_message } => {
156                            self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
157                                .await?;
158                            return Ok(final_message);
159                        }
160                        AgentState::Failed { error } => {
161                            self.emit_event(SessionEvent::Error {
162                                message: error.clone(),
163                            })
164                            .await?;
165                            self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
166                                .await?;
167                            return Err(AgentInterpreterError::Agent(error));
168                        }
169                        AgentState::Cancelled => {
170                            self.emit_event(SessionEvent::OperationCancelled {
171                                op_id: self.op_id,
172                                info: CancellationInfo {
173                                    pending_tool_calls: 0,
174                                    popped_queued_item: None,
175                                },
176                            })
177                            .await?;
178                            return Err(AgentInterpreterError::Cancelled);
179                        }
180                        _ => unreachable!(),
181                    }
182                }
183                return Err(AgentInterpreterError::Agent(
184                    "Stepper stuck with no outputs".to_string(),
185                ));
186            };
187
188            match output {
189                AgentOutput::CallModel {
190                    model,
191                    messages,
192                    system_context,
193                    tools,
194                } => {
195                    let result = self
196                        .effect_interpreter
197                        .call_model(
198                            model.clone(),
199                            messages,
200                            *system_context,
201                            tools,
202                            cancel_token.clone(),
203                        )
204                        .await;
205
206                    let message_id = MessageId::new();
207                    let timestamp = current_timestamp();
208
209                    let input = match result {
210                        Ok(content) => {
211                            let tool_calls: Vec<_> = content
212                                .iter()
213                                .filter_map(|c| {
214                                    if let crate::app::conversation::AssistantContent::ToolCall {
215                                        tool_call,
216                                        ..
217                                    } = c
218                                    {
219                                        Some(tool_call.clone())
220                                    } else {
221                                        None
222                                    }
223                                })
224                                .collect();
225
226                            AgentInput::ModelResponse {
227                                content,
228                                tool_calls,
229                                message_id,
230                                timestamp,
231                            }
232                        }
233                        Err(error) => AgentInput::ModelError { error },
234                    };
235
236                    let (new_state, outputs) = stepper.step(state, input);
237                    state = new_state;
238                    pending_outputs.extend(outputs);
239                }
240
241                AgentOutput::RequestApproval { tool_call } => {
242                    let tool_call_id = ToolCallId::from_string(&tool_call.id);
243                    let request_id = RequestId::new();
244
245                    self.emit_event(SessionEvent::ApprovalRequested {
246                        request_id,
247                        tool_call: tool_call.clone(),
248                    })
249                    .await?;
250
251                    if !self.config.auto_approve_tools {
252                        return Err(AgentInterpreterError::Agent(
253                            "Interactive tool approval not supported in AgentInterpreter".into(),
254                        ));
255                    }
256
257                    self.emit_event(SessionEvent::ApprovalDecided {
258                        request_id,
259                        decision: ApprovalDecision::Approved,
260                        remember: None,
261                    })
262                    .await?;
263
264                    let input = AgentInput::ToolApproved { tool_call_id };
265
266                    let (new_state, outputs) = stepper.step(state, input);
267                    state = new_state;
268                    pending_outputs.extend(outputs);
269                }
270
271                AgentOutput::ExecuteTool { tool_call } => {
272                    let tool_call_id = ToolCallId::from_string(&tool_call.id);
273
274                    self.emit_event(SessionEvent::ToolCallStarted {
275                        id: tool_call_id.clone(),
276                        name: tool_call.name.clone(),
277                        parameters: tool_call.parameters.clone(),
278                        model: agent_config.model.clone(),
279                    })
280                    .await?;
281
282                    let result = self
283                        .effect_interpreter
284                        .execute_tool(tool_call.clone(), cancel_token.clone())
285                        .await;
286
287                    let message_id = MessageId::new();
288                    let timestamp = current_timestamp();
289
290                    let input = match result {
291                        Ok(tool_result) => {
292                            self.emit_event(SessionEvent::ToolCallCompleted {
293                                id: tool_call_id.clone(),
294                                name: tool_call.name.clone(),
295                                result: tool_result.clone(),
296                                model: agent_config.model.clone(),
297                            })
298                            .await?;
299
300                            AgentInput::ToolCompleted {
301                                tool_call_id,
302                                result: tool_result,
303                                message_id,
304                                timestamp,
305                            }
306                        }
307                        Err(error) => {
308                            self.emit_event(SessionEvent::ToolCallFailed {
309                                id: tool_call_id.clone(),
310                                name: tool_call.name.clone(),
311                                error: error.to_string(),
312                                model: agent_config.model.clone(),
313                            })
314                            .await?;
315
316                            AgentInput::ToolFailed {
317                                tool_call_id,
318                                error,
319                                message_id,
320                                timestamp,
321                            }
322                        }
323                    };
324
325                    let (new_state, outputs) = stepper.step(state, input);
326                    state = new_state;
327                    pending_outputs.extend(outputs);
328                }
329
330                AgentOutput::EmitMessage { message } => {
331                    self.emit_event(SessionEvent::AssistantMessageAdded {
332                        message: message.clone(),
333                        model: agent_config.model.clone(),
334                    })
335                    .await?;
336
337                    if let Some(ref tx) = message_tx {
338                        let _ = tx.send(message).await;
339                    }
340                }
341
342                AgentOutput::Done { final_message } => {
343                    self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
344                        .await?;
345                    return Ok(final_message);
346                }
347
348                AgentOutput::Error { error } => {
349                    self.emit_event(SessionEvent::Error {
350                        message: error.clone(),
351                    })
352                    .await?;
353                    self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
354                        .await?;
355                    return Err(AgentInterpreterError::Agent(error));
356                }
357
358                AgentOutput::Cancelled => {
359                    self.emit_event(SessionEvent::OperationCancelled {
360                        op_id: self.op_id,
361                        info: CancellationInfo {
362                            pending_tool_calls: 0,
363                            popped_queued_item: None,
364                        },
365                    })
366                    .await?;
367                    return Err(AgentInterpreterError::Cancelled);
368                }
369            }
370        }
371    }
372
373    async fn emit_event(&self, event: SessionEvent) -> Result<(), AgentInterpreterError> {
374        self.event_store
375            .append(self.session_id, &event)
376            .await
377            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
378        Ok(())
379    }
380}
381
382fn default_session_config(default_model: crate::config::model::ModelId) -> SessionConfig {
383    SessionConfig {
384        workspace: WorkspaceConfig::Local {
385            path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
386        },
387        workspace_ref: None,
388        workspace_id: None,
389        repo_ref: None,
390        parent_session_id: None,
391        workspace_name: None,
392        tool_config: SessionToolConfig {
393            backends: Vec::new(),
394            visibility: ToolVisibility::All,
395            approval_policy: crate::session::state::ToolApprovalPolicy::default(),
396            metadata: HashMap::new(),
397        },
398        system_prompt: None,
399        primary_agent_id: None,
400        policy_overrides: SessionPolicyOverrides {
401            default_model: None,
402            tool_visibility: Some(ToolVisibility::ReadOnly),
403            approval_policy: ToolApprovalPolicyOverrides::empty(),
404        },
405        metadata: HashMap::new(),
406        default_model,
407    }
408}
409
410fn current_timestamp() -> u64 {
411    SystemTime::now()
412        .duration_since(UNIX_EPOCH)
413        .unwrap_or_default()
414        .as_secs()
415}
416
417#[derive(Debug, Error)]
418pub enum AgentInterpreterError {
419    #[error("API error: {0}")]
420    Api(String),
421
422    #[error("Agent error: {0}")]
423    Agent(String),
424
425    #[error("Event store error: {0}")]
426    EventStore(String),
427
428    #[error("Cancelled")]
429    Cancelled,
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::api::error::ApiError;
436    use crate::api::provider::{CompletionResponse, Provider};
437    use crate::app::SystemContext;
438    use crate::app::conversation::AssistantContent;
439    use crate::app::domain::session::event_store::InMemoryEventStore;
440    use crate::app::validation::ValidatorRegistry;
441    use crate::auth::ProviderRegistry;
442    use crate::config::model::ModelId;
443    use crate::config::provider::ProviderId;
444    use crate::model_registry::ModelRegistry;
445    use crate::tools::BackendRegistry;
446    use async_trait::async_trait;
447    use steer_tools::ToolSchema;
448
449    #[derive(Clone)]
450    struct StubProvider {
451        cancel_on_complete: bool,
452    }
453
454    #[async_trait]
455    impl Provider for StubProvider {
456        fn name(&self) -> &'static str {
457            "stub"
458        }
459
460        async fn complete(
461            &self,
462            _model_id: &ModelId,
463            _messages: Vec<Message>,
464            _system: Option<SystemContext>,
465            _tools: Option<Vec<ToolSchema>>,
466            _call_options: Option<crate::config::model::ModelParameters>,
467            token: CancellationToken,
468        ) -> Result<CompletionResponse, ApiError> {
469            if self.cancel_on_complete {
470                token.cancel();
471            }
472
473            Ok(CompletionResponse {
474                content: vec![AssistantContent::Text {
475                    text: "ok".to_string(),
476                }],
477            })
478        }
479    }
480
481    async fn create_test_deps() -> (Arc<dyn EventStore>, Arc<ApiClient>, Arc<ToolExecutor>) {
482        let event_store = Arc::new(InMemoryEventStore::new());
483        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
484        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
485        let api_client = Arc::new(ApiClient::new_with_deps(
486            crate::test_utils::test_llm_config_provider().unwrap(),
487            provider_registry,
488            model_registry,
489        ));
490
491        let tool_executor = Arc::new(ToolExecutor::with_components(
492            Arc::new(BackendRegistry::new()),
493            Arc::new(ValidatorRegistry::new()),
494        ));
495
496        (event_store, api_client, tool_executor)
497    }
498
499    #[tokio::test]
500    async fn test_cancel_after_completion_does_not_override_outputs() {
501        let (event_store, api_client, tool_executor) = create_test_deps().await;
502        let provider_id = ProviderId("stub".to_string());
503        let model_id = ModelId::new(provider_id.clone(), "stub-model");
504        api_client.insert_test_provider(
505            provider_id,
506            Arc::new(StubProvider {
507                cancel_on_complete: true,
508            }),
509        );
510
511        let interpreter = AgentInterpreter::new(
512            event_store.clone(),
513            api_client,
514            tool_executor,
515            AgentInterpreterConfig::default(),
516        )
517        .await
518        .expect("interpreter");
519
520        let cancel_token = CancellationToken::new();
521        let result = interpreter
522            .run(
523                AgentConfig {
524                    model: model_id,
525                    system_context: None,
526                    tools: vec![],
527                },
528                vec![],
529                None,
530                cancel_token.clone(),
531            )
532            .await;
533
534        assert!(result.is_ok(), "expected run to complete, got {result:?}");
535        assert!(cancel_token.is_cancelled(), "cancel token should be set");
536
537        let events = event_store
538            .load_events(interpreter.session_id())
539            .await
540            .expect("load events");
541
542        assert!(
543            events
544                .iter()
545                .any(|(_, event)| matches!(event, SessionEvent::AssistantMessageAdded { .. })),
546            "assistant message should be emitted"
547        );
548        assert!(
549            events
550                .iter()
551                .any(|(_, event)| matches!(event, SessionEvent::OperationCompleted { .. })),
552            "operation should complete"
553        );
554        assert!(
555            !events
556                .iter()
557                .any(|(_, event)| matches!(event, SessionEvent::OperationCancelled { .. })),
558            "operation should not be cancelled"
559        );
560    }
561}