1use std::collections::HashMap;
2
3use crate::app::SystemContext;
4use crate::app::conversation::{AssistantContent, Message, MessageData};
5use crate::app::domain::types::{MessageId, ToolCallId};
6use crate::config::model::ModelId;
7use steer_tools::{ToolCall, ToolError, ToolResult, ToolSchema};
8
9#[derive(Debug, Clone)]
10pub enum AgentState {
11 AwaitingModel {
12 messages: Vec<Message>,
13 },
14 AwaitingToolApprovals {
15 messages: Vec<Message>,
16 pending_approvals: Vec<ToolCall>,
17 approved: Vec<ToolCall>,
18 denied: Vec<ToolCall>,
19 },
20 AwaitingToolResults {
21 messages: Vec<Message>,
22 pending_results: HashMap<ToolCallId, ToolCall>,
23 completed_results: Vec<(ToolCallId, ToolResult)>,
24 },
25 Complete {
26 final_message: Message,
27 },
28 Failed {
29 error: String,
30 },
31 Cancelled,
32}
33
34#[derive(Debug, Clone)]
35pub enum AgentInput {
36 ModelResponse {
37 content: Vec<AssistantContent>,
38 tool_calls: Vec<ToolCall>,
39 message_id: MessageId,
40 timestamp: u64,
41 },
42 ModelError {
43 error: String,
44 },
45 ToolApproved {
46 tool_call_id: ToolCallId,
47 },
48 ToolDenied {
49 tool_call_id: ToolCallId,
50 },
51 ToolCompleted {
52 tool_call_id: ToolCallId,
53 result: ToolResult,
54 message_id: MessageId,
55 timestamp: u64,
56 },
57 ToolFailed {
58 tool_call_id: ToolCallId,
59 error: ToolError,
60 message_id: MessageId,
61 timestamp: u64,
62 },
63 Cancel,
64}
65
66#[derive(Debug, Clone)]
67pub enum AgentOutput {
68 CallModel {
69 model: ModelId,
70 messages: Vec<Message>,
71 system_context: Box<Option<SystemContext>>,
72 tools: Vec<ToolSchema>,
73 },
74 RequestApproval {
75 tool_call: ToolCall,
76 },
77 ExecuteTool {
78 tool_call: ToolCall,
79 },
80 EmitMessage {
81 message: Message,
82 },
83 Done {
84 final_message: Message,
85 },
86 Error {
87 error: String,
88 },
89 Cancelled,
90}
91
92#[derive(Debug, Clone)]
93pub struct AgentConfig {
94 pub model: ModelId,
95 pub system_context: Option<SystemContext>,
96 pub tools: Vec<ToolSchema>,
97}
98
99struct ToolCompletionContext {
100 messages: Vec<Message>,
101 pending_results: HashMap<ToolCallId, ToolCall>,
102 completed_results: Vec<(ToolCallId, ToolResult)>,
103 tool_call_id: ToolCallId,
104 message_id: MessageId,
105 timestamp: u64,
106}
107
108pub struct AgentStepper {
109 config: AgentConfig,
110}
111
112impl AgentStepper {
113 pub fn new(config: AgentConfig) -> Self {
114 Self { config }
115 }
116
117 pub fn initial_state(messages: Vec<Message>) -> AgentState {
118 AgentState::AwaitingModel { messages }
119 }
120
121 pub fn step(&self, state: AgentState, input: AgentInput) -> (AgentState, Vec<AgentOutput>) {
122 match (state, input) {
123 (
124 AgentState::AwaitingModel { messages },
125 AgentInput::ModelResponse {
126 content,
127 tool_calls,
128 message_id,
129 timestamp,
130 },
131 ) => Self::handle_model_response(messages, content, tool_calls, message_id, timestamp),
132
133 (AgentState::AwaitingModel { .. }, AgentInput::ModelError { error }) => (
134 AgentState::Failed {
135 error: error.clone(),
136 },
137 vec![AgentOutput::Error { error }],
138 ),
139
140 (
141 AgentState::AwaitingToolApprovals {
142 messages,
143 pending_approvals,
144 approved,
145 denied,
146 },
147 AgentInput::ToolApproved { tool_call_id },
148 ) => Self::handle_tool_approved(
149 messages,
150 pending_approvals,
151 approved,
152 denied,
153 tool_call_id,
154 ),
155
156 (
157 AgentState::AwaitingToolApprovals {
158 messages,
159 pending_approvals,
160 approved,
161 denied,
162 },
163 AgentInput::ToolDenied { tool_call_id },
164 ) => Self::handle_tool_denied(
165 messages,
166 pending_approvals,
167 approved,
168 denied,
169 tool_call_id,
170 ),
171
172 (
173 AgentState::AwaitingToolResults {
174 messages,
175 pending_results,
176 completed_results,
177 },
178 AgentInput::ToolCompleted {
179 tool_call_id,
180 result,
181 message_id,
182 timestamp,
183 },
184 ) => self.handle_tool_completed(
185 ToolCompletionContext {
186 messages,
187 pending_results,
188 completed_results,
189 tool_call_id,
190 message_id,
191 timestamp,
192 },
193 result,
194 ),
195
196 (
197 AgentState::AwaitingToolResults {
198 messages,
199 pending_results,
200 completed_results,
201 },
202 AgentInput::ToolFailed {
203 tool_call_id,
204 error,
205 message_id,
206 timestamp,
207 },
208 ) => self.handle_tool_failed(
209 ToolCompletionContext {
210 messages,
211 pending_results,
212 completed_results,
213 tool_call_id,
214 message_id,
215 timestamp,
216 },
217 error,
218 ),
219
220 (state, AgentInput::Cancel) => Self::handle_cancel(state),
221
222 (state, _) => (state, vec![]),
223 }
224 }
225
226 fn handle_model_response(
227 mut messages: Vec<Message>,
228 content: Vec<AssistantContent>,
229 tool_calls: Vec<ToolCall>,
230 message_id: MessageId,
231 timestamp: u64,
232 ) -> (AgentState, Vec<AgentOutput>) {
233 let parent_id = messages.last().map(|m| m.id().to_string());
234
235 let assistant_message = Message {
236 data: MessageData::Assistant { content },
237 timestamp,
238 id: message_id.0.clone(),
239 parent_message_id: parent_id,
240 };
241
242 messages.push(assistant_message.clone());
243
244 let mut outputs = vec![AgentOutput::EmitMessage {
245 message: assistant_message.clone(),
246 }];
247
248 if tool_calls.is_empty() {
249 (
250 AgentState::Complete {
251 final_message: assistant_message.clone(),
252 },
253 vec![
254 AgentOutput::EmitMessage {
255 message: assistant_message.clone(),
256 },
257 AgentOutput::Done {
258 final_message: assistant_message,
259 },
260 ],
261 )
262 } else {
263 for tool_call in &tool_calls {
264 outputs.push(AgentOutput::RequestApproval {
265 tool_call: tool_call.clone(),
266 });
267 }
268
269 (
270 AgentState::AwaitingToolApprovals {
271 messages,
272 pending_approvals: tool_calls,
273 approved: vec![],
274 denied: vec![],
275 },
276 outputs,
277 )
278 }
279 }
280
281 fn handle_tool_approved(
282 messages: Vec<Message>,
283 mut pending_approvals: Vec<ToolCall>,
284 mut approved: Vec<ToolCall>,
285 denied: Vec<ToolCall>,
286 tool_call_id: ToolCallId,
287 ) -> (AgentState, Vec<AgentOutput>) {
288 let mut outputs = vec![];
289
290 if let Some(pos) = pending_approvals
291 .iter()
292 .position(|tc| tc.id == tool_call_id.0)
293 {
294 let tool_call = pending_approvals.remove(pos);
295 outputs.push(AgentOutput::ExecuteTool {
296 tool_call: tool_call.clone(),
297 });
298 approved.push(tool_call);
299 }
300
301 if pending_approvals.is_empty() {
302 let mut pending_results = HashMap::new();
303 for tc in &approved {
304 pending_results.insert(ToolCallId::from_string(&tc.id), tc.clone());
305 }
306
307 (
308 AgentState::AwaitingToolResults {
309 messages,
310 pending_results,
311 completed_results: vec![],
312 },
313 outputs,
314 )
315 } else {
316 (
317 AgentState::AwaitingToolApprovals {
318 messages,
319 pending_approvals,
320 approved,
321 denied,
322 },
323 outputs,
324 )
325 }
326 }
327
328 fn handle_tool_denied(
329 mut messages: Vec<Message>,
330 mut pending_approvals: Vec<ToolCall>,
331 approved: Vec<ToolCall>,
332 mut denied: Vec<ToolCall>,
333 tool_call_id: ToolCallId,
334 ) -> (AgentState, Vec<AgentOutput>) {
335 let mut outputs = vec![];
336
337 if let Some(pos) = pending_approvals
338 .iter()
339 .position(|tc| tc.id == tool_call_id.0)
340 {
341 let tool_call = pending_approvals.remove(pos);
342 Self::emit_tool_error_message(
343 &mut messages,
344 &mut outputs,
345 &tool_call,
346 ToolError::DeniedByUser(tool_call.name.clone()),
347 );
348 denied.push(tool_call);
349 }
350
351 if pending_approvals.is_empty() {
352 if approved.is_empty() {
353 (
354 AgentState::Failed {
355 error: "All tools denied".to_string(),
356 },
357 {
358 outputs.push(AgentOutput::Error {
359 error: "All tools denied".to_string(),
360 });
361 outputs
362 },
363 )
364 } else {
365 let mut pending_results = HashMap::new();
366 for tc in &approved {
367 pending_results.insert(ToolCallId::from_string(&tc.id), tc.clone());
368 }
369
370 (
371 AgentState::AwaitingToolResults {
372 messages,
373 pending_results,
374 completed_results: vec![],
375 },
376 outputs,
377 )
378 }
379 } else {
380 (
381 AgentState::AwaitingToolApprovals {
382 messages,
383 pending_approvals,
384 approved,
385 denied,
386 },
387 outputs,
388 )
389 }
390 }
391
392 fn emit_tool_error_message(
393 messages: &mut Vec<Message>,
394 outputs: &mut Vec<AgentOutput>,
395 tool_call: &ToolCall,
396 error: ToolError,
397 ) {
398 let parent_id = messages.last().map(|m| m.id().to_string());
399 let message_id = MessageId::new();
400 let timestamp = Message::current_timestamp();
401
402 let tool_message = Message {
403 data: MessageData::Tool {
404 tool_use_id: tool_call.id.clone(),
405 result: ToolResult::Error(error),
406 },
407 timestamp,
408 id: message_id.0.clone(),
409 parent_message_id: parent_id,
410 };
411
412 messages.push(tool_message.clone());
413 outputs.push(AgentOutput::EmitMessage {
414 message: tool_message,
415 });
416 }
417
418 fn handle_tool_completed(
419 &self,
420 mut context: ToolCompletionContext,
421 result: ToolResult,
422 ) -> (AgentState, Vec<AgentOutput>) {
423 let mut outputs = vec![];
424
425 if let Some(tool_call) = context.pending_results.remove(&context.tool_call_id) {
426 let parent_id = context.messages.last().map(|m| m.id().to_string());
427
428 let tool_message = Message {
429 data: MessageData::Tool {
430 tool_use_id: tool_call.id.clone(),
431 result: result.clone(),
432 },
433 timestamp: context.timestamp,
434 id: context.message_id.0.clone(),
435 parent_message_id: parent_id,
436 };
437
438 context.messages.push(tool_message.clone());
439 outputs.push(AgentOutput::EmitMessage {
440 message: tool_message,
441 });
442 context
443 .completed_results
444 .push((context.tool_call_id, result));
445 }
446
447 if context.pending_results.is_empty() {
448 outputs.push(AgentOutput::CallModel {
449 model: self.config.model.clone(),
450 messages: context.messages.clone(),
451 system_context: Box::new(self.config.system_context.clone()),
452 tools: self.config.tools.clone(),
453 });
454
455 (
456 AgentState::AwaitingModel {
457 messages: context.messages,
458 },
459 outputs,
460 )
461 } else {
462 (
463 AgentState::AwaitingToolResults {
464 messages: context.messages,
465 pending_results: context.pending_results,
466 completed_results: context.completed_results,
467 },
468 outputs,
469 )
470 }
471 }
472
473 fn handle_tool_failed(
474 &self,
475 context: ToolCompletionContext,
476 error: ToolError,
477 ) -> (AgentState, Vec<AgentOutput>) {
478 let result = ToolResult::Error(error);
479 self.handle_tool_completed(context, result)
480 }
481
482 fn handle_cancel(state: AgentState) -> (AgentState, Vec<AgentOutput>) {
483 let mut outputs = Vec::new();
484
485 match state {
486 AgentState::AwaitingToolApprovals {
487 mut messages,
488 pending_approvals,
489 approved,
490 denied: _,
491 } => {
492 for tool_call in pending_approvals.into_iter().chain(approved.into_iter()) {
493 Self::emit_tool_error_message(
494 &mut messages,
495 &mut outputs,
496 &tool_call,
497 ToolError::Cancelled(tool_call.name.clone()),
498 );
499 }
500 }
501 AgentState::AwaitingToolResults {
502 mut messages,
503 pending_results,
504 completed_results: _,
505 } => {
506 for (_, tool_call) in pending_results {
507 Self::emit_tool_error_message(
508 &mut messages,
509 &mut outputs,
510 &tool_call,
511 ToolError::Cancelled(tool_call.name.clone()),
512 );
513 }
514 }
515 _ => {}
516 }
517
518 outputs.push(AgentOutput::Cancelled);
519 (AgentState::Cancelled, outputs)
520 }
521
522 pub fn needs_model_call(&self, state: &AgentState) -> bool {
523 matches!(state, AgentState::AwaitingModel { .. })
524 }
525
526 pub fn is_terminal(&self, state: &AgentState) -> bool {
527 matches!(
528 state,
529 AgentState::Complete { .. } | AgentState::Failed { .. } | AgentState::Cancelled
530 )
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use crate::config::model::builtin;
538
539 fn test_config() -> AgentConfig {
540 AgentConfig {
541 model: builtin::claude_sonnet_4_5(),
542 system_context: None,
543 tools: vec![],
544 }
545 }
546
547 #[test]
548 fn test_initial_state() {
549 let state = AgentStepper::initial_state(vec![]);
550 assert!(matches!(state, AgentState::AwaitingModel { .. }));
551 }
552
553 #[test]
554 fn test_model_response_no_tools_completes() {
555 let stepper = AgentStepper::new(test_config());
556 let state = AgentState::AwaitingModel { messages: vec![] };
557
558 let (new_state, outputs) = stepper.step(
559 state,
560 AgentInput::ModelResponse {
561 content: vec![],
562 tool_calls: vec![],
563 message_id: MessageId::new(),
564 timestamp: 0,
565 },
566 );
567
568 assert!(matches!(new_state, AgentState::Complete { .. }));
569 assert!(
570 outputs
571 .iter()
572 .any(|o| matches!(o, AgentOutput::Done { .. }))
573 );
574 }
575
576 #[test]
577 fn test_model_response_with_tools_requests_approval() {
578 let stepper = AgentStepper::new(test_config());
579 let state = AgentState::AwaitingModel { messages: vec![] };
580
581 let tool_call = ToolCall {
582 id: "tc_1".to_string(),
583 name: "test_tool".to_string(),
584 parameters: serde_json::json!({}),
585 };
586
587 let (new_state, outputs) = stepper.step(
588 state,
589 AgentInput::ModelResponse {
590 content: vec![],
591 tool_calls: vec![tool_call],
592 message_id: MessageId::new(),
593 timestamp: 0,
594 },
595 );
596
597 assert!(matches!(
598 new_state,
599 AgentState::AwaitingToolApprovals { .. }
600 ));
601 assert!(
602 outputs
603 .iter()
604 .any(|o| matches!(o, AgentOutput::RequestApproval { .. }))
605 );
606 }
607
608 #[test]
609 fn test_tool_denied_emits_tool_message() {
610 let stepper = AgentStepper::new(test_config());
611 let tool_call = ToolCall {
612 id: "tc_1".to_string(),
613 name: "test_tool".to_string(),
614 parameters: serde_json::json!({}),
615 };
616
617 let state = AgentState::AwaitingToolApprovals {
618 messages: vec![],
619 pending_approvals: vec![tool_call.clone()],
620 approved: vec![],
621 denied: vec![],
622 };
623
624 let (_new_state, outputs) = stepper.step(
625 state,
626 AgentInput::ToolDenied {
627 tool_call_id: ToolCallId::from_string("tc_1"),
628 },
629 );
630
631 let tool_message = outputs
632 .iter()
633 .find_map(|output| match output {
634 AgentOutput::EmitMessage { message } => Some(message),
635 _ => None,
636 })
637 .expect("tool denial should emit a tool result message");
638
639 match &tool_message.data {
640 MessageData::Tool { result, .. } => match result {
641 ToolResult::Error(error) => {
642 assert!(matches!(error, ToolError::DeniedByUser(name) if name == "test_tool"));
643 }
644 _ => panic!("expected denied tool error"),
645 },
646 _ => panic!("expected tool message"),
647 }
648 }
649
650 #[test]
651 fn test_cancel_emits_tool_results_for_pending_approvals() {
652 let stepper = AgentStepper::new(test_config());
653 let tool_call = ToolCall {
654 id: "tc_1".to_string(),
655 name: "test_tool".to_string(),
656 parameters: serde_json::json!({}),
657 };
658
659 let state = AgentState::AwaitingToolApprovals {
660 messages: vec![],
661 pending_approvals: vec![tool_call.clone()],
662 approved: vec![],
663 denied: vec![],
664 };
665
666 let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
667
668 assert!(matches!(new_state, AgentState::Cancelled));
669 assert!(outputs.iter().any(|o| matches!(o, AgentOutput::Cancelled)));
670
671 let tool_message = outputs
672 .iter()
673 .find_map(|output| match output {
674 AgentOutput::EmitMessage { message } => Some(message),
675 _ => None,
676 })
677 .expect("cancel should emit tool result messages");
678
679 match &tool_message.data {
680 MessageData::Tool { result, .. } => match result {
681 ToolResult::Error(error) => {
682 assert!(matches!(error, ToolError::Cancelled(name) if name == "test_tool"));
683 }
684 _ => panic!("expected cancelled tool error"),
685 },
686 _ => panic!("expected tool message"),
687 }
688 }
689
690 #[test]
691 fn test_cancel_from_any_state() {
692 let stepper = AgentStepper::new(test_config());
693
694 let states = vec![
695 AgentState::AwaitingModel { messages: vec![] },
696 AgentState::AwaitingToolApprovals {
697 messages: vec![],
698 pending_approvals: vec![],
699 approved: vec![],
700 denied: vec![],
701 },
702 ];
703
704 for state in states {
705 let (new_state, outputs) = stepper.step(state, AgentInput::Cancel);
706 assert!(matches!(new_state, AgentState::Cancelled));
707 assert!(outputs.iter().any(|o| matches!(o, AgentOutput::Cancelled)));
708 }
709 }
710}