Skip to main content

steer_core/app/domain/
state.rs

1use crate::app::SystemContext;
2use crate::app::conversation::MessageGraph;
3use crate::app::conversation::UserContent;
4use crate::app::domain::action::McpServerState;
5use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
6use crate::config::model::ModelId;
7use crate::prompts::system_prompt_for_model;
8use crate::session::state::SessionConfig;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11use steer_tools::{ToolCall, ToolSchema};
12
13#[derive(Debug, Clone)]
14pub struct AppState {
15    pub session_id: SessionId,
16    pub session_config: Option<SessionConfig>,
17    pub base_session_config: Option<SessionConfig>,
18    pub primary_agent_id: Option<String>,
19
20    pub message_graph: MessageGraph,
21
22    pub cached_system_context: Option<SystemContext>,
23
24    pub tools: Vec<ToolSchema>,
25
26    pub approved_tools: HashSet<String>,
27    pub approved_bash_patterns: HashSet<String>,
28    pub static_bash_patterns: Vec<String>,
29    pub pending_approval: Option<PendingApproval>,
30    pub approval_queue: VecDeque<QueuedApproval>,
31    pub queued_work: VecDeque<QueuedWorkItem>,
32
33    pub current_operation: Option<OperationState>,
34
35    pub active_streams: HashMap<OpId, StreamingMessage>,
36
37    pub workspace_files: Vec<String>,
38
39    pub mcp_servers: HashMap<String, McpServerState>,
40
41    pub cancelled_ops: HashSet<OpId>,
42
43    pub operation_models: HashMap<OpId, ModelId>,
44    pub operation_messages: HashMap<OpId, MessageId>,
45
46    pub event_sequence: u64,
47}
48
49#[derive(Debug, Clone)]
50pub struct PendingApproval {
51    pub request_id: RequestId,
52    pub tool_call: ToolCall,
53}
54
55#[derive(Debug, Clone)]
56pub struct QueuedApproval {
57    pub tool_call: ToolCall,
58}
59
60#[derive(Debug, Clone)]
61pub struct QueuedUserMessage {
62    pub content: Vec<UserContent>,
63    pub op_id: OpId,
64    pub message_id: MessageId,
65    pub model: ModelId,
66    pub queued_at: u64,
67}
68
69#[derive(Debug, Clone)]
70pub struct QueuedBashCommand {
71    pub command: String,
72    pub op_id: OpId,
73    pub message_id: MessageId,
74    pub queued_at: u64,
75}
76
77#[derive(Debug, Clone)]
78pub enum QueuedWorkItem {
79    UserMessage(QueuedUserMessage),
80    DirectBash(QueuedBashCommand),
81}
82
83#[derive(Debug, Clone)]
84pub struct OperationState {
85    pub op_id: OpId,
86    pub kind: OperationKind,
87    pub pending_tool_calls: HashSet<ToolCallId>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum OperationKind {
92    AgentLoop,
93    Compact,
94    DirectBash { command: String },
95}
96
97#[derive(Debug, Clone)]
98pub struct StreamingMessage {
99    pub message_id: MessageId,
100    pub op_id: OpId,
101    pub content: String,
102    pub tool_calls: Vec<ToolCall>,
103    pub byte_count: usize,
104}
105
106pub struct StreamingConfig {
107    pub max_buffer_bytes: usize,
108    pub max_concurrent_streams: usize,
109}
110
111impl Default for StreamingConfig {
112    fn default() -> Self {
113        Self {
114            max_buffer_bytes: 64 * 1024,
115            max_concurrent_streams: 3,
116        }
117    }
118}
119
120const MAX_CANCELLED_OPS: usize = 100;
121
122impl AppState {
123    pub fn new(session_id: SessionId) -> Self {
124        Self {
125            session_id,
126            session_config: None,
127            base_session_config: None,
128            primary_agent_id: None,
129            message_graph: MessageGraph::new(),
130            cached_system_context: None,
131            tools: Vec::new(),
132            approved_tools: HashSet::new(),
133            approved_bash_patterns: HashSet::new(),
134            static_bash_patterns: Vec::new(),
135            pending_approval: None,
136            approval_queue: VecDeque::new(),
137            queued_work: VecDeque::new(),
138            current_operation: None,
139            active_streams: HashMap::new(),
140            workspace_files: Vec::new(),
141            mcp_servers: HashMap::new(),
142            cancelled_ops: HashSet::new(),
143            operation_models: HashMap::new(),
144            operation_messages: HashMap::new(),
145            event_sequence: 0,
146        }
147    }
148
149    pub fn with_approved_patterns(mut self, patterns: Vec<String>) -> Self {
150        self.static_bash_patterns = patterns;
151        self
152    }
153
154    pub fn with_approved_tools(mut self, tools: HashSet<String>) -> Self {
155        self.approved_tools = tools;
156        self
157    }
158
159    pub fn is_tool_pre_approved(&self, tool_name: &str) -> bool {
160        self.approved_tools.contains(tool_name)
161    }
162
163    pub fn is_bash_pattern_approved(&self, command: &str) -> bool {
164        for pattern in self
165            .static_bash_patterns
166            .iter()
167            .chain(self.approved_bash_patterns.iter())
168        {
169            if pattern == command {
170                return true;
171            }
172            if let Ok(glob) = glob::Pattern::new(pattern)
173                && glob.matches(command)
174            {
175                return true;
176            }
177        }
178        false
179    }
180
181    pub fn approve_tool(&mut self, tool_name: String) {
182        self.approved_tools.insert(tool_name);
183    }
184
185    pub fn approve_bash_pattern(&mut self, pattern: String) {
186        self.approved_bash_patterns.insert(pattern);
187    }
188
189    pub fn record_cancelled_op(&mut self, op_id: OpId) {
190        self.cancelled_ops.insert(op_id);
191        if self.cancelled_ops.len() > MAX_CANCELLED_OPS
192            && let Some(&oldest) = self.cancelled_ops.iter().next()
193        {
194            self.cancelled_ops.remove(&oldest);
195        }
196    }
197
198    pub fn is_op_cancelled(&self, op_id: &OpId) -> bool {
199        self.cancelled_ops.contains(op_id)
200    }
201
202    pub fn has_pending_approval(&self) -> bool {
203        self.pending_approval.is_some()
204    }
205
206    pub fn has_active_operation(&self) -> bool {
207        self.current_operation.is_some()
208    }
209
210    pub fn queue_user_message(&mut self, item: QueuedUserMessage) {
211        if let Some(QueuedWorkItem::UserMessage(tail)) = self.queued_work.back_mut() {
212            if tail.content.iter().any(|item| {
213                !matches!(item, UserContent::Text { text } if text.as_str().trim().is_empty())
214            }) {
215                tail.content.push(UserContent::Text {
216                    text: "\n\n".to_string(),
217                });
218            }
219            tail.content.extend(item.content);
220            tail.op_id = item.op_id;
221            tail.message_id = item.message_id;
222            tail.model = item.model;
223            tail.queued_at = item.queued_at;
224            return;
225        }
226        self.queued_work
227            .push_back(QueuedWorkItem::UserMessage(item));
228    }
229
230    pub fn queue_bash_command(&mut self, item: QueuedBashCommand) {
231        self.queued_work.push_back(QueuedWorkItem::DirectBash(item));
232    }
233
234    pub fn pop_next_queued_work(&mut self) -> Option<QueuedWorkItem> {
235        self.queued_work.pop_front()
236    }
237
238    pub fn queued_summary(&self) -> (Option<QueuedWorkItem>, usize) {
239        (self.queued_work.front().cloned(), self.queued_work.len())
240    }
241
242    pub fn start_operation(&mut self, op_id: OpId, kind: OperationKind) {
243        self.current_operation = Some(OperationState {
244            op_id,
245            kind,
246            pending_tool_calls: HashSet::new(),
247        });
248    }
249
250    pub fn complete_operation(&mut self, op_id: OpId) {
251        self.operation_models.remove(&op_id);
252        self.operation_messages.remove(&op_id);
253        if self
254            .current_operation
255            .as_ref()
256            .is_some_and(|op| op.op_id == op_id)
257        {
258            self.current_operation = None;
259        }
260    }
261
262    pub fn add_pending_tool_call(&mut self, tool_call_id: ToolCallId) {
263        if let Some(ref mut op) = self.current_operation {
264            op.pending_tool_calls.insert(tool_call_id);
265        }
266    }
267
268    pub fn remove_pending_tool_call(&mut self, tool_call_id: &ToolCallId) {
269        if let Some(ref mut op) = self.current_operation {
270            op.pending_tool_calls.remove(tool_call_id);
271        }
272    }
273
274    pub fn increment_sequence(&mut self) -> u64 {
275        self.event_sequence += 1;
276        self.event_sequence
277    }
278
279    pub fn apply_session_config(
280        &mut self,
281        config: &SessionConfig,
282        primary_agent_id: Option<String>,
283        update_base: bool,
284    ) {
285        self.session_config = Some(config.clone());
286        let prompt = config
287            .system_prompt
288            .as_ref()
289            .and_then(|prompt| {
290                if prompt.trim().is_empty() {
291                    None
292                } else {
293                    Some(prompt.clone())
294                }
295            })
296            .unwrap_or_else(|| system_prompt_for_model(&config.default_model));
297        let environment = self
298            .cached_system_context
299            .as_ref()
300            .and_then(|context| context.environment.clone());
301        self.cached_system_context = Some(SystemContext::with_environment(prompt, environment));
302
303        self.approved_tools
304            .clone_from(config.tool_config.approval_policy.pre_approved_tools());
305        self.approved_bash_patterns.clear();
306        self.static_bash_patterns = config
307            .tool_config
308            .approval_policy
309            .preapproved
310            .bash_patterns()
311            .map(|patterns| patterns.to_vec())
312            .unwrap_or_default();
313        self.pending_approval = None;
314        self.approval_queue.clear();
315
316        if let Some(primary_agent_id) = primary_agent_id.or_else(|| config.primary_agent_id.clone())
317        {
318            self.primary_agent_id = Some(primary_agent_id);
319        }
320
321        if update_base {
322            self.base_session_config = Some(config.clone());
323        }
324    }
325}