1use crate::{
2 checkpoint_store::CheckpointStore, context::ContextBudget, context::ContextFile,
3 event_log::EventLog, idempotency::IdempotencyStore, sandbox::SandboxConfig,
4 session_manager::SessionManager,
5};
6use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
7use stakpak_api::SessionStorage;
8use stakpak_mcp_client::McpClient;
9use std::{collections::HashMap, sync::Arc, time::Instant};
10use tokio::sync::{RwLock, broadcast};
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
14pub struct PendingToolApprovals {
15 pub run_id: Uuid,
16 pub tool_calls: Vec<ProposedToolCall>,
17}
18
19#[derive(Clone)]
20pub struct AppState {
21 pub run_manager: SessionManager,
22 pub session_store: Arc<dyn SessionStorage>,
24 pub events: Arc<EventLog>,
25 pub idempotency: Arc<IdempotencyStore>,
26 pub inference: Arc<stakai::Inference>,
27 pub checkpoint_store: Arc<CheckpointStore>,
29 pub models: Arc<Vec<stakai::Model>>,
30 pub default_model: Option<stakai::Model>,
31 pub tool_approval_policy: ToolApprovalPolicy,
32 pub started_at: Instant,
33 pub mcp_client: Option<Arc<McpClient>>,
34 pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
35 pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
36 pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
37 pub sandbox_config: Option<SandboxConfig>,
38 pub base_system_prompt: Option<String>,
39 pub context_budget: ContextBudget,
40 pub project_dir: Option<String>,
44 pub skills_context: Arc<RwLock<Vec<ContextFile>>>,
47 pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
48}
49
50impl AppState {
51 #[allow(clippy::too_many_arguments)]
52 pub fn new(
53 session_store: Arc<dyn SessionStorage>,
54 events: Arc<EventLog>,
55 idempotency: Arc<IdempotencyStore>,
56 inference: Arc<stakai::Inference>,
57 models: Vec<stakai::Model>,
58 default_model: Option<stakai::Model>,
59 tool_approval_policy: ToolApprovalPolicy,
60 ) -> Self {
61 Self {
62 run_manager: SessionManager::new(),
63 session_store,
64 events,
65 idempotency,
66 inference,
67 checkpoint_store: Arc::new(CheckpointStore::default_local()),
68 models: Arc::new(models),
69 default_model,
70 tool_approval_policy,
71 started_at: Instant::now(),
72 mcp_client: None,
73 mcp_tools: Arc::new(RwLock::new(Vec::new())),
74 mcp_server_shutdown_tx: None,
75 mcp_proxy_shutdown_tx: None,
76 sandbox_config: None,
77 base_system_prompt: None,
78 context_budget: ContextBudget::default(),
79 project_dir: None,
80 skills_context: Arc::new(RwLock::new(Vec::new())),
81 pending_tools: Arc::new(RwLock::new(HashMap::new())),
82 }
83 }
84
85 pub fn with_mcp(
86 mut self,
87 mcp_client: Arc<McpClient>,
88 mcp_tools: Vec<stakai::Tool>,
89 mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
90 mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
91 ) -> Self {
92 self.mcp_client = Some(mcp_client);
93 self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
94 self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
95 self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
96 self
97 }
98
99 pub fn with_sandbox(mut self, sandbox_config: SandboxConfig) -> Self {
100 self.sandbox_config = Some(sandbox_config);
101 self
102 }
103
104 pub fn with_base_system_prompt(mut self, prompt: Option<String>) -> Self {
105 self.base_system_prompt = prompt.filter(|value| !value.trim().is_empty());
106 self
107 }
108
109 pub fn with_context_budget(mut self, budget: ContextBudget) -> Self {
110 self.context_budget = budget;
111 self
112 }
113
114 pub fn with_project_dir(mut self, dir: Option<String>) -> Self {
115 self.project_dir = dir.filter(|value| !value.trim().is_empty());
116 self
117 }
118
119 pub fn with_skills(mut self, context_files: Vec<ContextFile>) -> Self {
120 self.skills_context = Arc::new(RwLock::new(context_files));
121 self
122 }
123
124 pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
125 self.checkpoint_store = checkpoint_store;
126 self
127 }
128
129 pub async fn current_skills(&self) -> Vec<ContextFile> {
130 self.skills_context.read().await.clone()
131 }
132
133 pub async fn replace_skills(&self, context_files: Vec<ContextFile>) {
134 let mut guard = self.skills_context.write().await;
135 *guard = context_files;
136 }
137
138 pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
139 self.mcp_tools.read().await.clone()
140 }
141
142 pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
143 let Some(mcp_client) = self.mcp_client.as_ref() else {
144 return Ok(self.mcp_tools.read().await.len());
145 };
146
147 let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
148 .await
149 .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
150
151 let converted = raw_tools
152 .into_iter()
153 .map(|tool| stakai::Tool {
154 tool_type: "function".to_string(),
155 function: stakai::ToolFunction {
156 name: tool.name.as_ref().to_string(),
157 description: tool
158 .description
159 .as_ref()
160 .map(std::string::ToString::to_string)
161 .unwrap_or_default(),
162 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
163 },
164 provider_options: None,
165 })
166 .collect::<Vec<_>>();
167
168 let mut guard = self.mcp_tools.write().await;
169 *guard = converted;
170 Ok(guard.len())
171 }
172
173 pub fn uptime_seconds(&self) -> u64 {
174 self.started_at.elapsed().as_secs()
175 }
176
177 pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
178 match requested {
179 Some(requested_model) => self.find_model(requested_model),
180 None => self
181 .default_model
182 .clone()
183 .or_else(|| self.models.first().cloned()),
184 }
185 }
186
187 pub async fn set_pending_tools(
188 &self,
189 session_id: Uuid,
190 run_id: Uuid,
191 tool_calls: Vec<ProposedToolCall>,
192 ) {
193 let mut guard = self.pending_tools.write().await;
194 guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
195 }
196
197 pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
198 let mut guard = self.pending_tools.write().await;
199 if guard
200 .get(&session_id)
201 .is_some_and(|pending| pending.run_id == run_id)
202 {
203 guard.remove(&session_id);
204 }
205 }
206
207 pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
208 let guard = self.pending_tools.read().await;
209 guard.get(&session_id).cloned()
210 }
211
212 fn find_model(&self, requested: &str) -> Option<stakai::Model> {
213 if let Some((provider, id)) = requested.split_once('/') {
214 return self
215 .models
216 .iter()
217 .find(|model| model.provider == provider && model.id == id)
218 .cloned()
219 .or_else(|| Some(stakai::Model::custom(id, provider)));
220 }
221
222 self.models
223 .iter()
224 .find(|model| model.id == requested)
225 .cloned()
226 .or_else(|| {
227 self.default_model.as_ref().map(|default_model| {
228 stakai::Model::custom(requested.to_string(), default_model.provider.clone())
229 })
230 })
231 .or_else(|| {
232 self.models.first().map(|model| {
233 stakai::Model::custom(requested.to_string(), model.provider.clone())
234 })
235 })
236 .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
237 }
238}