Skip to main content

steer_core/app/domain/
state.rs

1use crate::app::SystemContext;
2use crate::app::conversation::MessageGraph;
3use crate::app::domain::action::McpServerState;
4use crate::app::domain::types::{
5    MessageId, NonEmptyString, OpId, RequestId, SessionId, ToolCallId,
6};
7use crate::config::model::ModelId;
8use crate::prompts::system_prompt_for_model;
9use crate::session::state::SessionConfig;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use steer_tools::{ToolCall, ToolSchema};
13
14#[derive(Debug, Clone)]
15pub struct AppState {
16    pub session_id: SessionId,
17    pub session_config: Option<SessionConfig>,
18    pub base_session_config: Option<SessionConfig>,
19    pub primary_agent_id: Option<String>,
20
21    pub message_graph: MessageGraph,
22
23    pub cached_system_context: Option<SystemContext>,
24
25    pub tools: Vec<ToolSchema>,
26
27    pub approved_tools: HashSet<String>,
28    pub approved_bash_patterns: HashSet<String>,
29    pub static_bash_patterns: Vec<String>,
30    pub pending_approval: Option<PendingApproval>,
31    pub approval_queue: VecDeque<QueuedApproval>,
32    pub queued_work: VecDeque<QueuedWorkItem>,
33
34    pub current_operation: Option<OperationState>,
35
36    pub active_streams: HashMap<OpId, StreamingMessage>,
37
38    pub workspace_files: Vec<String>,
39
40    pub mcp_servers: HashMap<String, McpServerState>,
41
42    pub cancelled_ops: HashSet<OpId>,
43
44    pub operation_models: HashMap<OpId, ModelId>,
45    pub operation_messages: HashMap<OpId, MessageId>,
46
47    pub event_sequence: u64,
48}
49
50#[derive(Debug, Clone)]
51pub struct PendingApproval {
52    pub request_id: RequestId,
53    pub tool_call: ToolCall,
54}
55
56#[derive(Debug, Clone)]
57pub struct QueuedApproval {
58    pub tool_call: ToolCall,
59}
60
61#[derive(Debug, Clone)]
62pub struct QueuedUserMessage {
63    pub text: NonEmptyString,
64    pub op_id: OpId,
65    pub message_id: MessageId,
66    pub model: ModelId,
67    pub queued_at: u64,
68}
69
70#[derive(Debug, Clone)]
71pub struct QueuedBashCommand {
72    pub command: String,
73    pub op_id: OpId,
74    pub message_id: MessageId,
75    pub queued_at: u64,
76}
77
78#[derive(Debug, Clone)]
79pub enum QueuedWorkItem {
80    UserMessage(QueuedUserMessage),
81    DirectBash(QueuedBashCommand),
82}
83
84#[derive(Debug, Clone)]
85pub struct OperationState {
86    pub op_id: OpId,
87    pub kind: OperationKind,
88    pub pending_tool_calls: HashSet<ToolCallId>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub enum OperationKind {
93    AgentLoop,
94    Compact,
95    DirectBash { command: String },
96}
97
98#[derive(Debug, Clone)]
99pub struct StreamingMessage {
100    pub message_id: MessageId,
101    pub op_id: OpId,
102    pub content: String,
103    pub tool_calls: Vec<ToolCall>,
104    pub byte_count: usize,
105}
106
107pub struct StreamingConfig {
108    pub max_buffer_bytes: usize,
109    pub max_concurrent_streams: usize,
110}
111
112impl Default for StreamingConfig {
113    fn default() -> Self {
114        Self {
115            max_buffer_bytes: 64 * 1024,
116            max_concurrent_streams: 3,
117        }
118    }
119}
120
121const MAX_CANCELLED_OPS: usize = 100;
122
123impl AppState {
124    pub fn new(session_id: SessionId) -> Self {
125        Self {
126            session_id,
127            session_config: None,
128            base_session_config: None,
129            primary_agent_id: None,
130            message_graph: MessageGraph::new(),
131            cached_system_context: None,
132            tools: Vec::new(),
133            approved_tools: HashSet::new(),
134            approved_bash_patterns: HashSet::new(),
135            static_bash_patterns: Vec::new(),
136            pending_approval: None,
137            approval_queue: VecDeque::new(),
138            queued_work: VecDeque::new(),
139            current_operation: None,
140            active_streams: HashMap::new(),
141            workspace_files: Vec::new(),
142            mcp_servers: HashMap::new(),
143            cancelled_ops: HashSet::new(),
144            operation_models: HashMap::new(),
145            operation_messages: HashMap::new(),
146            event_sequence: 0,
147        }
148    }
149
150    pub fn with_approved_patterns(mut self, patterns: Vec<String>) -> Self {
151        self.static_bash_patterns = patterns;
152        self
153    }
154
155    pub fn with_approved_tools(mut self, tools: HashSet<String>) -> Self {
156        self.approved_tools = tools;
157        self
158    }
159
160    pub fn is_tool_pre_approved(&self, tool_name: &str) -> bool {
161        self.approved_tools.contains(tool_name)
162    }
163
164    pub fn is_bash_pattern_approved(&self, command: &str) -> bool {
165        for pattern in self
166            .static_bash_patterns
167            .iter()
168            .chain(self.approved_bash_patterns.iter())
169        {
170            if pattern == command {
171                return true;
172            }
173            if let Ok(glob) = glob::Pattern::new(pattern)
174                && glob.matches(command)
175            {
176                return true;
177            }
178        }
179        false
180    }
181
182    pub fn approve_tool(&mut self, tool_name: String) {
183        self.approved_tools.insert(tool_name);
184    }
185
186    pub fn approve_bash_pattern(&mut self, pattern: String) {
187        self.approved_bash_patterns.insert(pattern);
188    }
189
190    pub fn record_cancelled_op(&mut self, op_id: OpId) {
191        self.cancelled_ops.insert(op_id);
192        if self.cancelled_ops.len() > MAX_CANCELLED_OPS
193            && let Some(&oldest) = self.cancelled_ops.iter().next()
194        {
195            self.cancelled_ops.remove(&oldest);
196        }
197    }
198
199    pub fn is_op_cancelled(&self, op_id: &OpId) -> bool {
200        self.cancelled_ops.contains(op_id)
201    }
202
203    pub fn has_pending_approval(&self) -> bool {
204        self.pending_approval.is_some()
205    }
206
207    pub fn has_active_operation(&self) -> bool {
208        self.current_operation.is_some()
209    }
210
211    pub fn queue_user_message(&mut self, item: QueuedUserMessage) {
212        if let Some(QueuedWorkItem::UserMessage(tail)) = self.queued_work.back_mut() {
213            let combined = format!("{}\n\n{}", tail.text.as_str(), item.text.as_str());
214            if let Some(text) = NonEmptyString::new(combined) {
215                tail.text = text;
216                tail.op_id = item.op_id;
217                tail.message_id = item.message_id;
218                tail.model = item.model;
219                tail.queued_at = item.queued_at;
220                return;
221            }
222        }
223        self.queued_work
224            .push_back(QueuedWorkItem::UserMessage(item));
225    }
226
227    pub fn queue_bash_command(&mut self, item: QueuedBashCommand) {
228        self.queued_work.push_back(QueuedWorkItem::DirectBash(item));
229    }
230
231    pub fn pop_next_queued_work(&mut self) -> Option<QueuedWorkItem> {
232        self.queued_work.pop_front()
233    }
234
235    pub fn queued_summary(&self) -> (Option<QueuedWorkItem>, usize) {
236        (self.queued_work.front().cloned(), self.queued_work.len())
237    }
238
239    pub fn start_operation(&mut self, op_id: OpId, kind: OperationKind) {
240        self.current_operation = Some(OperationState {
241            op_id,
242            kind,
243            pending_tool_calls: HashSet::new(),
244        });
245    }
246
247    pub fn complete_operation(&mut self, op_id: OpId) {
248        self.operation_models.remove(&op_id);
249        self.operation_messages.remove(&op_id);
250        if self
251            .current_operation
252            .as_ref()
253            .is_some_and(|op| op.op_id == op_id)
254        {
255            self.current_operation = None;
256        }
257    }
258
259    pub fn add_pending_tool_call(&mut self, tool_call_id: ToolCallId) {
260        if let Some(ref mut op) = self.current_operation {
261            op.pending_tool_calls.insert(tool_call_id);
262        }
263    }
264
265    pub fn remove_pending_tool_call(&mut self, tool_call_id: &ToolCallId) {
266        if let Some(ref mut op) = self.current_operation {
267            op.pending_tool_calls.remove(tool_call_id);
268        }
269    }
270
271    pub fn increment_sequence(&mut self) -> u64 {
272        self.event_sequence += 1;
273        self.event_sequence
274    }
275
276    pub fn apply_session_config(
277        &mut self,
278        config: &SessionConfig,
279        primary_agent_id: Option<String>,
280        update_base: bool,
281    ) {
282        self.session_config = Some(config.clone());
283        let prompt = config
284            .system_prompt
285            .as_ref()
286            .and_then(|prompt| {
287                if prompt.trim().is_empty() {
288                    None
289                } else {
290                    Some(prompt.clone())
291                }
292            })
293            .unwrap_or_else(|| system_prompt_for_model(&config.default_model));
294        let environment = self
295            .cached_system_context
296            .as_ref()
297            .and_then(|context| context.environment.clone());
298        self.cached_system_context = Some(SystemContext::with_environment(prompt, environment));
299
300        self.approved_tools
301            .clone_from(config.tool_config.approval_policy.pre_approved_tools());
302        self.approved_bash_patterns.clear();
303        self.static_bash_patterns = config
304            .tool_config
305            .approval_policy
306            .preapproved
307            .bash_patterns()
308            .map(|patterns| patterns.to_vec())
309            .unwrap_or_default();
310        self.pending_approval = None;
311        self.approval_queue.clear();
312
313        if let Some(primary_agent_id) = primary_agent_id.or_else(|| config.primary_agent_id.clone())
314        {
315            self.primary_agent_id = Some(primary_agent_id);
316        }
317
318        if update_base {
319            self.base_session_config = Some(config.clone());
320        }
321    }
322}