1use crate::api::provider::TokenUsage;
2use crate::app::SystemContext;
3use crate::app::conversation::MessageGraph;
4use crate::app::conversation::UserContent;
5use crate::app::domain::action::McpServerState;
6use crate::app::domain::event::ContextWindowUsage;
7use crate::app::domain::types::{MessageId, OpId, RequestId, SessionId, ToolCallId};
8use crate::config::model::ModelId;
9use crate::prompts::system_prompt_for_model;
10use crate::session::state::SessionConfig;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13use steer_tools::{ToolCall, ToolSchema};
14
15#[derive(Debug, Clone)]
16pub struct AppState {
17 pub session_id: SessionId,
18 pub session_config: Option<SessionConfig>,
19 pub base_session_config: Option<SessionConfig>,
20 pub primary_agent_id: Option<String>,
21
22 pub message_graph: MessageGraph,
23
24 pub cached_system_context: Option<SystemContext>,
25
26 pub tools: Vec<ToolSchema>,
27
28 pub approved_tools: HashSet<String>,
29 pub approved_bash_patterns: HashSet<String>,
30 pub static_bash_patterns: Vec<String>,
31 pub pending_approval: Option<PendingApproval>,
32 pub approval_queue: VecDeque<QueuedApproval>,
33 pub queued_work: VecDeque<QueuedWorkItem>,
34
35 pub current_operation: Option<OperationState>,
36
37 pub active_streams: HashMap<OpId, StreamingMessage>,
38
39 pub workspace_files: Vec<String>,
40
41 pub mcp_servers: HashMap<String, McpServerState>,
42
43 pub cancelled_ops: HashSet<OpId>,
44
45 pub operation_models: HashMap<OpId, ModelId>,
46 pub operation_messages: HashMap<OpId, MessageId>,
47
48 pub llm_usage_by_op: HashMap<OpId, LlmUsageSnapshot>,
49 pub llm_usage_totals: TokenUsage,
50
51 pub event_sequence: u64,
52
53 pub compaction_summary_ids: HashSet<String>,
55}
56
57#[derive(Debug, Clone)]
58pub struct PendingApproval {
59 pub request_id: RequestId,
60 pub tool_call: ToolCall,
61}
62
63#[derive(Debug, Clone)]
64pub struct LlmUsageSnapshot {
65 pub model: ModelId,
66 pub usage: TokenUsage,
67 pub context_window: Option<ContextWindowUsage>,
68}
69
70#[derive(Debug, Clone)]
71pub struct QueuedApproval {
72 pub tool_call: ToolCall,
73}
74
75#[derive(Debug, Clone)]
76pub struct QueuedUserMessage {
77 pub content: Vec<UserContent>,
78 pub op_id: OpId,
79 pub message_id: MessageId,
80 pub model: ModelId,
81 pub queued_at: u64,
82}
83
84#[derive(Debug, Clone)]
85pub struct QueuedBashCommand {
86 pub command: String,
87 pub op_id: OpId,
88 pub message_id: MessageId,
89 pub queued_at: u64,
90}
91
92#[derive(Debug, Clone)]
93pub enum QueuedWorkItem {
94 UserMessage(QueuedUserMessage),
95 DirectBash(QueuedBashCommand),
96}
97
98#[derive(Debug, Clone)]
99pub struct OperationState {
100 pub op_id: OpId,
101 pub kind: OperationKind,
102 pub pending_tool_calls: HashSet<ToolCallId>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum OperationKind {
107 AgentLoop,
108 Compact {
109 trigger: crate::app::domain::event::CompactTrigger,
110 },
111 DirectBash {
112 command: String,
113 },
114}
115
116#[derive(Debug, Clone)]
117pub struct StreamingMessage {
118 pub message_id: MessageId,
119 pub op_id: OpId,
120 pub content: String,
121 pub tool_calls: Vec<ToolCall>,
122 pub byte_count: usize,
123}
124
125pub struct StreamingConfig {
126 pub max_buffer_bytes: usize,
127 pub max_concurrent_streams: usize,
128}
129
130impl Default for StreamingConfig {
131 fn default() -> Self {
132 Self {
133 max_buffer_bytes: 64 * 1024,
134 max_concurrent_streams: 3,
135 }
136 }
137}
138
139const MAX_CANCELLED_OPS: usize = 100;
140
141impl AppState {
142 pub fn new(session_id: SessionId) -> Self {
143 Self {
144 session_id,
145 session_config: None,
146 base_session_config: None,
147 primary_agent_id: None,
148 message_graph: MessageGraph::new(),
149 cached_system_context: None,
150 tools: Vec::new(),
151 approved_tools: HashSet::new(),
152 approved_bash_patterns: HashSet::new(),
153 static_bash_patterns: Vec::new(),
154 pending_approval: None,
155 approval_queue: VecDeque::new(),
156 queued_work: VecDeque::new(),
157 current_operation: None,
158 active_streams: HashMap::new(),
159 workspace_files: Vec::new(),
160 mcp_servers: HashMap::new(),
161 cancelled_ops: HashSet::new(),
162 operation_models: HashMap::new(),
163 operation_messages: HashMap::new(),
164 llm_usage_by_op: HashMap::new(),
165 llm_usage_totals: TokenUsage::new(0, 0, 0),
166 event_sequence: 0,
167 compaction_summary_ids: HashSet::new(),
168 }
169 }
170
171 pub fn with_approved_patterns(mut self, patterns: Vec<String>) -> Self {
172 self.static_bash_patterns = patterns;
173 self
174 }
175
176 pub fn with_approved_tools(mut self, tools: HashSet<String>) -> Self {
177 self.approved_tools = tools;
178 self
179 }
180
181 pub fn is_tool_pre_approved(&self, tool_name: &str) -> bool {
182 self.approved_tools.contains(tool_name)
183 }
184
185 pub fn is_bash_pattern_approved(&self, command: &str) -> bool {
186 for pattern in self
187 .static_bash_patterns
188 .iter()
189 .chain(self.approved_bash_patterns.iter())
190 {
191 if pattern == command {
192 return true;
193 }
194 if let Ok(glob) = glob::Pattern::new(pattern)
195 && glob.matches(command)
196 {
197 return true;
198 }
199 }
200 false
201 }
202
203 pub fn approve_tool(&mut self, tool_name: String) {
204 self.approved_tools.insert(tool_name);
205 }
206
207 pub fn approve_bash_pattern(&mut self, pattern: String) {
208 self.approved_bash_patterns.insert(pattern);
209 }
210
211 pub fn record_cancelled_op(&mut self, op_id: OpId) {
212 self.cancelled_ops.insert(op_id);
213 if self.cancelled_ops.len() > MAX_CANCELLED_OPS
214 && let Some(&oldest) = self.cancelled_ops.iter().next()
215 {
216 self.cancelled_ops.remove(&oldest);
217 }
218 }
219
220 pub fn is_op_cancelled(&self, op_id: &OpId) -> bool {
221 self.cancelled_ops.contains(op_id)
222 }
223
224 pub fn has_pending_approval(&self) -> bool {
225 self.pending_approval.is_some()
226 }
227
228 pub fn has_active_operation(&self) -> bool {
229 self.current_operation.is_some()
230 }
231
232 pub fn queue_user_message(&mut self, item: QueuedUserMessage) {
233 if let Some(QueuedWorkItem::UserMessage(tail)) = self.queued_work.back_mut() {
234 if tail.content.iter().any(|item| {
235 !matches!(item, UserContent::Text { text } if text.as_str().trim().is_empty())
236 }) {
237 tail.content.push(UserContent::Text {
238 text: "\n\n".to_string(),
239 });
240 }
241 tail.content.extend(item.content);
242 tail.op_id = item.op_id;
243 tail.message_id = item.message_id;
244 tail.model = item.model;
245 tail.queued_at = item.queued_at;
246 return;
247 }
248 self.queued_work
249 .push_back(QueuedWorkItem::UserMessage(item));
250 }
251
252 pub fn queue_bash_command(&mut self, item: QueuedBashCommand) {
253 self.queued_work.push_back(QueuedWorkItem::DirectBash(item));
254 }
255
256 pub fn pop_next_queued_work(&mut self) -> Option<QueuedWorkItem> {
257 self.queued_work.pop_front()
258 }
259
260 pub fn queued_summary(&self) -> (Option<QueuedWorkItem>, usize) {
261 (self.queued_work.front().cloned(), self.queued_work.len())
262 }
263
264 pub fn start_operation(&mut self, op_id: OpId, kind: OperationKind) {
265 self.current_operation = Some(OperationState {
266 op_id,
267 kind,
268 pending_tool_calls: HashSet::new(),
269 });
270 }
271
272 pub fn complete_operation(&mut self, op_id: OpId) {
273 self.operation_models.remove(&op_id);
274 self.operation_messages.remove(&op_id);
275 if self
276 .current_operation
277 .as_ref()
278 .is_some_and(|op| op.op_id == op_id)
279 {
280 self.current_operation = None;
281 }
282 }
283
284 pub fn add_pending_tool_call(&mut self, tool_call_id: ToolCallId) {
285 if let Some(ref mut op) = self.current_operation {
286 op.pending_tool_calls.insert(tool_call_id);
287 }
288 }
289
290 pub fn remove_pending_tool_call(&mut self, tool_call_id: &ToolCallId) {
291 if let Some(ref mut op) = self.current_operation {
292 op.pending_tool_calls.remove(tool_call_id);
293 }
294 }
295
296 pub fn increment_sequence(&mut self) -> u64 {
297 self.event_sequence += 1;
298 self.event_sequence
299 }
300
301 pub fn record_llm_usage(
302 &mut self,
303 op_id: OpId,
304 model: ModelId,
305 usage: TokenUsage,
306 context_window: Option<ContextWindowUsage>,
307 ) {
308 self.llm_usage_by_op.insert(
309 op_id,
310 LlmUsageSnapshot {
311 model,
312 usage,
313 context_window,
314 },
315 );
316 self.recompute_llm_usage_totals();
317 }
318
319 fn recompute_llm_usage_totals(&mut self) {
320 let mut input_tokens = 0u32;
321 let mut output_tokens = 0u32;
322 let mut total_tokens = 0u32;
323
324 for snapshot in self.llm_usage_by_op.values() {
325 input_tokens = input_tokens.saturating_add(snapshot.usage.input_tokens);
326 output_tokens = output_tokens.saturating_add(snapshot.usage.output_tokens);
327 total_tokens = total_tokens.saturating_add(snapshot.usage.total_tokens);
328 }
329
330 self.llm_usage_totals = TokenUsage::new(input_tokens, output_tokens, total_tokens);
331 }
332
333 pub fn apply_session_config(
334 &mut self,
335 config: &SessionConfig,
336 primary_agent_id: Option<String>,
337 update_base: bool,
338 ) {
339 self.session_config = Some(config.clone());
340 let prompt = config
341 .system_prompt
342 .as_ref()
343 .and_then(|prompt| {
344 if prompt.trim().is_empty() {
345 None
346 } else {
347 Some(prompt.clone())
348 }
349 })
350 .unwrap_or_else(|| system_prompt_for_model(&config.default_model));
351 let environment = self
352 .cached_system_context
353 .as_ref()
354 .and_then(|context| context.environment.clone());
355 self.cached_system_context = Some(SystemContext::with_environment(prompt, environment));
356
357 self.approved_tools
358 .clone_from(config.tool_config.approval_policy.pre_approved_tools());
359 self.approved_bash_patterns.clear();
360 self.static_bash_patterns = config
361 .tool_config
362 .approval_policy
363 .preapproved
364 .bash_patterns()
365 .map(|patterns| patterns.to_vec())
366 .unwrap_or_default();
367 self.pending_approval = None;
368 self.approval_queue.clear();
369
370 if let Some(primary_agent_id) = primary_agent_id.or_else(|| config.primary_agent_id.clone())
371 {
372 self.primary_agent_id = Some(primary_agent_id);
373 }
374
375 if update_base {
376 self.base_session_config = Some(config.clone());
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use crate::config::model::builtin;
385
386 #[test]
387 fn record_llm_usage_replaces_snapshot_for_op_and_recomputes_totals() {
388 let mut state = AppState::new(SessionId::new());
389 let model = builtin::claude_sonnet_4_5();
390 let op_a = OpId::new();
391 let op_b = OpId::new();
392
393 state.record_llm_usage(op_a, model.clone(), TokenUsage::new(3, 5, 8), None);
394 assert_eq!(state.llm_usage_totals, TokenUsage::new(3, 5, 8));
395
396 state.record_llm_usage(
397 op_a,
398 model.clone(),
399 TokenUsage::new(7, 11, 18),
400 Some(ContextWindowUsage {
401 max_context_tokens: Some(200_000),
402 remaining_tokens: Some(199_982),
403 utilization_ratio: Some(0.00009),
404 estimated: false,
405 }),
406 );
407
408 let snapshot_a = state
409 .llm_usage_by_op
410 .get(&op_a)
411 .expect("usage for op_a should be present");
412 assert_eq!(snapshot_a.usage, TokenUsage::new(7, 11, 18));
413 assert!(snapshot_a.context_window.is_some());
414 assert_eq!(state.llm_usage_totals, TokenUsage::new(7, 11, 18));
415
416 state.record_llm_usage(op_b, model, TokenUsage::new(2, 4, 6), None);
417 assert_eq!(state.llm_usage_totals, TokenUsage::new(9, 15, 24));
418 }
419}