Skip to main content

steer_core/app/domain/runtime/
agent_interpreter.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use thiserror::Error;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::Client as ApiClient;
10use crate::app::conversation::Message;
11use crate::app::domain::action::ApprovalDecision;
12use crate::app::domain::event::{CancellationInfo, OperationKind, SessionEvent};
13use crate::app::domain::session::EventStore;
14use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
15use crate::config::model::builtin::default_model;
16use crate::session::state::{
17    SessionConfig, SessionPolicyOverrides, SessionToolConfig, ToolApprovalPolicyOverrides,
18    ToolVisibility, WorkspaceConfig,
19};
20use crate::tools::{SessionMcpBackends, ToolExecutor};
21
22use super::interpreter::EffectInterpreter;
23use super::stepper::{AgentConfig, AgentInput, AgentOutput, AgentState, AgentStepper};
24
25#[derive(Clone, Default)]
26pub struct AgentInterpreterConfig {
27    pub auto_approve_tools: bool,
28    pub parent_session_id: Option<SessionId>,
29    pub session_config: Option<SessionConfig>,
30    pub session_backends: Option<Arc<SessionMcpBackends>>,
31}
32
33impl std::fmt::Debug for AgentInterpreterConfig {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("AgentInterpreterConfig")
36            .field("auto_approve_tools", &self.auto_approve_tools)
37            .field("parent_session_id", &self.parent_session_id)
38            .field("session_config", &self.session_config)
39            .field("session_backends", &self.session_backends.is_some())
40            .finish()
41    }
42}
43
44impl AgentInterpreterConfig {
45    pub fn for_sub_agent(parent_session_id: SessionId) -> Self {
46        Self {
47            auto_approve_tools: true,
48            parent_session_id: Some(parent_session_id),
49            session_config: None,
50            session_backends: None,
51        }
52    }
53}
54
55pub struct AgentInterpreter {
56    session_id: SessionId,
57    op_id: OpId,
58    config: AgentInterpreterConfig,
59    event_store: Arc<dyn EventStore>,
60    effect_interpreter: EffectInterpreter,
61}
62
63impl AgentInterpreter {
64    pub async fn new(
65        event_store: Arc<dyn EventStore>,
66        api_client: Arc<ApiClient>,
67        tool_executor: Arc<ToolExecutor>,
68        config: AgentInterpreterConfig,
69    ) -> Result<Self, AgentInterpreterError> {
70        let session_id = SessionId::new();
71        let op_id = OpId::new();
72
73        event_store
74            .create_session(session_id)
75            .await
76            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
77
78        let mut session_config = config
79            .session_config
80            .clone()
81            .unwrap_or_else(|| default_session_config(default_model()));
82        if session_config.parent_session_id.is_none() {
83            session_config.parent_session_id = config.parent_session_id;
84        }
85
86        let session_created_event = SessionEvent::SessionCreated {
87            config: Box::new(session_config),
88            metadata: HashMap::new(),
89            parent_session_id: config.parent_session_id,
90        };
91        event_store
92            .append(session_id, &session_created_event)
93            .await
94            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
95
96        let mut effect_interpreter =
97            EffectInterpreter::new(api_client, tool_executor).with_session(session_id);
98        if let Some(backends) = config.session_backends.clone() {
99            effect_interpreter = effect_interpreter.with_session_backends(backends);
100        }
101
102        Ok(Self {
103            session_id,
104            op_id,
105            config,
106            event_store,
107            effect_interpreter,
108        })
109    }
110
111    pub fn session_id(&self) -> SessionId {
112        self.session_id
113    }
114
115    pub async fn run(
116        &self,
117        agent_config: AgentConfig,
118        initial_messages: Vec<Message>,
119        message_tx: Option<mpsc::Sender<Message>>,
120        cancel_token: CancellationToken,
121    ) -> Result<Message, AgentInterpreterError> {
122        self.emit_event(SessionEvent::OperationStarted {
123            op_id: self.op_id,
124            kind: OperationKind::AgentLoop,
125        })
126        .await?;
127
128        let stepper = AgentStepper::new(agent_config.clone());
129        let mut state = AgentStepper::initial_state(initial_messages.clone());
130
131        let initial_outputs = vec![AgentOutput::CallModel {
132            model: agent_config.model.clone(),
133            messages: initial_messages,
134            system_context: Box::new(agent_config.system_context.clone()),
135            tools: agent_config.tools.clone(),
136        }];
137
138        let mut pending_outputs: VecDeque<AgentOutput> = VecDeque::from(initial_outputs);
139
140        loop {
141            if cancel_token.is_cancelled()
142                && !matches!(state, AgentState::Cancelled)
143                && !stepper.is_terminal(&state)
144            {
145                let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
146                state = new_state;
147                pending_outputs = VecDeque::from(outputs);
148            }
149
150            let output = if let Some(o) = pending_outputs.pop_front() {
151                o
152            } else {
153                if stepper.is_terminal(&state) {
154                    match state {
155                        AgentState::Complete { final_message } => {
156                            self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
157                                .await?;
158                            return Ok(final_message);
159                        }
160                        AgentState::Failed { error } => {
161                            self.emit_event(SessionEvent::Error {
162                                message: error.clone(),
163                            })
164                            .await?;
165                            self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
166                                .await?;
167                            return Err(AgentInterpreterError::Agent(error));
168                        }
169                        AgentState::Cancelled => {
170                            self.emit_event(SessionEvent::OperationCancelled {
171                                op_id: self.op_id,
172                                info: CancellationInfo {
173                                    pending_tool_calls: 0,
174                                    popped_queued_item: None,
175                                },
176                            })
177                            .await?;
178                            return Err(AgentInterpreterError::Cancelled);
179                        }
180                        _ => unreachable!(),
181                    }
182                }
183                return Err(AgentInterpreterError::Agent(
184                    "Stepper stuck with no outputs".to_string(),
185                ));
186            };
187
188            match output {
189                AgentOutput::CallModel {
190                    model,
191                    messages,
192                    system_context,
193                    tools,
194                } => {
195                    let result = self
196                        .effect_interpreter
197                        .call_model(
198                            model.clone(),
199                            messages,
200                            *system_context,
201                            tools,
202                            cancel_token.clone(),
203                        )
204                        .await;
205
206                    let message_id = MessageId::new();
207                    let timestamp = current_timestamp();
208
209                    let input = match result {
210                        Ok(response) => {
211                            let tool_calls: Vec<_> = response
212                                .content
213                                .iter()
214                                .filter_map(|c| {
215                                    if let crate::app::conversation::AssistantContent::ToolCall {
216                                        tool_call,
217                                        ..
218                                    } = c
219                                    {
220                                        Some(tool_call.clone())
221                                    } else {
222                                        None
223                                    }
224                                })
225                                .collect();
226
227                            AgentInput::ModelResponse {
228                                content: response.content,
229                                tool_calls,
230                                message_id,
231                                timestamp,
232                            }
233                        }
234                        Err(error) => AgentInput::ModelError { error },
235                    };
236
237                    let (new_state, outputs) = stepper.step(state, input);
238                    state = new_state;
239                    pending_outputs.extend(outputs);
240                }
241
242                AgentOutput::RequestApproval { tool_call } => {
243                    let tool_call_id = ToolCallId::from_string(&tool_call.id);
244                    let request_id = RequestId::new();
245
246                    self.emit_event(SessionEvent::ApprovalRequested {
247                        request_id,
248                        tool_call: tool_call.clone(),
249                    })
250                    .await?;
251
252                    if !self.config.auto_approve_tools {
253                        return Err(AgentInterpreterError::Agent(
254                            "Interactive tool approval not supported in AgentInterpreter".into(),
255                        ));
256                    }
257
258                    self.emit_event(SessionEvent::ApprovalDecided {
259                        request_id,
260                        decision: ApprovalDecision::Approved,
261                        remember: None,
262                    })
263                    .await?;
264
265                    let input = AgentInput::ToolApproved { tool_call_id };
266
267                    let (new_state, outputs) = stepper.step(state, input);
268                    state = new_state;
269                    pending_outputs.extend(outputs);
270                }
271
272                AgentOutput::ExecuteTool { tool_call } => {
273                    let tool_call_id = ToolCallId::from_string(&tool_call.id);
274
275                    self.emit_event(SessionEvent::ToolCallStarted {
276                        id: tool_call_id.clone(),
277                        name: tool_call.name.clone(),
278                        parameters: tool_call.parameters.clone(),
279                        model: agent_config.model.clone(),
280                    })
281                    .await?;
282
283                    let result = self
284                        .effect_interpreter
285                        .execute_tool(tool_call.clone(), cancel_token.clone())
286                        .await;
287
288                    let message_id = MessageId::new();
289                    let timestamp = current_timestamp();
290
291                    let input = match result {
292                        Ok(tool_result) => {
293                            self.emit_event(SessionEvent::ToolCallCompleted {
294                                id: tool_call_id.clone(),
295                                name: tool_call.name.clone(),
296                                result: tool_result.clone(),
297                                model: agent_config.model.clone(),
298                            })
299                            .await?;
300
301                            AgentInput::ToolCompleted {
302                                tool_call_id,
303                                result: tool_result,
304                                message_id,
305                                timestamp,
306                            }
307                        }
308                        Err(error) => {
309                            self.emit_event(SessionEvent::ToolCallFailed {
310                                id: tool_call_id.clone(),
311                                name: tool_call.name.clone(),
312                                error: error.to_string(),
313                                model: agent_config.model.clone(),
314                            })
315                            .await?;
316
317                            AgentInput::ToolFailed {
318                                tool_call_id,
319                                error,
320                                message_id,
321                                timestamp,
322                            }
323                        }
324                    };
325
326                    let (new_state, outputs) = stepper.step(state, input);
327                    state = new_state;
328                    pending_outputs.extend(outputs);
329                }
330
331                AgentOutput::EmitMessage { message } => {
332                    self.emit_event(SessionEvent::AssistantMessageAdded {
333                        message: message.clone(),
334                        model: agent_config.model.clone(),
335                    })
336                    .await?;
337
338                    if let Some(ref tx) = message_tx {
339                        let _ = tx.send(message).await;
340                    }
341                }
342
343                AgentOutput::Done { final_message } => {
344                    self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
345                        .await?;
346                    return Ok(final_message);
347                }
348
349                AgentOutput::Error { error } => {
350                    self.emit_event(SessionEvent::Error {
351                        message: error.clone(),
352                    })
353                    .await?;
354                    self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
355                        .await?;
356                    return Err(AgentInterpreterError::Agent(error));
357                }
358
359                AgentOutput::Cancelled => {
360                    self.emit_event(SessionEvent::OperationCancelled {
361                        op_id: self.op_id,
362                        info: CancellationInfo {
363                            pending_tool_calls: 0,
364                            popped_queued_item: None,
365                        },
366                    })
367                    .await?;
368                    return Err(AgentInterpreterError::Cancelled);
369                }
370            }
371        }
372    }
373
374    async fn emit_event(&self, event: SessionEvent) -> Result<(), AgentInterpreterError> {
375        self.event_store
376            .append(self.session_id, &event)
377            .await
378            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
379        Ok(())
380    }
381}
382
383fn default_session_config(default_model: crate::config::model::ModelId) -> SessionConfig {
384    SessionConfig {
385        workspace: WorkspaceConfig::Local {
386            path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
387        },
388        workspace_ref: None,
389        workspace_id: None,
390        repo_ref: None,
391        parent_session_id: None,
392        workspace_name: None,
393        tool_config: SessionToolConfig {
394            backends: Vec::new(),
395            visibility: ToolVisibility::All,
396            approval_policy: crate::session::state::ToolApprovalPolicy::default(),
397            metadata: HashMap::new(),
398        },
399        system_prompt: None,
400        primary_agent_id: None,
401        policy_overrides: SessionPolicyOverrides {
402            default_model: None,
403            tool_visibility: Some(ToolVisibility::ReadOnly),
404            approval_policy: ToolApprovalPolicyOverrides::empty(),
405        },
406        metadata: HashMap::new(),
407        default_model,
408        auto_compaction: crate::session::state::AutoCompactionConfig::default(),
409    }
410}
411
412fn current_timestamp() -> u64 {
413    SystemTime::now()
414        .duration_since(UNIX_EPOCH)
415        .unwrap_or_default()
416        .as_secs()
417}
418
419#[derive(Debug, Error)]
420pub enum AgentInterpreterError {
421    #[error("API error: {0}")]
422    Api(String),
423
424    #[error("Agent error: {0}")]
425    Agent(String),
426
427    #[error("Event store error: {0}")]
428    EventStore(String),
429
430    #[error("Cancelled")]
431    Cancelled,
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use crate::api::error::ApiError;
438    use crate::api::provider::{CompletionResponse, Provider};
439    use crate::app::SystemContext;
440    use crate::app::conversation::AssistantContent;
441    use crate::app::domain::session::event_store::InMemoryEventStore;
442    use crate::app::validation::ValidatorRegistry;
443    use crate::auth::ProviderRegistry;
444    use crate::config::model::ModelId;
445    use crate::config::provider::ProviderId;
446    use crate::model_registry::ModelRegistry;
447    use crate::tools::BackendRegistry;
448    use async_trait::async_trait;
449    use steer_tools::ToolSchema;
450
451    #[derive(Clone)]
452    struct StubProvider {
453        cancel_on_complete: bool,
454    }
455
456    #[async_trait]
457    impl Provider for StubProvider {
458        fn name(&self) -> &'static str {
459            "stub"
460        }
461
462        async fn complete(
463            &self,
464            _model_id: &ModelId,
465            _messages: Vec<Message>,
466            _system: Option<SystemContext>,
467            _tools: Option<Vec<ToolSchema>>,
468            _call_options: Option<crate::config::model::ModelParameters>,
469            token: CancellationToken,
470        ) -> Result<CompletionResponse, ApiError> {
471            if self.cancel_on_complete {
472                token.cancel();
473            }
474
475            Ok(CompletionResponse {
476                content: vec![AssistantContent::Text {
477                    text: "ok".to_string(),
478                }],
479                usage: None,
480            })
481        }
482    }
483
484    async fn create_test_deps() -> (Arc<dyn EventStore>, Arc<ApiClient>, Arc<ToolExecutor>) {
485        let event_store = Arc::new(InMemoryEventStore::new());
486        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
487        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
488        let api_client = Arc::new(ApiClient::new_with_deps(
489            crate::test_utils::test_llm_config_provider().unwrap(),
490            provider_registry,
491            model_registry,
492        ));
493
494        let tool_executor = Arc::new(ToolExecutor::with_components(
495            Arc::new(BackendRegistry::new()),
496            Arc::new(ValidatorRegistry::new()),
497        ));
498
499        (event_store, api_client, tool_executor)
500    }
501
502    #[tokio::test]
503    async fn test_cancel_after_completion_does_not_override_outputs() {
504        let (event_store, api_client, tool_executor) = create_test_deps().await;
505        let provider_id = ProviderId("stub".to_string());
506        let model_id = ModelId::new(provider_id.clone(), "stub-model");
507        api_client.insert_test_provider(
508            provider_id,
509            Arc::new(StubProvider {
510                cancel_on_complete: true,
511            }),
512        );
513
514        let interpreter = AgentInterpreter::new(
515            event_store.clone(),
516            api_client,
517            tool_executor,
518            AgentInterpreterConfig::default(),
519        )
520        .await
521        .expect("interpreter");
522
523        let cancel_token = CancellationToken::new();
524        let result = interpreter
525            .run(
526                AgentConfig {
527                    model: model_id,
528                    system_context: None,
529                    tools: vec![],
530                },
531                vec![],
532                None,
533                cancel_token.clone(),
534            )
535            .await;
536
537        assert!(result.is_ok(), "expected run to complete, got {result:?}");
538        assert!(cancel_token.is_cancelled(), "cancel token should be set");
539
540        let events = event_store
541            .load_events(interpreter.session_id())
542            .await
543            .expect("load events");
544
545        assert!(
546            events
547                .iter()
548                .any(|(_, event)| matches!(event, SessionEvent::AssistantMessageAdded { .. })),
549            "assistant message should be emitted"
550        );
551        assert!(
552            events
553                .iter()
554                .any(|(_, event)| matches!(event, SessionEvent::OperationCompleted { .. })),
555            "operation should complete"
556        );
557        assert!(
558            !events
559                .iter()
560                .any(|(_, event)| matches!(event, SessionEvent::OperationCancelled { .. })),
561            "operation should not be cancelled"
562        );
563    }
564}