Skip to main content

stakpak_server/
state.rs

1use crate::{
2    checkpoint_store::CheckpointStore, event_log::EventLog, idempotency::IdempotencyStore,
3    sandbox::SandboxConfig, session_manager::SessionManager,
4};
5use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
6use stakpak_api::SessionStorage;
7use stakpak_mcp_client::McpClient;
8use std::{collections::HashMap, sync::Arc, time::Instant};
9use tokio::sync::{RwLock, broadcast};
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
13pub struct PendingToolApprovals {
14    pub run_id: Uuid,
15    pub tool_calls: Vec<ProposedToolCall>,
16}
17
18#[derive(Clone)]
19pub struct AppState {
20    pub run_manager: SessionManager,
21    /// Durable session/checkpoint backend (SQLite/remote API).
22    pub session_store: Arc<dyn SessionStorage>,
23    pub events: Arc<EventLog>,
24    pub idempotency: Arc<IdempotencyStore>,
25    pub inference: Arc<stakai::Inference>,
26    /// Server-side latest-envelope cache (`stakai::Message` + runtime metadata).
27    pub checkpoint_store: Arc<CheckpointStore>,
28    pub models: Arc<Vec<stakai::Model>>,
29    pub default_model: Option<stakai::Model>,
30    pub tool_approval_policy: ToolApprovalPolicy,
31    pub started_at: Instant,
32    pub mcp_client: Option<Arc<McpClient>>,
33    pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
34    pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
35    pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
36    pub sandbox_config: Option<SandboxConfig>,
37    pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
38}
39
40impl AppState {
41    #[allow(clippy::too_many_arguments)]
42    pub fn new(
43        session_store: Arc<dyn SessionStorage>,
44        events: Arc<EventLog>,
45        idempotency: Arc<IdempotencyStore>,
46        inference: Arc<stakai::Inference>,
47        models: Vec<stakai::Model>,
48        default_model: Option<stakai::Model>,
49        tool_approval_policy: ToolApprovalPolicy,
50    ) -> Self {
51        Self {
52            run_manager: SessionManager::new(),
53            session_store,
54            events,
55            idempotency,
56            inference,
57            checkpoint_store: Arc::new(CheckpointStore::default_local()),
58            models: Arc::new(models),
59            default_model,
60            tool_approval_policy,
61            started_at: Instant::now(),
62            mcp_client: None,
63            mcp_tools: Arc::new(RwLock::new(Vec::new())),
64            mcp_server_shutdown_tx: None,
65            mcp_proxy_shutdown_tx: None,
66            sandbox_config: None,
67            pending_tools: Arc::new(RwLock::new(HashMap::new())),
68        }
69    }
70
71    pub fn with_mcp(
72        mut self,
73        mcp_client: Arc<McpClient>,
74        mcp_tools: Vec<stakai::Tool>,
75        mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
76        mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
77    ) -> Self {
78        self.mcp_client = Some(mcp_client);
79        self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
80        self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
81        self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
82        self
83    }
84
85    pub fn with_sandbox(mut self, sandbox_config: SandboxConfig) -> Self {
86        self.sandbox_config = Some(sandbox_config);
87        self
88    }
89
90    pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
91        self.checkpoint_store = checkpoint_store;
92        self
93    }
94
95    pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
96        self.mcp_tools.read().await.clone()
97    }
98
99    pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
100        let Some(mcp_client) = self.mcp_client.as_ref() else {
101            return Ok(self.mcp_tools.read().await.len());
102        };
103
104        let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
105            .await
106            .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
107
108        let converted = raw_tools
109            .into_iter()
110            .map(|tool| stakai::Tool {
111                tool_type: "function".to_string(),
112                function: stakai::ToolFunction {
113                    name: tool.name.as_ref().to_string(),
114                    description: tool
115                        .description
116                        .as_ref()
117                        .map(std::string::ToString::to_string)
118                        .unwrap_or_default(),
119                    parameters: serde_json::Value::Object((*tool.input_schema).clone()),
120                },
121                provider_options: None,
122            })
123            .collect::<Vec<_>>();
124
125        let mut guard = self.mcp_tools.write().await;
126        *guard = converted;
127        Ok(guard.len())
128    }
129
130    pub fn uptime_seconds(&self) -> u64 {
131        self.started_at.elapsed().as_secs()
132    }
133
134    pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
135        match requested {
136            Some(requested_model) => self.find_model(requested_model),
137            None => self
138                .default_model
139                .clone()
140                .or_else(|| self.models.first().cloned()),
141        }
142    }
143
144    pub async fn set_pending_tools(
145        &self,
146        session_id: Uuid,
147        run_id: Uuid,
148        tool_calls: Vec<ProposedToolCall>,
149    ) {
150        let mut guard = self.pending_tools.write().await;
151        guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
152    }
153
154    pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
155        let mut guard = self.pending_tools.write().await;
156        if guard
157            .get(&session_id)
158            .is_some_and(|pending| pending.run_id == run_id)
159        {
160            guard.remove(&session_id);
161        }
162    }
163
164    pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
165        let guard = self.pending_tools.read().await;
166        guard.get(&session_id).cloned()
167    }
168
169    fn find_model(&self, requested: &str) -> Option<stakai::Model> {
170        if let Some((provider, id)) = requested.split_once('/') {
171            return self
172                .models
173                .iter()
174                .find(|model| model.provider == provider && model.id == id)
175                .cloned()
176                .or_else(|| Some(stakai::Model::custom(id, provider)));
177        }
178
179        self.models
180            .iter()
181            .find(|model| model.id == requested)
182            .cloned()
183            .or_else(|| {
184                self.default_model.as_ref().map(|default_model| {
185                    stakai::Model::custom(requested.to_string(), default_model.provider.clone())
186                })
187            })
188            .or_else(|| {
189                self.models.first().map(|model| {
190                    stakai::Model::custom(requested.to_string(), model.provider.clone())
191                })
192            })
193            .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
194    }
195}