Skip to main content

steer_core/app/domain/
state.rs

1use crate::api::provider::TokenUsage;
2use crate::app::SystemContext;
3use crate::app::conversation::MessageGraph;
4use crate::app::conversation::UserContent;
5use crate::app::domain::action::McpServerState;
6use crate::app::domain::event::ContextWindowUsage;
7use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
8use crate::config::model::ModelId;
9use crate::prompts::system_prompt_for_model;
10use crate::session::state::SessionConfig;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13use steer_tools::{ToolCall, ToolSchema};
14
15#[derive(Debug, Clone)]
16pub struct AppState {
17    pub session_id: SessionId,
18    pub session_config: Option<SessionConfig>,
19    pub base_session_config: Option<SessionConfig>,
20    pub primary_agent_id: Option<String>,
21
22    pub message_graph: MessageGraph,
23
24    pub cached_system_context: Option<SystemContext>,
25
26    pub tools: Vec<ToolSchema>,
27
28    pub approved_tools: HashSet<String>,
29    pub approved_bash_patterns: HashSet<String>,
30    pub static_bash_patterns: Vec<String>,
31    pub pending_approval: Option<PendingApproval>,
32    pub approval_queue: VecDeque<QueuedApproval>,
33    pub queued_work: VecDeque<QueuedWorkItem>,
34
35    pub current_operation: Option<OperationState>,
36
37    pub active_streams: HashMap<OpId, StreamingMessage>,
38
39    pub workspace_files: Vec<String>,
40
41    pub mcp_servers: HashMap<String, McpServerState>,
42
43    pub cancelled_ops: HashSet<OpId>,
44
45    pub operation_models: HashMap<OpId, ModelId>,
46    pub operation_messages: HashMap<OpId, MessageId>,
47
48    pub llm_usage_by_op: HashMap<OpId, LlmUsageSnapshot>,
49    pub llm_usage_totals: TokenUsage,
50
51    pub event_sequence: u64,
52
53    /// Message IDs that are compaction summaries.
54    pub compaction_summary_ids: HashSet<String>,
55}
56
57#[derive(Debug, Clone)]
58pub struct PendingApproval {
59    pub request_id: RequestId,
60    pub tool_call: ToolCall,
61}
62
63#[derive(Debug, Clone)]
64pub struct LlmUsageSnapshot {
65    pub model: ModelId,
66    pub usage: TokenUsage,
67    pub context_window: Option<ContextWindowUsage>,
68}
69
70#[derive(Debug, Clone)]
71pub struct QueuedApproval {
72    pub tool_call: ToolCall,
73}
74
75#[derive(Debug, Clone)]
76pub struct QueuedUserMessage {
77    pub content: Vec<UserContent>,
78    pub op_id: OpId,
79    pub message_id: MessageId,
80    pub model: ModelId,
81    pub queued_at: u64,
82}
83
84#[derive(Debug, Clone)]
85pub struct QueuedBashCommand {
86    pub command: String,
87    pub op_id: OpId,
88    pub message_id: MessageId,
89    pub queued_at: u64,
90}
91
92#[derive(Debug, Clone)]
93pub enum QueuedWorkItem {
94    UserMessage(QueuedUserMessage),
95    DirectBash(QueuedBashCommand),
96}
97
98#[derive(Debug, Clone)]
99pub struct OperationState {
100    pub op_id: OpId,
101    pub kind: OperationKind,
102    pub pending_tool_calls: HashSet<ToolCallId>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum OperationKind {
107    AgentLoop,
108    Compact {
109        trigger: crate::app::domain::event::CompactTrigger,
110    },
111    DirectBash {
112        command: String,
113    },
114}
115
116#[derive(Debug, Clone)]
117pub struct StreamingMessage {
118    pub message_id: MessageId,
119    pub op_id: OpId,
120    pub content: String,
121    pub tool_calls: Vec<ToolCall>,
122    pub byte_count: usize,
123}
124
125pub struct StreamingConfig {
126    pub max_buffer_bytes: usize,
127    pub max_concurrent_streams: usize,
128}
129
130impl Default for StreamingConfig {
131    fn default() -> Self {
132        Self {
133            max_buffer_bytes: 64 * 1024,
134            max_concurrent_streams: 3,
135        }
136    }
137}
138
139const MAX_CANCELLED_OPS: usize = 100;
140
141impl AppState {
142    pub fn new(session_id: SessionId) -> Self {
143        Self {
144            session_id,
145            session_config: None,
146            base_session_config: None,
147            primary_agent_id: None,
148            message_graph: MessageGraph::new(),
149            cached_system_context: None,
150            tools: Vec::new(),
151            approved_tools: HashSet::new(),
152            approved_bash_patterns: HashSet::new(),
153            static_bash_patterns: Vec::new(),
154            pending_approval: None,
155            approval_queue: VecDeque::new(),
156            queued_work: VecDeque::new(),
157            current_operation: None,
158            active_streams: HashMap::new(),
159            workspace_files: Vec::new(),
160            mcp_servers: HashMap::new(),
161            cancelled_ops: HashSet::new(),
162            operation_models: HashMap::new(),
163            operation_messages: HashMap::new(),
164            llm_usage_by_op: HashMap::new(),
165            llm_usage_totals: TokenUsage::new(0, 0, 0),
166            event_sequence: 0,
167            compaction_summary_ids: HashSet::new(),
168        }
169    }
170
171    pub fn with_approved_patterns(mut self, patterns: Vec<String>) -> Self {
172        self.static_bash_patterns = patterns;
173        self
174    }
175
176    pub fn with_approved_tools(mut self, tools: HashSet<String>) -> Self {
177        self.approved_tools = tools;
178        self
179    }
180
181    pub fn is_tool_pre_approved(&self, tool_name: &str) -> bool {
182        self.approved_tools.contains(tool_name)
183    }
184
185    pub fn is_bash_pattern_approved(&self, command: &str) -> bool {
186        for pattern in self
187            .static_bash_patterns
188            .iter()
189            .chain(self.approved_bash_patterns.iter())
190        {
191            if pattern == command {
192                return true;
193            }
194            if let Ok(glob) = glob::Pattern::new(pattern)
195                && glob.matches(command)
196            {
197                return true;
198            }
199        }
200        false
201    }
202
203    pub fn approve_tool(&mut self, tool_name: String) {
204        self.approved_tools.insert(tool_name);
205    }
206
207    pub fn approve_bash_pattern(&mut self, pattern: String) {
208        self.approved_bash_patterns.insert(pattern);
209    }
210
211    pub fn record_cancelled_op(&mut self, op_id: OpId) {
212        self.cancelled_ops.insert(op_id);
213        if self.cancelled_ops.len() > MAX_CANCELLED_OPS
214            && let Some(&oldest) = self.cancelled_ops.iter().next()
215        {
216            self.cancelled_ops.remove(&oldest);
217        }
218    }
219
220    pub fn is_op_cancelled(&self, op_id: &OpId) -> bool {
221        self.cancelled_ops.contains(op_id)
222    }
223
224    pub fn has_pending_approval(&self) -> bool {
225        self.pending_approval.is_some()
226    }
227
228    pub fn has_active_operation(&self) -> bool {
229        self.current_operation.is_some()
230    }
231
232    pub fn queue_user_message(&mut self, item: QueuedUserMessage) {
233        if let Some(QueuedWorkItem::UserMessage(tail)) = self.queued_work.back_mut() {
234            if tail.content.iter().any(|item| {
235                !matches!(item, UserContent::Text { text } if text.as_str().trim().is_empty())
236            }) {
237                tail.content.push(UserContent::Text {
238                    text: "\n\n".to_string(),
239                });
240            }
241            tail.content.extend(item.content);
242            tail.op_id = item.op_id;
243            tail.message_id = item.message_id;
244            tail.model = item.model;
245            tail.queued_at = item.queued_at;
246            return;
247        }
248        self.queued_work
249            .push_back(QueuedWorkItem::UserMessage(item));
250    }
251
252    pub fn queue_bash_command(&mut self, item: QueuedBashCommand) {
253        self.queued_work.push_back(QueuedWorkItem::DirectBash(item));
254    }
255
256    pub fn pop_next_queued_work(&mut self) -> Option<QueuedWorkItem> {
257        self.queued_work.pop_front()
258    }
259
260    pub fn queued_summary(&self) -> (Option<QueuedWorkItem>, usize) {
261        (self.queued_work.front().cloned(), self.queued_work.len())
262    }
263
264    pub fn start_operation(&mut self, op_id: OpId, kind: OperationKind) {
265        self.current_operation = Some(OperationState {
266            op_id,
267            kind,
268            pending_tool_calls: HashSet::new(),
269        });
270    }
271
272    pub fn complete_operation(&mut self, op_id: OpId) {
273        self.operation_models.remove(&op_id);
274        self.operation_messages.remove(&op_id);
275        if self
276            .current_operation
277            .as_ref()
278            .is_some_and(|op| op.op_id == op_id)
279        {
280            self.current_operation = None;
281        }
282    }
283
284    pub fn add_pending_tool_call(&mut self, tool_call_id: ToolCallId) {
285        if let Some(ref mut op) = self.current_operation {
286            op.pending_tool_calls.insert(tool_call_id);
287        }
288    }
289
290    pub fn remove_pending_tool_call(&mut self, tool_call_id: &ToolCallId) {
291        if let Some(ref mut op) = self.current_operation {
292            op.pending_tool_calls.remove(tool_call_id);
293        }
294    }
295
296    pub fn increment_sequence(&mut self) -> u64 {
297        self.event_sequence += 1;
298        self.event_sequence
299    }
300
301    pub fn record_llm_usage(
302        &mut self,
303        op_id: OpId,
304        model: ModelId,
305        usage: TokenUsage,
306        context_window: Option<ContextWindowUsage>,
307    ) {
308        self.llm_usage_by_op.insert(
309            op_id,
310            LlmUsageSnapshot {
311                model,
312                usage,
313                context_window,
314            },
315        );
316        self.recompute_llm_usage_totals();
317    }
318
319    fn recompute_llm_usage_totals(&mut self) {
320        let mut input_tokens = 0u32;
321        let mut output_tokens = 0u32;
322        let mut total_tokens = 0u32;
323
324        for snapshot in self.llm_usage_by_op.values() {
325            input_tokens = input_tokens.saturating_add(snapshot.usage.input_tokens);
326            output_tokens = output_tokens.saturating_add(snapshot.usage.output_tokens);
327            total_tokens = total_tokens.saturating_add(snapshot.usage.total_tokens);
328        }
329
330        self.llm_usage_totals = TokenUsage::new(input_tokens, output_tokens, total_tokens);
331    }
332
333    pub fn apply_session_config(
334        &mut self,
335        config: &SessionConfig,
336        primary_agent_id: Option<String>,
337        update_base: bool,
338    ) {
339        self.session_config = Some(config.clone());
340        let prompt = config
341            .system_prompt
342            .as_ref()
343            .and_then(|prompt| {
344                if prompt.trim().is_empty() {
345                    None
346                } else {
347                    Some(prompt.clone())
348                }
349            })
350            .unwrap_or_else(|| system_prompt_for_model(&config.default_model));
351        let environment = self
352            .cached_system_context
353            .as_ref()
354            .and_then(|context| context.environment.clone());
355        self.cached_system_context = Some(SystemContext::with_environment(prompt, environment));
356
357        self.approved_tools
358            .clone_from(config.tool_config.approval_policy.pre_approved_tools());
359        self.approved_bash_patterns.clear();
360        self.static_bash_patterns = config
361            .tool_config
362            .approval_policy
363            .preapproved
364            .bash_patterns()
365            .map(|patterns| patterns.to_vec())
366            .unwrap_or_default();
367        self.pending_approval = None;
368        self.approval_queue.clear();
369
370        if let Some(primary_agent_id) = primary_agent_id.or_else(|| config.primary_agent_id.clone())
371        {
372            self.primary_agent_id = Some(primary_agent_id);
373        }
374
375        if update_base {
376            self.base_session_config = Some(config.clone());
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::config::model::builtin;
385
386    #[test]
387    fn record_llm_usage_replaces_snapshot_for_op_and_recomputes_totals() {
388        let mut state = AppState::new(SessionId::new());
389        let model = builtin::claude_sonnet_4_5();
390        let op_a = OpId::new();
391        let op_b = OpId::new();
392
393        state.record_llm_usage(op_a, model.clone(), TokenUsage::new(3, 5, 8), None);
394        assert_eq!(state.llm_usage_totals, TokenUsage::new(3, 5, 8));
395
396        state.record_llm_usage(
397            op_a,
398            model.clone(),
399            TokenUsage::new(7, 11, 18),
400            Some(ContextWindowUsage {
401                max_context_tokens: Some(200_000),
402                remaining_tokens: Some(199_982),
403                utilization_ratio: Some(0.00009),
404                estimated: false,
405            }),
406        );
407
408        let snapshot_a = state
409            .llm_usage_by_op
410            .get(&op_a)
411            .expect("usage for op_a should be present");
412        assert_eq!(snapshot_a.usage, TokenUsage::new(7, 11, 18));
413        assert!(snapshot_a.context_window.is_some());
414        assert_eq!(state.llm_usage_totals, TokenUsage::new(7, 11, 18));
415
416        state.record_llm_usage(op_b, model, TokenUsage::new(2, 4, 6), None);
417        assert_eq!(state.llm_usage_totals, TokenUsage::new(9, 15, 24));
418    }
419}