Skip to main content

steer_core/app/domain/runtime/
agent_interpreter.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use thiserror::Error;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::Client as ApiClient;
10use crate::app::conversation::Message;
11use crate::app::domain::action::ApprovalDecision;
12use crate::app::domain::event::{CancellationInfo, OperationKind, SessionEvent};
13use crate::app::domain::session::EventStore;
14use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
15use crate::config::model::builtin::default_model;
16use crate::session::state::{
17    SessionConfig, SessionPolicyOverrides, SessionToolConfig, ToolApprovalPolicyOverrides,
18    ToolVisibility, WorkspaceConfig,
19};
20use crate::tools::{SessionMcpBackends, ToolExecutor};
21
22use super::interpreter::EffectInterpreter;
23use super::stepper::{AgentConfig, AgentInput, AgentOutput, AgentState, AgentStepper};
24
25#[derive(Clone, Default)]
26pub struct AgentInterpreterConfig {
27    pub auto_approve_tools: bool,
28    pub parent_session_id: Option<SessionId>,
29    pub session_config: Option<SessionConfig>,
30    pub session_backends: Option<Arc<SessionMcpBackends>>,
31}
32
33impl std::fmt::Debug for AgentInterpreterConfig {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("AgentInterpreterConfig")
36            .field("auto_approve_tools", &self.auto_approve_tools)
37            .field("parent_session_id", &self.parent_session_id)
38            .field("session_config", &self.session_config)
39            .field("session_backends", &self.session_backends.is_some())
40            .finish()
41    }
42}
43
44impl AgentInterpreterConfig {
45    pub fn for_sub_agent(parent_session_id: SessionId) -> Self {
46        Self {
47            auto_approve_tools: true,
48            parent_session_id: Some(parent_session_id),
49            session_config: None,
50            session_backends: None,
51        }
52    }
53}
54
55pub struct AgentInterpreter {
56    session_id: SessionId,
57    op_id: OpId,
58    config: AgentInterpreterConfig,
59    event_store: Arc<dyn EventStore>,
60    effect_interpreter: EffectInterpreter,
61}
62
63impl AgentInterpreter {
64    pub async fn new(
65        event_store: Arc<dyn EventStore>,
66        api_client: Arc<ApiClient>,
67        tool_executor: Arc<ToolExecutor>,
68        config: AgentInterpreterConfig,
69    ) -> Result<Self, AgentInterpreterError> {
70        let session_id = SessionId::new();
71        let op_id = OpId::new();
72
73        event_store
74            .create_session(session_id)
75            .await
76            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
77
78        let mut session_config = config
79            .session_config
80            .clone()
81            .unwrap_or_else(|| default_session_config(default_model()));
82        if session_config.parent_session_id.is_none() {
83            session_config.parent_session_id = config.parent_session_id;
84        }
85
86        let session_created_event = SessionEvent::SessionCreated {
87            config: Box::new(session_config),
88            metadata: HashMap::new(),
89            parent_session_id: config.parent_session_id,
90        };
91        event_store
92            .append(session_id, &session_created_event)
93            .await
94            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
95
96        let mut effect_interpreter =
97            EffectInterpreter::new(api_client, tool_executor).with_session(session_id);
98        if let Some(backends) = config.session_backends.clone() {
99            effect_interpreter = effect_interpreter.with_session_backends(backends);
100        }
101
102        Ok(Self {
103            session_id,
104            op_id,
105            config,
106            event_store,
107            effect_interpreter,
108        })
109    }
110
111    pub fn session_id(&self) -> SessionId {
112        self.session_id
113    }
114
115    pub async fn run(
116        &self,
117        agent_config: AgentConfig,
118        initial_messages: Vec<Message>,
119        message_tx: Option<mpsc::Sender<Message>>,
120        cancel_token: CancellationToken,
121    ) -> Result<Message, AgentInterpreterError> {
122        self.emit_event(SessionEvent::OperationStarted {
123            op_id: self.op_id,
124            kind: OperationKind::AgentLoop,
125        })
126        .await?;
127
128        let stepper = AgentStepper::new(agent_config.clone());
129        let mut state = AgentStepper::initial_state(initial_messages.clone());
130
131        let initial_outputs = vec![AgentOutput::CallModel {
132            model: agent_config.model.clone(),
133            messages: initial_messages,
134            system_context: Box::new(agent_config.system_context.clone()),
135            tools: agent_config.tools.clone(),
136        }];
137
138        let mut pending_outputs: VecDeque<AgentOutput> = VecDeque::from(initial_outputs);
139
140        loop {
141            if cancel_token.is_cancelled()
142                && !matches!(state, AgentState::Cancelled)
143                && !stepper.is_terminal(&state)
144            {
145                let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
146                state = new_state;
147                pending_outputs = VecDeque::from(outputs);
148            }
149
150            let output = if let Some(o) = pending_outputs.pop_front() {
151                o
152            } else {
153                if stepper.is_terminal(&state) {
154                    match state {
155                        AgentState::Complete { final_message } => {
156                            self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
157                                .await?;
158                            return Ok(final_message);
159                        }
160                        AgentState::Failed { error } => {
161                            self.emit_event(SessionEvent::Error {
162                                message: error.clone(),
163                            })
164                            .await?;
165                            self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
166                                .await?;
167                            return Err(AgentInterpreterError::Agent(error));
168                        }
169                        AgentState::Cancelled => {
170                            self.emit_event(SessionEvent::OperationCancelled {
171                                op_id: self.op_id,
172                                info: CancellationInfo {
173                                    pending_tool_calls: 0,
174                                    popped_queued_item: None,
175                                },
176                            })
177                            .await?;
178                            return Err(AgentInterpreterError::Cancelled);
179                        }
180                        _ => unreachable!(),
181                    }
182                }
183                return Err(AgentInterpreterError::Agent(
184                    "Stepper stuck with no outputs".to_string(),
185                ));
186            };
187
188            match output {
189                AgentOutput::CallModel {
190                    model,
191                    messages,
192                    system_context,
193                    tools,
194                } => {
195                    let result = self
196                        .effect_interpreter
197                        .call_model(
198                            model.clone(),
199                            messages,
200                            *system_context,
201                            tools,
202                            cancel_token.clone(),
203                        )
204                        .await;
205
206                    let message_id = MessageId::new();
207                    let timestamp = current_timestamp();
208
209                    let input = match result {
210                        Ok(response) => {
211                            let tool_calls: Vec<_> = response
212                                .content
213                                .iter()
214                                .filter_map(|c| {
215                                    if let crate::app::conversation::AssistantContent::ToolCall {
216                                        tool_call,
217                                        ..
218                                    } = c
219                                    {
220                                        Some(tool_call.clone())
221                                    } else {
222                                        None
223                                    }
224                                })
225                                .collect();
226
227                            AgentInput::ModelResponse {
228                                content: response.content,
229                                tool_calls,
230                                message_id,
231                                timestamp,
232                            }
233                        }
234                        Err(error) => AgentInput::ModelError { error },
235                    };
236
237                    let (new_state, outputs) = stepper.step(state, input);
238                    state = new_state;
239                    pending_outputs.extend(outputs);
240                }
241
242                AgentOutput::RequestApproval { tool_call } => {
243                    let tool_call_id = ToolCallId::from_string(&tool_call.id);
244                    let request_id = RequestId::new();
245
246                    self.emit_event(SessionEvent::ApprovalRequested {
247                        request_id,
248                        tool_call: tool_call.clone(),
249                    })
250                    .await?;
251
252                    if !self.config.auto_approve_tools {
253                        return Err(AgentInterpreterError::Agent(
254                            "Interactive tool approval not supported in AgentInterpreter".into(),
255                        ));
256                    }
257
258                    self.emit_event(SessionEvent::ApprovalDecided {
259                        request_id,
260                        decision: ApprovalDecision::Approved,
261                        remember: None,
262                    })
263                    .await?;
264
265                    let input = AgentInput::ToolApproved { tool_call_id };
266
267                    let (new_state, outputs) = stepper.step(state, input);
268                    state = new_state;
269                    pending_outputs.extend(outputs);
270                }
271
272                AgentOutput::ExecuteTool { tool_call } => {
273                    let tool_call_id = ToolCallId::from_string(&tool_call.id);
274
275                    self.emit_event(SessionEvent::ToolCallStarted {
276                        id: tool_call_id.clone(),
277                        name: tool_call.name.clone(),
278                        parameters: tool_call.parameters.clone(),
279                        model: agent_config.model.clone(),
280                    })
281                    .await?;
282
283                    let result = self
284                        .effect_interpreter
285                        .execute_tool(
286                            tool_call.clone(),
287                            Some(agent_config.model.clone()),
288                            cancel_token.clone(),
289                        )
290                        .await;
291
292                    let message_id = MessageId::new();
293                    let timestamp = current_timestamp();
294
295                    let input = match result {
296                        Ok(tool_result) => {
297                            self.emit_event(SessionEvent::ToolCallCompleted {
298                                id: tool_call_id.clone(),
299                                name: tool_call.name.clone(),
300                                result: tool_result.clone(),
301                                model: agent_config.model.clone(),
302                            })
303                            .await?;
304
305                            AgentInput::ToolCompleted {
306                                tool_call_id,
307                                result: tool_result,
308                                message_id,
309                                timestamp,
310                            }
311                        }
312                        Err(error) => {
313                            self.emit_event(SessionEvent::ToolCallFailed {
314                                id: tool_call_id.clone(),
315                                name: tool_call.name.clone(),
316                                error: error.to_string(),
317                                model: agent_config.model.clone(),
318                            })
319                            .await?;
320
321                            AgentInput::ToolFailed {
322                                tool_call_id,
323                                error,
324                                message_id,
325                                timestamp,
326                            }
327                        }
328                    };
329
330                    let (new_state, outputs) = stepper.step(state, input);
331                    state = new_state;
332                    pending_outputs.extend(outputs);
333                }
334
335                AgentOutput::EmitMessage { message } => {
336                    self.emit_event(SessionEvent::AssistantMessageAdded {
337                        message: message.clone(),
338                        model: agent_config.model.clone(),
339                    })
340                    .await?;
341
342                    if let Some(ref tx) = message_tx {
343                        let _ = tx.send(message).await;
344                    }
345                }
346
347                AgentOutput::Done { final_message } => {
348                    self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
349                        .await?;
350                    return Ok(final_message);
351                }
352
353                AgentOutput::Error { error } => {
354                    self.emit_event(SessionEvent::Error {
355                        message: error.clone(),
356                    })
357                    .await?;
358                    self.emit_event(SessionEvent::OperationCompleted { op_id: self.op_id })
359                        .await?;
360                    return Err(AgentInterpreterError::Agent(error));
361                }
362
363                AgentOutput::Cancelled => {
364                    self.emit_event(SessionEvent::OperationCancelled {
365                        op_id: self.op_id,
366                        info: CancellationInfo {
367                            pending_tool_calls: 0,
368                            popped_queued_item: None,
369                        },
370                    })
371                    .await?;
372                    return Err(AgentInterpreterError::Cancelled);
373                }
374            }
375        }
376    }
377
378    async fn emit_event(&self, event: SessionEvent) -> Result<(), AgentInterpreterError> {
379        self.event_store
380            .append(self.session_id, &event)
381            .await
382            .map_err(|e| AgentInterpreterError::EventStore(e.to_string()))?;
383        Ok(())
384    }
385}
386
387fn default_session_config(default_model: crate::config::model::ModelId) -> SessionConfig {
388    SessionConfig {
389        workspace: WorkspaceConfig::Local {
390            path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
391        },
392        workspace_ref: None,
393        workspace_id: None,
394        repo_ref: None,
395        parent_session_id: None,
396        workspace_name: None,
397        tool_config: SessionToolConfig {
398            backends: Vec::new(),
399            visibility: ToolVisibility::All,
400            approval_policy: crate::session::state::ToolApprovalPolicy::default(),
401            metadata: HashMap::new(),
402        },
403        system_prompt: None,
404        primary_agent_id: None,
405        policy_overrides: SessionPolicyOverrides {
406            default_model: None,
407            tool_visibility: Some(ToolVisibility::ReadOnly),
408            approval_policy: ToolApprovalPolicyOverrides::empty(),
409        },
410        metadata: HashMap::new(),
411        default_model,
412        auto_compaction: crate::session::state::AutoCompactionConfig::default(),
413    }
414}
415
416fn current_timestamp() -> u64 {
417    SystemTime::now()
418        .duration_since(UNIX_EPOCH)
419        .unwrap_or_default()
420        .as_secs()
421}
422
423#[derive(Debug, Error)]
424pub enum AgentInterpreterError {
425    #[error("API error: {0}")]
426    Api(String),
427
428    #[error("Agent error: {0}")]
429    Agent(String),
430
431    #[error("Event store error: {0}")]
432    EventStore(String),
433
434    #[error("Cancelled")]
435    Cancelled,
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::api::error::ApiError;
442    use crate::api::provider::{CompletionResponse, Provider};
443    use crate::app::SystemContext;
444    use crate::app::conversation::AssistantContent;
445    use crate::app::domain::session::event_store::InMemoryEventStore;
446    use crate::app::validation::ValidatorRegistry;
447    use crate::auth::ProviderRegistry;
448    use crate::config::model::ModelId;
449    use crate::config::provider::ProviderId;
450    use crate::model_registry::ModelRegistry;
451    use crate::tools::BackendRegistry;
452    use async_trait::async_trait;
453    use steer_tools::ToolSchema;
454
455    #[derive(Clone)]
456    struct StubProvider {
457        cancel_on_complete: bool,
458    }
459
460    #[async_trait]
461    impl Provider for StubProvider {
462        fn name(&self) -> &'static str {
463            "stub"
464        }
465
466        async fn complete(
467            &self,
468            _model_id: &ModelId,
469            _messages: Vec<Message>,
470            _system: Option<SystemContext>,
471            _tools: Option<Vec<ToolSchema>>,
472            _call_options: Option<crate::config::model::ModelParameters>,
473            token: CancellationToken,
474        ) -> Result<CompletionResponse, ApiError> {
475            if self.cancel_on_complete {
476                token.cancel();
477            }
478
479            Ok(CompletionResponse {
480                content: vec![AssistantContent::Text {
481                    text: "ok".to_string(),
482                }],
483                usage: None,
484            })
485        }
486    }
487
488    async fn create_test_deps() -> (Arc<dyn EventStore>, Arc<ApiClient>, Arc<ToolExecutor>) {
489        let event_store = Arc::new(InMemoryEventStore::new());
490        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
491        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
492        let api_client = Arc::new(ApiClient::new_with_deps(
493            crate::test_utils::test_llm_config_provider().unwrap(),
494            provider_registry,
495            model_registry,
496        ));
497
498        let tool_executor = Arc::new(ToolExecutor::with_components(
499            Arc::new(BackendRegistry::new()),
500            Arc::new(ValidatorRegistry::new()),
501        ));
502
503        (event_store, api_client, tool_executor)
504    }
505
506    #[tokio::test]
507    async fn test_cancel_after_completion_does_not_override_outputs() {
508        let (event_store, api_client, tool_executor) = create_test_deps().await;
509        let provider_id = ProviderId("stub".to_string());
510        let model_id = ModelId::new(provider_id.clone(), "stub-model");
511        api_client.insert_test_provider(
512            provider_id,
513            Arc::new(StubProvider {
514                cancel_on_complete: true,
515            }),
516        );
517
518        let interpreter = AgentInterpreter::new(
519            event_store.clone(),
520            api_client,
521            tool_executor,
522            AgentInterpreterConfig::default(),
523        )
524        .await
525        .expect("interpreter");
526
527        let cancel_token = CancellationToken::new();
528        let result = interpreter
529            .run(
530                AgentConfig {
531                    model: model_id,
532                    system_context: None,
533                    tools: vec![],
534                },
535                vec![],
536                None,
537                cancel_token.clone(),
538            )
539            .await;
540
541        assert!(result.is_ok(), "expected run to complete, got {result:?}");
542        assert!(cancel_token.is_cancelled(), "cancel token should be set");
543
544        let events = event_store
545            .load_events(interpreter.session_id())
546            .await
547            .expect("load events");
548
549        assert!(
550            events
551                .iter()
552                .any(|(_, event)| matches!(event, SessionEvent::AssistantMessageAdded { .. })),
553            "assistant message should be emitted"
554        );
555        assert!(
556            events
557                .iter()
558                .any(|(_, event)| matches!(event, SessionEvent::OperationCompleted { .. })),
559            "operation should complete"
560        );
561        assert!(
562            !events
563                .iter()
564                .any(|(_, event)| matches!(event, SessionEvent::OperationCancelled { .. })),
565            "operation should not be cancelled"
566        );
567    }
568}