Skip to main content

steer_core/app/domain/runtime/
stepper.rs

1use std::collections::HashMap;
2
3use crate::app::SystemContext;
4use crate::app::conversation::{AssistantContent, Message, MessageData};
5use crate::app::domain::types::{MessageId, ToolCallId};
6use crate::config::model::ModelId;
7use steer_tools::{ToolCall, ToolError, ToolResult, ToolSchema};
8
9#[derive(Debug, Clone)]
10pub enum AgentState {
11    AwaitingModel {
12        messages: Vec<Message>,
13    },
14    AwaitingToolApprovals {
15        messages: Vec<Message>,
16        pending_approvals: Vec<ToolCall>,
17        approved: Vec<ToolCall>,
18        denied: Vec<ToolCall>,
19    },
20    AwaitingToolResults {
21        messages: Vec<Message>,
22        pending_results: HashMap<ToolCallId, ToolCall>,
23        completed_results: Vec<(ToolCallId, ToolResult)>,
24    },
25    Complete {
26        final_message: Message,
27    },
28    Failed {
29        error: String,
30    },
31    Cancelled,
32}
33
34#[derive(Debug, Clone)]
35pub enum AgentInput {
36    ModelResponse {
37        content: Vec<AssistantContent>,
38        tool_calls: Vec<ToolCall>,
39        message_id: MessageId,
40        timestamp: u64,
41    },
42    ModelError {
43        error: String,
44    },
45    ToolApproved {
46        tool_call_id: ToolCallId,
47    },
48    ToolDenied {
49        tool_call_id: ToolCallId,
50    },
51    ToolCompleted {
52        tool_call_id: ToolCallId,
53        result: ToolResult,
54        message_id: MessageId,
55        timestamp: u64,
56    },
57    ToolFailed {
58        tool_call_id: ToolCallId,
59        error: ToolError,
60        message_id: MessageId,
61        timestamp: u64,
62    },
63    Cancel,
64}
65
66#[derive(Debug, Clone)]
67pub enum AgentOutput {
68    CallModel {
69        model: ModelId,
70        messages: Vec<Message>,
71        system_context: Box<Option<SystemContext>>,
72        tools: Vec<ToolSchema>,
73    },
74    RequestApproval {
75        tool_call: ToolCall,
76    },
77    ExecuteTool {
78        tool_call: ToolCall,
79    },
80    EmitMessage {
81        message: Message,
82    },
83    Done {
84        final_message: Message,
85    },
86    Error {
87        error: String,
88    },
89    Cancelled,
90}
91
92#[derive(Debug, Clone)]
93pub struct AgentConfig {
94    pub model: ModelId,
95    pub system_context: Option<SystemContext>,
96    pub tools: Vec<ToolSchema>,
97}
98
99struct ToolCompletionContext {
100    messages: Vec<Message>,
101    pending_results: HashMap<ToolCallId, ToolCall>,
102    completed_results: Vec<(ToolCallId, ToolResult)>,
103    tool_call_id: ToolCallId,
104    message_id: MessageId,
105    timestamp: u64,
106}
107
108pub struct AgentStepper {
109    config: AgentConfig,
110}
111
112impl AgentStepper {
113    pub fn new(config: AgentConfig) -> Self {
114        Self { config }
115    }
116
117    pub fn initial_state(messages: Vec<Message>) -> AgentState {
118        AgentState::AwaitingModel { messages }
119    }
120
121    pub fn step(&self, state: AgentState, input: AgentInput) -> (AgentState, Vec<AgentOutput>) {
122        match (state, input) {
123            (
124                AgentState::AwaitingModel { messages },
125                AgentInput::ModelResponse {
126                    content,
127                    tool_calls,
128                    message_id,
129                    timestamp,
130                },
131            ) => Self::handle_model_response(messages, content, tool_calls, message_id, timestamp),
132
133            (AgentState::AwaitingModel { .. }, AgentInput::ModelError { error }) => (
134                AgentState::Failed {
135                    error: error.clone(),
136                },
137                vec![AgentOutput::Error { error }],
138            ),
139
140            (
141                AgentState::AwaitingToolApprovals {
142                    messages,
143                    pending_approvals,
144                    approved,
145                    denied,
146                },
147                AgentInput::ToolApproved { tool_call_id },
148            ) => Self::handle_tool_approved(
149                messages,
150                pending_approvals,
151                approved,
152                denied,
153                tool_call_id,
154            ),
155
156            (
157                AgentState::AwaitingToolApprovals {
158                    messages,
159                    pending_approvals,
160                    approved,
161                    denied,
162                },
163                AgentInput::ToolDenied { tool_call_id },
164            ) => Self::handle_tool_denied(
165                messages,
166                pending_approvals,
167                approved,
168                denied,
169                tool_call_id,
170            ),
171
172            (
173                AgentState::AwaitingToolResults {
174                    messages,
175                    pending_results,
176                    completed_results,
177                },
178                AgentInput::ToolCompleted {
179                    tool_call_id,
180                    result,
181                    message_id,
182                    timestamp,
183                },
184            ) => self.handle_tool_completed(
185                ToolCompletionContext {
186                    messages,
187                    pending_results,
188                    completed_results,
189                    tool_call_id,
190                    message_id,
191                    timestamp,
192                },
193                result,
194            ),
195
196            (
197                AgentState::AwaitingToolResults {
198                    messages,
199                    pending_results,
200                    completed_results,
201                },
202                AgentInput::ToolFailed {
203                    tool_call_id,
204                    error,
205                    message_id,
206                    timestamp,
207                },
208            ) => self.handle_tool_failed(
209                ToolCompletionContext {
210                    messages,
211                    pending_results,
212                    completed_results,
213                    tool_call_id,
214                    message_id,
215                    timestamp,
216                },
217                error,
218            ),
219
220            (state, AgentInput::Cancel) => Self::handle_cancel(state),
221
222            (state, _) => (state, vec![]),
223        }
224    }
225
226    fn handle_model_response(
227        mut messages: Vec<Message>,
228        content: Vec<AssistantContent>,
229        tool_calls: Vec<ToolCall>,
230        message_id: MessageId,
231        timestamp: u64,
232    ) -> (AgentState, Vec<AgentOutput>) {
233        let parent_id = messages.last().map(|m| m.id().to_string());
234
235        let assistant_message = Message {
236            data: MessageData::Assistant { content },
237            timestamp,
238            id: message_id.0.clone(),
239            parent_message_id: parent_id,
240        };
241
242        messages.push(assistant_message.clone());
243
244        let mut outputs = vec![AgentOutput::EmitMessage {
245            message: assistant_message.clone(),
246        }];
247
248        if tool_calls.is_empty() {
249            (
250                AgentState::Complete {
251                    final_message: assistant_message.clone(),
252                },
253                vec![
254                    AgentOutput::EmitMessage {
255                        message: assistant_message.clone(),
256                    },
257                    AgentOutput::Done {
258                        final_message: assistant_message,
259                    },
260                ],
261            )
262        } else {
263            for tool_call in &tool_calls {
264                outputs.push(AgentOutput::RequestApproval {
265                    tool_call: tool_call.clone(),
266                });
267            }
268
269            (
270                AgentState::AwaitingToolApprovals {
271                    messages,
272                    pending_approvals: tool_calls,
273                    approved: vec![],
274                    denied: vec![],
275                },
276                outputs,
277            )
278        }
279    }
280
281    fn handle_tool_approved(
282        messages: Vec<Message>,
283        mut pending_approvals: Vec<ToolCall>,
284        mut approved: Vec<ToolCall>,
285        denied: Vec<ToolCall>,
286        tool_call_id: ToolCallId,
287    ) -> (AgentState, Vec<AgentOutput>) {
288        let mut outputs = vec![];
289
290        if let Some(pos) = pending_approvals
291            .iter()
292            .position(|tc| tc.id == tool_call_id.0)
293        {
294            let tool_call = pending_approvals.remove(pos);
295            outputs.push(AgentOutput::ExecuteTool {
296                tool_call: tool_call.clone(),
297            });
298            approved.push(tool_call);
299        }
300
301        if pending_approvals.is_empty() {
302            let mut pending_results = HashMap::new();
303            for tc in &approved {
304                pending_results.insert(ToolCallId::from_string(&tc.id), tc.clone());
305            }
306
307            (
308                AgentState::AwaitingToolResults {
309                    messages,
310                    pending_results,
311                    completed_results: vec![],
312                },
313                outputs,
314            )
315        } else {
316            (
317                AgentState::AwaitingToolApprovals {
318                    messages,
319                    pending_approvals,
320                    approved,
321                    denied,
322                },
323                outputs,
324            )
325        }
326    }
327
328    fn handle_tool_denied(
329        mut messages: Vec<Message>,
330        mut pending_approvals: Vec<ToolCall>,
331        approved: Vec<ToolCall>,
332        mut denied: Vec<ToolCall>,
333        tool_call_id: ToolCallId,
334    ) -> (AgentState, Vec<AgentOutput>) {
335        let mut outputs = vec![];
336
337        if let Some(pos) = pending_approvals
338            .iter()
339            .position(|tc| tc.id == tool_call_id.0)
340        {
341            let tool_call = pending_approvals.remove(pos);
342            Self::emit_tool_error_message(
343                &mut messages,
344                &mut outputs,
345                &tool_call,
346                ToolError::DeniedByUser(tool_call.name.clone()),
347            );
348            denied.push(tool_call);
349        }
350
351        if pending_approvals.is_empty() {
352            if approved.is_empty() {
353                (
354                    AgentState::Failed {
355                        error: "All tools denied".to_string(),
356                    },
357                    {
358                        outputs.push(AgentOutput::Error {
359                            error: "All tools denied".to_string(),
360                        });
361                        outputs
362                    },
363                )
364            } else {
365                let mut pending_results = HashMap::new();
366                for tc in &approved {
367                    pending_results.insert(ToolCallId::from_string(&tc.id), tc.clone());
368                }
369
370                (
371                    AgentState::AwaitingToolResults {
372                        messages,
373                        pending_results,
374                        completed_results: vec![],
375                    },
376                    outputs,
377                )
378            }
379        } else {
380            (
381                AgentState::AwaitingToolApprovals {
382                    messages,
383                    pending_approvals,
384                    approved,
385                    denied,
386                },
387                outputs,
388            )
389        }
390    }
391
392    fn emit_tool_error_message(
393        messages: &mut Vec<Message>,
394        outputs: &mut Vec<AgentOutput>,
395        tool_call: &ToolCall,
396        error: ToolError,
397    ) {
398        let parent_id = messages.last().map(|m| m.id().to_string());
399        let message_id = MessageId::new();
400        let timestamp = Message::current_timestamp();
401
402        let tool_message = Message {
403            data: MessageData::Tool {
404                tool_use_id: tool_call.id.clone(),
405                result: ToolResult::Error(error),
406            },
407            timestamp,
408            id: message_id.0.clone(),
409            parent_message_id: parent_id,
410        };
411
412        messages.push(tool_message.clone());
413        outputs.push(AgentOutput::EmitMessage {
414            message: tool_message,
415        });
416    }
417
418    fn handle_tool_completed(
419        &self,
420        mut context: ToolCompletionContext,
421        result: ToolResult,
422    ) -> (AgentState, Vec<AgentOutput>) {
423        let mut outputs = vec![];
424
425        if let Some(tool_call) = context.pending_results.remove(&context.tool_call_id) {
426            let parent_id = context.messages.last().map(|m| m.id().to_string());
427
428            let tool_message = Message {
429                data: MessageData::Tool {
430                    tool_use_id: tool_call.id.clone(),
431                    result: result.clone(),
432                },
433                timestamp: context.timestamp,
434                id: context.message_id.0.clone(),
435                parent_message_id: parent_id,
436            };
437
438            context.messages.push(tool_message.clone());
439            outputs.push(AgentOutput::EmitMessage {
440                message: tool_message,
441            });
442            context
443                .completed_results
444                .push((context.tool_call_id, result));
445        }
446
447        if context.pending_results.is_empty() {
448            outputs.push(AgentOutput::CallModel {
449                model: self.config.model.clone(),
450                messages: context.messages.clone(),
451                system_context: Box::new(self.config.system_context.clone()),
452                tools: self.config.tools.clone(),
453            });
454
455            (
456                AgentState::AwaitingModel {
457                    messages: context.messages,
458                },
459                outputs,
460            )
461        } else {
462            (
463                AgentState::AwaitingToolResults {
464                    messages: context.messages,
465                    pending_results: context.pending_results,
466                    completed_results: context.completed_results,
467                },
468                outputs,
469            )
470        }
471    }
472
473    fn handle_tool_failed(
474        &self,
475        context: ToolCompletionContext,
476        error: ToolError,
477    ) -> (AgentState, Vec<AgentOutput>) {
478        let result = ToolResult::Error(error);
479        self.handle_tool_completed(context, result)
480    }
481
482    fn handle_cancel(state: AgentState) -> (AgentState, Vec<AgentOutput>) {
483        let mut outputs = Vec::new();
484
485        match state {
486            AgentState::AwaitingToolApprovals {
487                mut messages,
488                pending_approvals,
489                approved,
490                denied: _,
491            } => {
492                for tool_call in pending_approvals.into_iter().chain(approved.into_iter()) {
493                    Self::emit_tool_error_message(
494                        &mut messages,
495                        &mut outputs,
496                        &tool_call,
497                        ToolError::Cancelled(tool_call.name.clone()),
498                    );
499                }
500            }
501            AgentState::AwaitingToolResults {
502                mut messages,
503                pending_results,
504                completed_results: _,
505            } => {
506                for (_, tool_call) in pending_results {
507                    Self::emit_tool_error_message(
508                        &mut messages,
509                        &mut outputs,
510                        &tool_call,
511                        ToolError::Cancelled(tool_call.name.clone()),
512                    );
513                }
514            }
515            _ => {}
516        }
517
518        outputs.push(AgentOutput::Cancelled);
519        (AgentState::Cancelled, outputs)
520    }
521
522    pub fn needs_model_call(&self, state: &AgentState) -> bool {
523        matches!(state, AgentState::AwaitingModel { .. })
524    }
525
526    pub fn is_terminal(&self, state: &AgentState) -> bool {
527        matches!(
528            state,
529            AgentState::Complete { .. } | AgentState::Failed { .. } | AgentState::Cancelled
530        )
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use crate::config::model::builtin;
538
539    fn test_config() -> AgentConfig {
540        AgentConfig {
541            model: builtin::claude_sonnet_4_5(),
542            system_context: None,
543            tools: vec![],
544        }
545    }
546
547    #[test]
548    fn test_initial_state() {
549        let state = AgentStepper::initial_state(vec![]);
550        assert!(matches!(state, AgentState::AwaitingModel { .. }));
551    }
552
553    #[test]
554    fn test_model_response_no_tools_completes() {
555        let stepper = AgentStepper::new(test_config());
556        let state = AgentState::AwaitingModel { messages: vec![] };
557
558        let (new_state, outputs) = stepper.step(
559            state,
560            AgentInput::ModelResponse {
561                content: vec![],
562                tool_calls: vec![],
563                message_id: MessageId::new(),
564                timestamp: 0,
565            },
566        );
567
568        assert!(matches!(new_state, AgentState::Complete { .. }));
569        assert!(
570            outputs
571                .iter()
572                .any(|o| matches!(o, AgentOutput::Done { .. }))
573        );
574    }
575
576    #[test]
577    fn test_model_response_with_tools_requests_approval() {
578        let stepper = AgentStepper::new(test_config());
579        let state = AgentState::AwaitingModel { messages: vec![] };
580
581        let tool_call = ToolCall {
582            id: "tc_1".to_string(),
583            name: "test_tool".to_string(),
584            parameters: serde_json::json!({}),
585        };
586
587        let (new_state, outputs) = stepper.step(
588            state,
589            AgentInput::ModelResponse {
590                content: vec![],
591                tool_calls: vec![tool_call],
592                message_id: MessageId::new(),
593                timestamp: 0,
594            },
595        );
596
597        assert!(matches!(
598            new_state,
599            AgentState::AwaitingToolApprovals { .. }
600        ));
601        assert!(
602            outputs
603                .iter()
604                .any(|o| matches!(o, AgentOutput::RequestApproval { .. }))
605        );
606    }
607
608    #[test]
609    fn test_tool_denied_emits_tool_message() {
610        let stepper = AgentStepper::new(test_config());
611        let tool_call = ToolCall {
612            id: "tc_1".to_string(),
613            name: "test_tool".to_string(),
614            parameters: serde_json::json!({}),
615        };
616
617        let state = AgentState::AwaitingToolApprovals {
618            messages: vec![],
619            pending_approvals: vec![tool_call.clone()],
620            approved: vec![],
621            denied: vec![],
622        };
623
624        let (_new_state, outputs) = stepper.step(
625            state,
626            AgentInput::ToolDenied {
627                tool_call_id: ToolCallId::from_string("tc_1"),
628            },
629        );
630
631        let tool_message = outputs
632            .iter()
633            .find_map(|output| match output {
634                AgentOutput::EmitMessage { message } => Some(message),
635                _ => None,
636            })
637            .expect("tool denial should emit a tool result message");
638
639        match &tool_message.data {
640            MessageData::Tool { result, .. } => match result {
641                ToolResult::Error(error) => {
642                    assert!(matches!(error, ToolError::DeniedByUser(name) if name == "test_tool"));
643                }
644                _ => panic!("expected denied tool error"),
645            },
646            _ => panic!("expected tool message"),
647        }
648    }
649
650    #[test]
651    fn test_cancel_emits_tool_results_for_pending_approvals() {
652        let stepper = AgentStepper::new(test_config());
653        let tool_call = ToolCall {
654            id: "tc_1".to_string(),
655            name: "test_tool".to_string(),
656            parameters: serde_json::json!({}),
657        };
658
659        let state = AgentState::AwaitingToolApprovals {
660            messages: vec![],
661            pending_approvals: vec![tool_call.clone()],
662            approved: vec![],
663            denied: vec![],
664        };
665
666        let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
667
668        assert!(matches!(new_state, AgentState::Cancelled));
669        assert!(outputs.iter().any(|o| matches!(o, AgentOutput::Cancelled)));
670
671        let tool_message = outputs
672            .iter()
673            .find_map(|output| match output {
674                AgentOutput::EmitMessage { message } => Some(message),
675                _ => None,
676            })
677            .expect("cancel should emit tool result messages");
678
679        match &tool_message.data {
680            MessageData::Tool { result, .. } => match result {
681                ToolResult::Error(error) => {
682                    assert!(matches!(error, ToolError::Cancelled(name) if name == "test_tool"));
683                }
684                _ => panic!("expected cancelled tool error"),
685            },
686            _ => panic!("expected tool message"),
687        }
688    }
689
690    #[test]
691    fn test_cancel_from_any_state() {
692        let stepper = AgentStepper::new(test_config());
693
694        let states = vec![
695            AgentState::AwaitingModel { messages: vec![] },
696            AgentState::AwaitingToolApprovals {
697                messages: vec![],
698                pending_approvals: vec![],
699                approved: vec![],
700                denied: vec![],
701            },
702        ];
703
704        for state in states {
705            let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
706            assert!(matches!(new_state, AgentState::Cancelled));
707            assert!(outputs.iter().any(|o| matches!(o, AgentOutput::Cancelled)));
708        }
709    }
710}