Skip to main content

stakpak_server/
state.rs

1use crate::{
2    checkpoint_store::CheckpointStore, event_log::EventLog, idempotency::IdempotencyStore,
3    session_manager::SessionManager,
4};
5use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
6use stakpak_api::SessionStorage;
7use stakpak_mcp_client::McpClient;
8use std::{collections::HashMap, sync::Arc, time::Instant};
9use tokio::sync::{RwLock, broadcast};
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
13pub struct PendingToolApprovals {
14    pub run_id: Uuid,
15    pub tool_calls: Vec<ProposedToolCall>,
16}
17
18#[derive(Clone)]
19pub struct AppState {
20    pub run_manager: SessionManager,
21    /// Durable session/checkpoint backend (SQLite/remote API).
22    pub session_store: Arc<dyn SessionStorage>,
23    pub events: Arc<EventLog>,
24    pub idempotency: Arc<IdempotencyStore>,
25    pub inference: Arc<stakai::Inference>,
26    /// Server-side latest-envelope cache (`stakai::Message` + runtime metadata).
27    pub checkpoint_store: Arc<CheckpointStore>,
28    pub models: Arc<Vec<stakai::Model>>,
29    pub default_model: Option<stakai::Model>,
30    pub tool_approval_policy: ToolApprovalPolicy,
31    pub started_at: Instant,
32    pub mcp_client: Option<Arc<McpClient>>,
33    pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
34    pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
35    pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
36    pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
37}
38
39impl AppState {
40    #[allow(clippy::too_many_arguments)]
41    pub fn new(
42        session_store: Arc<dyn SessionStorage>,
43        events: Arc<EventLog>,
44        idempotency: Arc<IdempotencyStore>,
45        inference: Arc<stakai::Inference>,
46        models: Vec<stakai::Model>,
47        default_model: Option<stakai::Model>,
48        tool_approval_policy: ToolApprovalPolicy,
49    ) -> Self {
50        Self {
51            run_manager: SessionManager::new(),
52            session_store,
53            events,
54            idempotency,
55            inference,
56            checkpoint_store: Arc::new(CheckpointStore::default_local()),
57            models: Arc::new(models),
58            default_model,
59            tool_approval_policy,
60            started_at: Instant::now(),
61            mcp_client: None,
62            mcp_tools: Arc::new(RwLock::new(Vec::new())),
63            mcp_server_shutdown_tx: None,
64            mcp_proxy_shutdown_tx: None,
65            pending_tools: Arc::new(RwLock::new(HashMap::new())),
66        }
67    }
68
69    pub fn with_mcp(
70        mut self,
71        mcp_client: Arc<McpClient>,
72        mcp_tools: Vec<stakai::Tool>,
73        mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
74        mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
75    ) -> Self {
76        self.mcp_client = Some(mcp_client);
77        self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
78        self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
79        self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
80        self
81    }
82
83    pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
84        self.checkpoint_store = checkpoint_store;
85        self
86    }
87
88    pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
89        self.mcp_tools.read().await.clone()
90    }
91
92    pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
93        let Some(mcp_client) = self.mcp_client.as_ref() else {
94            return Ok(self.mcp_tools.read().await.len());
95        };
96
97        let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
98            .await
99            .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
100
101        let converted = raw_tools
102            .into_iter()
103            .map(|tool| stakai::Tool {
104                tool_type: "function".to_string(),
105                function: stakai::ToolFunction {
106                    name: tool.name.as_ref().to_string(),
107                    description: tool
108                        .description
109                        .as_ref()
110                        .map(std::string::ToString::to_string)
111                        .unwrap_or_default(),
112                    parameters: serde_json::Value::Object((*tool.input_schema).clone()),
113                },
114                provider_options: None,
115            })
116            .collect::<Vec<_>>();
117
118        let mut guard = self.mcp_tools.write().await;
119        *guard = converted;
120        Ok(guard.len())
121    }
122
123    pub fn uptime_seconds(&self) -> u64 {
124        self.started_at.elapsed().as_secs()
125    }
126
127    pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
128        match requested {
129            Some(requested_model) => self.find_model(requested_model),
130            None => self
131                .default_model
132                .clone()
133                .or_else(|| self.models.first().cloned()),
134        }
135    }
136
137    pub async fn set_pending_tools(
138        &self,
139        session_id: Uuid,
140        run_id: Uuid,
141        tool_calls: Vec<ProposedToolCall>,
142    ) {
143        let mut guard = self.pending_tools.write().await;
144        guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
145    }
146
147    pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
148        let mut guard = self.pending_tools.write().await;
149        if guard
150            .get(&session_id)
151            .is_some_and(|pending| pending.run_id == run_id)
152        {
153            guard.remove(&session_id);
154        }
155    }
156
157    pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
158        let guard = self.pending_tools.read().await;
159        guard.get(&session_id).cloned()
160    }
161
162    fn find_model(&self, requested: &str) -> Option<stakai::Model> {
163        if let Some((provider, id)) = requested.split_once('/') {
164            return self
165                .models
166                .iter()
167                .find(|model| model.provider == provider && model.id == id)
168                .cloned()
169                .or_else(|| Some(stakai::Model::custom(id, provider)));
170        }
171
172        self.models
173            .iter()
174            .find(|model| model.id == requested)
175            .cloned()
176            .or_else(|| {
177                self.default_model.as_ref().map(|default_model| {
178                    stakai::Model::custom(requested.to_string(), default_model.provider.clone())
179                })
180            })
181            .or_else(|| {
182                self.models.first().map(|model| {
183                    stakai::Model::custom(requested.to_string(), model.provider.clone())
184                })
185            })
186            .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
187    }
188}