Skip to main content

stakpak_server/
state.rs

1use crate::{
2    checkpoint_store::CheckpointStore, context::ContextBudget, context::ContextFile,
3    event_log::EventLog, idempotency::IdempotencyStore, sandbox::SandboxConfig,
4    session_manager::SessionManager,
5};
6use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
7use stakpak_api::SessionStorage;
8use stakpak_mcp_client::McpClient;
9use std::{collections::HashMap, sync::Arc, time::Instant};
10use tokio::sync::{RwLock, broadcast};
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
14pub struct PendingToolApprovals {
15    pub run_id: Uuid,
16    pub tool_calls: Vec<ProposedToolCall>,
17}
18
19#[derive(Clone)]
20pub struct AppState {
21    pub run_manager: SessionManager,
22    /// Durable session/checkpoint backend (SQLite/remote API).
23    pub session_store: Arc<dyn SessionStorage>,
24    pub events: Arc<EventLog>,
25    pub idempotency: Arc<IdempotencyStore>,
26    pub inference: Arc<stakai::Inference>,
27    /// Server-side latest-envelope cache (`stakai::Message` + runtime metadata).
28    pub checkpoint_store: Arc<CheckpointStore>,
29    pub models: Arc<Vec<stakai::Model>>,
30    pub default_model: Option<stakai::Model>,
31    pub tool_approval_policy: ToolApprovalPolicy,
32    pub started_at: Instant,
33    pub mcp_client: Option<Arc<McpClient>>,
34    pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
35    pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
36    pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
37    pub sandbox_config: Option<SandboxConfig>,
38    pub base_system_prompt: Option<String>,
39    pub context_budget: ContextBudget,
40    /// Base directory for project context discovery (AGENTS.md, APPS.md).
41    /// Falls back to process cwd if not set. Should be set to the directory
42    /// where `stakpak up` was run so gateway sessions can discover project files.
43    pub project_dir: Option<String>,
44    /// Cached remote skill context files (currently fetched from the remote
45    /// skills endpoint contract) and injected into new sessions as baseline context.
46    pub skills_context: Arc<RwLock<Vec<ContextFile>>>,
47    pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
48}
49
50impl AppState {
51    #[allow(clippy::too_many_arguments)]
52    pub fn new(
53        session_store: Arc<dyn SessionStorage>,
54        events: Arc<EventLog>,
55        idempotency: Arc<IdempotencyStore>,
56        inference: Arc<stakai::Inference>,
57        models: Vec<stakai::Model>,
58        default_model: Option<stakai::Model>,
59        tool_approval_policy: ToolApprovalPolicy,
60    ) -> Self {
61        Self {
62            run_manager: SessionManager::new(),
63            session_store,
64            events,
65            idempotency,
66            inference,
67            checkpoint_store: Arc::new(CheckpointStore::default_local()),
68            models: Arc::new(models),
69            default_model,
70            tool_approval_policy,
71            started_at: Instant::now(),
72            mcp_client: None,
73            mcp_tools: Arc::new(RwLock::new(Vec::new())),
74            mcp_server_shutdown_tx: None,
75            mcp_proxy_shutdown_tx: None,
76            sandbox_config: None,
77            base_system_prompt: None,
78            context_budget: ContextBudget::default(),
79            project_dir: None,
80            skills_context: Arc::new(RwLock::new(Vec::new())),
81            pending_tools: Arc::new(RwLock::new(HashMap::new())),
82        }
83    }
84
85    pub fn with_mcp(
86        mut self,
87        mcp_client: Arc<McpClient>,
88        mcp_tools: Vec<stakai::Tool>,
89        mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
90        mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
91    ) -> Self {
92        self.mcp_client = Some(mcp_client);
93        self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
94        self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
95        self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
96        self
97    }
98
99    pub fn with_sandbox(mut self, sandbox_config: SandboxConfig) -> Self {
100        self.sandbox_config = Some(sandbox_config);
101        self
102    }
103
104    pub fn with_base_system_prompt(mut self, prompt: Option<String>) -> Self {
105        self.base_system_prompt = prompt.filter(|value| !value.trim().is_empty());
106        self
107    }
108
109    pub fn with_context_budget(mut self, budget: ContextBudget) -> Self {
110        self.context_budget = budget;
111        self
112    }
113
114    pub fn with_project_dir(mut self, dir: Option<String>) -> Self {
115        self.project_dir = dir.filter(|value| !value.trim().is_empty());
116        self
117    }
118
119    pub fn with_skills(mut self, context_files: Vec<ContextFile>) -> Self {
120        self.skills_context = Arc::new(RwLock::new(context_files));
121        self
122    }
123
124    pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
125        self.checkpoint_store = checkpoint_store;
126        self
127    }
128
129    pub async fn current_skills(&self) -> Vec<ContextFile> {
130        self.skills_context.read().await.clone()
131    }
132
133    pub async fn replace_skills(&self, context_files: Vec<ContextFile>) {
134        let mut guard = self.skills_context.write().await;
135        *guard = context_files;
136    }
137
138    pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
139        self.mcp_tools.read().await.clone()
140    }
141
142    pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
143        let Some(mcp_client) = self.mcp_client.as_ref() else {
144            return Ok(self.mcp_tools.read().await.len());
145        };
146
147        let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
148            .await
149            .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
150
151        let converted = raw_tools
152            .into_iter()
153            .map(|tool| stakai::Tool {
154                tool_type: "function".to_string(),
155                function: stakai::ToolFunction {
156                    name: tool.name.as_ref().to_string(),
157                    description: tool
158                        .description
159                        .as_ref()
160                        .map(std::string::ToString::to_string)
161                        .unwrap_or_default(),
162                    parameters: serde_json::Value::Object((*tool.input_schema).clone()),
163                },
164                provider_options: None,
165            })
166            .collect::<Vec<_>>();
167
168        let mut guard = self.mcp_tools.write().await;
169        *guard = converted;
170        Ok(guard.len())
171    }
172
173    pub fn uptime_seconds(&self) -> u64 {
174        self.started_at.elapsed().as_secs()
175    }
176
177    pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
178        match requested {
179            Some(requested_model) => self.find_model(requested_model),
180            None => self
181                .default_model
182                .clone()
183                .or_else(|| self.models.first().cloned()),
184        }
185    }
186
187    pub async fn set_pending_tools(
188        &self,
189        session_id: Uuid,
190        run_id: Uuid,
191        tool_calls: Vec<ProposedToolCall>,
192    ) {
193        let mut guard = self.pending_tools.write().await;
194        guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
195    }
196
197    pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
198        let mut guard = self.pending_tools.write().await;
199        if guard
200            .get(&session_id)
201            .is_some_and(|pending| pending.run_id == run_id)
202        {
203            guard.remove(&session_id);
204        }
205    }
206
207    pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
208        let guard = self.pending_tools.read().await;
209        guard.get(&session_id).cloned()
210    }
211
212    fn find_model(&self, requested: &str) -> Option<stakai::Model> {
213        if let Some((provider, id)) = requested.split_once('/') {
214            return self
215                .models
216                .iter()
217                .find(|model| model.provider == provider && model.id == id)
218                .cloned()
219                .or_else(|| Some(stakai::Model::custom(id, provider)));
220        }
221
222        self.models
223            .iter()
224            .find(|model| model.id == requested)
225            .cloned()
226            .or_else(|| {
227                self.default_model.as_ref().map(|default_model| {
228                    stakai::Model::custom(requested.to_string(), default_model.provider.clone())
229                })
230            })
231            .or_else(|| {
232                self.models.first().map(|model| {
233                    stakai::Model::custom(requested.to_string(), model.provider.clone())
234                })
235            })
236            .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
237    }
238}