1use crate::{
2 checkpoint_store::CheckpointStore,
3 context::ContextBudget,
4 context::ContextFile,
5 event_log::EventLog,
6 idempotency::IdempotencyStore,
7 sandbox::{PersistentSandbox, SandboxConfig, SandboxMode},
8 session_manager::SessionManager,
9};
10use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
11use stakpak_api::SessionStorage;
12use stakpak_mcp_client::McpClient;
13use std::{collections::HashMap, sync::Arc, time::Instant};
14use tokio::sync::{RwLock, broadcast};
15use uuid::Uuid;
16
17#[derive(Debug, Clone)]
18pub struct PendingToolApprovals {
19 pub run_id: Uuid,
20 pub tool_calls: Vec<ProposedToolCall>,
21}
22
23#[derive(Clone)]
24pub struct AppState {
25 pub run_manager: SessionManager,
26 pub session_store: Arc<dyn SessionStorage>,
28 pub events: Arc<EventLog>,
29 pub idempotency: Arc<IdempotencyStore>,
30 pub inference: Arc<stakai::Inference>,
31 pub checkpoint_store: Arc<CheckpointStore>,
33 pub models: Arc<Vec<stakai::Model>>,
34 pub default_model: Option<stakai::Model>,
35 pub tool_approval_policy: ToolApprovalPolicy,
36 pub started_at: Instant,
37 pub mcp_client: Option<Arc<McpClient>>,
38 pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
39 pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
40 pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
41 pub sandbox_config: Option<SandboxConfig>,
42 pub persistent_sandbox: Option<Arc<PersistentSandbox>>,
45 pub base_system_prompt: Option<String>,
46 pub context_budget: ContextBudget,
47 pub project_dir: Option<String>,
51 pub skills_context: Arc<RwLock<Vec<ContextFile>>>,
54 pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
55}
56
57impl AppState {
58 #[allow(clippy::too_many_arguments)]
59 pub fn new(
60 session_store: Arc<dyn SessionStorage>,
61 events: Arc<EventLog>,
62 idempotency: Arc<IdempotencyStore>,
63 inference: Arc<stakai::Inference>,
64 models: Vec<stakai::Model>,
65 default_model: Option<stakai::Model>,
66 tool_approval_policy: ToolApprovalPolicy,
67 ) -> Self {
68 Self {
69 run_manager: SessionManager::new(),
70 session_store,
71 events,
72 idempotency,
73 inference,
74 checkpoint_store: Arc::new(CheckpointStore::default_local()),
75 models: Arc::new(models),
76 default_model,
77 tool_approval_policy,
78 started_at: Instant::now(),
79 mcp_client: None,
80 mcp_tools: Arc::new(RwLock::new(Vec::new())),
81 mcp_server_shutdown_tx: None,
82 mcp_proxy_shutdown_tx: None,
83 sandbox_config: None,
84 persistent_sandbox: None,
85 base_system_prompt: None,
86 context_budget: ContextBudget::default(),
87 project_dir: None,
88 skills_context: Arc::new(RwLock::new(Vec::new())),
89 pending_tools: Arc::new(RwLock::new(HashMap::new())),
90 }
91 }
92
93 pub fn with_mcp(
94 mut self,
95 mcp_client: Arc<McpClient>,
96 mcp_tools: Vec<stakai::Tool>,
97 mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
98 mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
99 ) -> Self {
100 self.mcp_client = Some(mcp_client);
101 self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
102 self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
103 self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
104 self
105 }
106
107 pub fn with_sandbox(mut self, sandbox_config: SandboxConfig) -> Self {
108 self.sandbox_config = Some(sandbox_config);
109 self
110 }
111
112 pub fn with_persistent_sandbox(mut self, sandbox: PersistentSandbox) -> Self {
115 self.persistent_sandbox = Some(Arc::new(sandbox));
116 self
117 }
118
119 pub fn sandbox_mode(&self) -> SandboxMode {
121 self.sandbox_config
122 .as_ref()
123 .map(|c| c.mode.clone())
124 .unwrap_or_default()
125 }
126
127 pub fn with_base_system_prompt(mut self, prompt: Option<String>) -> Self {
128 self.base_system_prompt = prompt.filter(|value| !value.trim().is_empty());
129 self
130 }
131
132 pub fn with_context_budget(mut self, budget: ContextBudget) -> Self {
133 self.context_budget = budget;
134 self
135 }
136
137 pub fn with_project_dir(mut self, dir: Option<String>) -> Self {
138 self.project_dir = dir.filter(|value| !value.trim().is_empty());
139 self
140 }
141
142 pub fn with_skills(mut self, context_files: Vec<ContextFile>) -> Self {
143 self.skills_context = Arc::new(RwLock::new(context_files));
144 self
145 }
146
147 pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
148 self.checkpoint_store = checkpoint_store;
149 self
150 }
151
152 pub async fn current_skills(&self) -> Vec<ContextFile> {
153 self.skills_context.read().await.clone()
154 }
155
156 pub async fn replace_skills(&self, context_files: Vec<ContextFile>) {
157 let mut guard = self.skills_context.write().await;
158 *guard = context_files;
159 }
160
161 pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
162 self.mcp_tools.read().await.clone()
163 }
164
165 pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
166 let Some(mcp_client) = self.mcp_client.as_ref() else {
167 return Ok(self.mcp_tools.read().await.len());
168 };
169
170 let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
171 .await
172 .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
173
174 let converted = raw_tools
175 .into_iter()
176 .map(|tool| stakai::Tool {
177 tool_type: "function".to_string(),
178 function: stakai::ToolFunction {
179 name: tool.name.as_ref().to_string(),
180 description: tool
181 .description
182 .as_ref()
183 .map(std::string::ToString::to_string)
184 .unwrap_or_default(),
185 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
186 },
187 provider_options: None,
188 })
189 .collect::<Vec<_>>();
190
191 let mut guard = self.mcp_tools.write().await;
192 *guard = converted;
193 Ok(guard.len())
194 }
195
196 pub fn uptime_seconds(&self) -> u64 {
197 self.started_at.elapsed().as_secs()
198 }
199
200 pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
201 match requested {
202 Some(requested_model) => self.find_model(requested_model),
203 None => self
204 .default_model
205 .clone()
206 .or_else(|| self.models.first().cloned()),
207 }
208 }
209
210 pub async fn set_pending_tools(
211 &self,
212 session_id: Uuid,
213 run_id: Uuid,
214 tool_calls: Vec<ProposedToolCall>,
215 ) {
216 let mut guard = self.pending_tools.write().await;
217 guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
218 }
219
220 pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
221 let mut guard = self.pending_tools.write().await;
222 if guard
223 .get(&session_id)
224 .is_some_and(|pending| pending.run_id == run_id)
225 {
226 guard.remove(&session_id);
227 }
228 }
229
230 pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
231 let guard = self.pending_tools.read().await;
232 guard.get(&session_id).cloned()
233 }
234
235 fn find_model(&self, requested: &str) -> Option<stakai::Model> {
236 if let Some((provider, id)) = requested.split_once('/') {
237 return self
238 .models
239 .iter()
240 .find(|model| model.provider == provider && model.id == id)
241 .cloned()
242 .or_else(|| Some(stakai::Model::custom(id, provider)));
243 }
244
245 self.models
246 .iter()
247 .find(|model| model.id == requested)
248 .cloned()
249 .or_else(|| {
250 self.default_model.as_ref().map(|default_model| {
251 stakai::Model::custom(requested.to_string(), default_model.provider.clone())
252 })
253 })
254 .or_else(|| {
255 self.models.first().map(|model| {
256 stakai::Model::custom(requested.to_string(), model.provider.clone())
257 })
258 })
259 .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
260 }
261}