Skip to main content

smooth_operator_server/
state.rs

1//! Server + per-connection state.
2//!
3//! [`AppState`] is shared across every connection + every admin HTTP request
4//! (cloneable `Arc` handles): the storage adapter, the resolved
5//! [`ServerConfig`], the session registry, and — for the admin API (Phase 12) —
6//! the [`AuthVerifier`], an [`IndexingStore`], and the document-set registry.
7//!
8//! Sessions live in an in-memory map keyed by `sessionId` so `get_session` and
9//! reconnects work across connections (mirrors the protocol's "connection →
10//! session" / "session → connections" state model, simplified for the reference
11//! single-process server). On AWS this map would be DynamoDB; on k8s, Redis or
12//! Postgres.
13
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17use smooth_operator_core::HumanResponse;
18use tokio::sync::mpsc::UnboundedSender;
19
20use smooth_operator::adapter::StorageAdapter;
21use smooth_operator::auth::{AuthVerifier, NoAuthVerifier};
22use smooth_operator::backplane::{Backplane, InMemoryBackplane};
23use smooth_operator::connector_config::{ConnectorConfigStore, InMemoryConnectorConfigStore};
24use smooth_operator::domain::Session;
25use smooth_operator::gateway_key::{EnvGatewayKeyResolver, GatewayKeyResolver};
26use smooth_operator::settings::{InMemorySettingsStore, SettingsStore};
27use smooth_operator::tool_provider::ToolProvider;
28use smooth_operator::widget_auth::{PermissiveWidgetAuth, WidgetAuthProvider};
29use tokio_util::sync::CancellationToken;
30
31use smooth_operator_core::llm_provider::LlmProvider;
32use smooth_operator_ingestion::indexing::{InMemoryIndexingStore, IndexingStore};
33
34use crate::config::ServerConfig;
35
36/// Shared, cloneable application state handed to every WebSocket connection +
37/// every admin HTTP request.
38#[derive(Clone)]
39pub struct AppState {
40    /// The single storage seam (conversations / participants / messages /
41    /// sessions / checkpoints / knowledge).
42    pub storage: Arc<dyn StorageAdapter>,
43    /// Resolved server configuration (gateway, model, limits).
44    pub config: Arc<ServerConfig>,
45    /// The configured auth verifier (jwt / smoo / none). Used by the admin API's
46    /// `require_role` extractor to turn a bearer token into a `Principal`.
47    pub auth: Arc<dyn AuthVerifier>,
48    /// Indexing-run status store, surfaced by `GET /admin/indexing/runs`.
49    pub indexing: Arc<dyn IndexingStore>,
50    /// Connector-configuration store, CRUD'd by the admin write API
51    /// (`/admin/connectors`). Org-scoped; holds an `auth_ref` (secret name), not
52    /// the secret itself.
53    pub connector_configs: Arc<dyn ConnectorConfigStore>,
54    /// Per-org agent settings store, read/written by `/admin/settings`.
55    pub settings: Arc<dyn SettingsStore>,
56    /// **Host tool-injection seam.** When `Some`, the runner asks this provider
57    /// for EXTRA tools and merges them into every turn's `ToolRegistry`
58    /// alongside the built-ins. Defaults to `None` (built-ins only); a host
59    /// installs one via [`with_tools`](Self::with_tools) to contribute its own
60    /// per-org tool catalog without forking the runner.
61    pub tool_provider: Option<Arc<dyn ToolProvider>>,
62    /// Embeddable-widget auth hook: resolves an agent's origin-allowlist +
63    /// public-key policy for `<smooth-agent-chat>` connections. Defaults to
64    /// [`PermissiveWidgetAuth`] (no enforcement) until a host installs a real
65    /// provider via [`with_widget_auth`](Self::with_widget_auth).
66    pub widget_auth: Arc<dyn WidgetAuthProvider>,
67    /// Connection backplane: per-pod sink registry + cross-pod event delivery.
68    /// Defaults to [`InMemoryBackplane`] (single-process); a host installs a
69    /// Redis/NATS impl via [`with_backplane`](Self::with_backplane) to scale out
70    /// and to let non-AI publishers push realtime events to connected clients.
71    pub backplane: Arc<dyn Backplane>,
72    /// Test-only injected LLM surface. When `Some`, every `send_message` turn
73    /// runs the engine against this provider (a
74    /// [`MockLlmClient`](smooth_operator_core::llm_provider::MockLlmClient))
75    /// instead of building a live gateway client from `config` — exactly the
76    /// `ServerState(chat_client=mock)` seam the Python reference uses to drive the
77    /// scenario-parity corpus deterministically offline. **`None` in production**
78    /// (a live client is built from the gateway config), so the `/ws` path is
79    /// byte-for-byte unchanged for real deployments. Installed via
80    /// [`with_chat_provider`](Self::with_chat_provider).
81    pub chat_provider: Option<Arc<dyn LlmProvider>>,
82    /// Per-org LLM gateway-key resolver: maps a turn's `org_id` to the gateway
83    /// key it should bill/scope to. Defaults to [`EnvGatewayKeyResolver`] (the
84    /// single `SMOOAI_GATEWAY_KEY` for every org — unchanged local behavior); a
85    /// multi-tenant host installs a per-org resolver via
86    /// [`with_gateway_key_resolver`](Self::with_gateway_key_resolver) so each
87    /// tenant's usage is attributed to its own key. The per-turn LLM-config build
88    /// falls back to the env key whenever the resolver returns `None`.
89    pub gateway_key_resolver: Arc<dyn GatewayKeyResolver>,
90    /// Graceful-shutdown signal, shared across every per-connection clone of this
91    /// state. On SIGTERM/ctrl_c the serve loop cancels this token; each
92    /// connection's reader loop selects on [`CancellationToken::cancelled`] so it
93    /// finishes its in-flight turn, exits, and detaches from the [`Backplane`] —
94    /// no in-flight turn dropped, no stale registry entry left behind. A fresh
95    /// token from [`new`](Self::new) is never cancelled, so the `/ws` path and
96    /// tests are unaffected until a `run`/serve path wires the signal.
97    pub shutdown: CancellationToken,
98    /// Session registry: `sessionId` → session blob. Shared across connections.
99    sessions: Arc<RwLock<HashMap<String, Session>>>,
100    /// Document-set registry, **org-scoped**: `org_id` → (set name → document
101    /// count). The in-memory knowledge backend drops document metadata on
102    /// ingest, so the admin API reads document-set membership from this side
103    /// registry. Keyed by org so org A's document sets are never reported to an
104    /// org-B caller (cross-org leak fix — SMOODEV access-control hardening).
105    doc_sets: Arc<RwLock<HashMap<String, HashMap<String, usize>>>>,
106    /// Connector registry, **org-scoped**: `org_id` → set of connector names
107    /// whose indexing runs should be listed. Keyed by org so a same-named
108    /// connector in two orgs does not collide, and `GET /admin/indexing/runs`
109    /// only ever lists the caller's org's connectors.
110    connectors: Arc<RwLock<HashMap<String, Vec<String>>>>,
111    /// **Human-in-the-loop pending confirmations**: `sessionId` →
112    /// [`HumanResponse`] sender for a turn currently parked on a write-tool
113    /// confirmation. When an agent turn calls a tool that requires human
114    /// approval, the runner installs a `ConfirmationHook` (smooth-operator-core)
115    /// that parks the loop and registers its response sender here. A subsequent
116    /// `confirm_tool_action` frame looks the session up, takes the sender, and
117    /// feeds it [`HumanResponse::Approved`] / [`HumanResponse::Denied`] to resume
118    /// the parked turn (execute or reject the tool). Keyed by session so each
119    /// session has at most one outstanding confirmation; an empty map means no
120    /// turn is parked (the default, byte-for-byte unchanged from before HITL).
121    pending_confirmations: Arc<RwLock<HashMap<String, UnboundedSender<HumanResponse>>>>,
122}
123
124/// Namespace a connector name by org for the [`IndexingStore`] key, so two orgs
125/// with a same-named connector (`"docs"`) record + list **separate** runs. The
126/// `\u{1}` separator can't appear in a user-supplied connector name, so it can't
127/// be spoofed to cross an org boundary.
128#[must_use]
129pub fn scoped_connector_key(org_id: &str, connector_name: &str) -> String {
130    format!("IXCONN#{org_id}\u{1}{connector_name}")
131}
132
133impl AppState {
134    /// Construct shared state over a storage adapter and config.
135    ///
136    /// Defaults the admin-API collaborators: a [`NoAuthVerifier`] (overridden via
137    /// [`with_auth`](Self::with_auth)) and an empty [`InMemoryIndexingStore`]
138    /// (overridden via [`with_indexing`](Self::with_indexing)). The `/ws` path
139    /// uses none of these, so existing callers are unaffected.
140    #[must_use]
141    pub fn new(storage: Arc<dyn StorageAdapter>, config: ServerConfig) -> Self {
142        // Default resolver returns the single env gateway key for every org, so
143        // the local/default flavor is unchanged until a host installs a per-org
144        // resolver via `with_gateway_key_resolver`.
145        let gateway_key_resolver: Arc<dyn GatewayKeyResolver> =
146            Arc::new(EnvGatewayKeyResolver::new(config.gateway_key.clone()));
147        Self {
148            storage,
149            config: Arc::new(config),
150            auth: Arc::new(NoAuthVerifier::default()),
151            indexing: Arc::new(InMemoryIndexingStore::new()),
152            connector_configs: Arc::new(InMemoryConnectorConfigStore::new()),
153            settings: Arc::new(InMemorySettingsStore::new()),
154            tool_provider: None,
155            widget_auth: Arc::new(PermissiveWidgetAuth),
156            backplane: Arc::new(InMemoryBackplane::new()),
157            chat_provider: None,
158            gateway_key_resolver,
159            // A fresh, never-cancelled token: every clone of this state shares
160            // its cancellation state, so the serve loop cancelling once fans out
161            // to every connection. Defaulting here (rather than at each call
162            // site) keeps construction ripple-free.
163            shutdown: CancellationToken::new(),
164            sessions: Arc::new(RwLock::new(HashMap::new())),
165            doc_sets: Arc::new(RwLock::new(HashMap::new())),
166            connectors: Arc::new(RwLock::new(HashMap::new())),
167            pending_confirmations: Arc::new(RwLock::new(HashMap::new())),
168        }
169    }
170
171    /// Install the configured auth verifier (builder).
172    #[must_use]
173    pub fn with_auth(mut self, auth: Arc<dyn AuthVerifier>) -> Self {
174        self.auth = auth;
175        self
176    }
177
178    /// Install the indexing store (builder).
179    #[must_use]
180    pub fn with_indexing(mut self, indexing: Arc<dyn IndexingStore>) -> Self {
181        self.indexing = indexing;
182        self
183    }
184
185    /// Install the connector-configuration store (builder).
186    #[must_use]
187    pub fn with_connector_configs(mut self, store: Arc<dyn ConnectorConfigStore>) -> Self {
188        self.connector_configs = store;
189        self
190    }
191
192    /// Install the agent-settings store (builder).
193    #[must_use]
194    pub fn with_settings(mut self, store: Arc<dyn SettingsStore>) -> Self {
195        self.settings = store;
196        self
197    }
198
199    /// Install a host [`ToolProvider`] (builder). The runner merges the
200    /// provider's per-turn tools into every turn's registry alongside the
201    /// built-ins. Without this, the registry is exactly the built-ins, so the
202    /// default/local flavor is unaffected.
203    #[must_use]
204    pub fn with_tools(mut self, provider: Arc<dyn ToolProvider>) -> Self {
205        self.tool_provider = Some(provider);
206        self
207    }
208
209    /// Install the embeddable-widget auth provider (builder). A host backs this
210    /// with its agent store so embed origins + public keys are enforced.
211    #[must_use]
212    pub fn with_widget_auth(mut self, provider: Arc<dyn WidgetAuthProvider>) -> Self {
213        self.widget_auth = provider;
214        self
215    }
216
217    /// Install the connection backplane (builder). A host installs a Redis/NATS
218    /// impl to scale the WS service horizontally and to let other services push
219    /// realtime events to connected clients via [`Backplane::publish`].
220    #[must_use]
221    pub fn with_backplane(mut self, backplane: Arc<dyn Backplane>) -> Self {
222        self.backplane = backplane;
223        self
224    }
225
226    /// Install a test-injected LLM provider (builder). Every `send_message` turn
227    /// then runs the engine against this provider instead of a live gateway
228    /// client — the [`MockLlmClient`](smooth_operator_core::llm_provider::MockLlmClient)
229    /// seam the scenario-parity corpus drives. Production never calls this, so the
230    /// live path is unchanged. See [`chat_provider`](Self::chat_provider).
231    #[must_use]
232    pub fn with_chat_provider(mut self, provider: Arc<dyn LlmProvider>) -> Self {
233        self.chat_provider = Some(provider);
234        self
235    }
236
237    /// Install a per-org gateway-key resolver (builder). A multi-tenant host
238    /// installs a resolver backed by its per-org key store (e.g. one LiteLLM
239    /// virtual key per tenant) so each org's turns are billed/scoped to its own
240    /// key. The per-turn LLM-config build falls back to the env key whenever the
241    /// resolver returns `None`, so a resolver covering only some orgs is safe.
242    /// Leaving this unset keeps the default [`EnvGatewayKeyResolver`] (single env
243    /// key for every org — unchanged local behavior).
244    #[must_use]
245    pub fn with_gateway_key_resolver(mut self, resolver: Arc<dyn GatewayKeyResolver>) -> Self {
246        self.gateway_key_resolver = resolver;
247        self
248    }
249
250    /// Install the graceful-shutdown signal (builder). The serve loop owns a
251    /// clone of this token and cancels it on SIGTERM/ctrl_c; every per-connection
252    /// clone observes the cancellation and drains. Defaulted to a fresh token in
253    /// [`new`](Self::new), so this is only needed when a caller wants to drive
254    /// shutdown from its own token.
255    #[must_use]
256    pub fn with_shutdown(mut self, shutdown: CancellationToken) -> Self {
257        self.shutdown = shutdown;
258        self
259    }
260
261    /// Register a freshly created session.
262    pub fn insert_session(&self, session: Session) {
263        if let Ok(mut map) = self.sessions.write() {
264            map.insert(session.session_id.clone(), session);
265        }
266    }
267
268    /// Look up a session by id.
269    #[must_use]
270    pub fn get_session(&self, session_id: &str) -> Option<Session> {
271        self.sessions.read().ok()?.get(session_id).cloned()
272    }
273
274    /// Record that a document was added to a named document set **within an org**
275    /// (increments its count). Used by seeding + the ingest path so
276    /// `GET /admin/document-sets` can report set names + counts despite the
277    /// in-memory backend dropping document metadata. Org-scoped so org A's sets
278    /// are never reported to an org-B caller.
279    pub fn record_document_set(&self, org_id: impl Into<String>, set: impl Into<String>) {
280        if let Ok(mut map) = self.doc_sets.write() {
281            *map.entry(org_id.into())
282                .or_default()
283                .entry(set.into())
284                .or_insert(0) += 1;
285        }
286    }
287
288    /// Snapshot **one org's** document-set registry as `(name, count)` pairs,
289    /// sorted by name for a stable response. Never returns another org's sets.
290    #[must_use]
291    pub fn document_sets(&self, org_id: &str) -> Vec<(String, usize)> {
292        let Ok(map) = self.doc_sets.read() else {
293            return Vec::new();
294        };
295        let Some(org_sets) = map.get(org_id) else {
296            return Vec::new();
297        };
298        let mut out: Vec<(String, usize)> = org_sets.iter().map(|(k, v)| (k.clone(), *v)).collect();
299        out.sort_by(|a, b| a.0.cmp(&b.0));
300        out
301    }
302
303    /// Record a connector (within an org) whose indexing runs should be listed
304    /// (idempotent). Org-scoped so a same-named connector in two orgs records
305    /// separately and `GET /admin/indexing/runs` only lists the caller's org's.
306    pub fn record_connector(&self, org_id: impl Into<String>, name: impl Into<String>) {
307        let name = name.into();
308        if let Ok(mut map) = self.connectors.write() {
309            let v = map.entry(org_id.into()).or_default();
310            if !v.iter().any(|c| c == &name) {
311                v.push(name);
312            }
313        }
314    }
315
316    /// Snapshot **one org's** recorded connector names (sorted, stable). Never
317    /// returns another org's connectors.
318    #[must_use]
319    pub fn connectors(&self, org_id: &str) -> Vec<String> {
320        let Ok(map) = self.connectors.read() else {
321            return Vec::new();
322        };
323        let mut out = map.get(org_id).cloned().unwrap_or_default();
324        out.sort();
325        out
326    }
327
328    /// Register a parked turn's [`HumanResponse`] sender for `session_id`, so a
329    /// later `confirm_tool_action` can resume it. Any prior pending sender for
330    /// the same session is replaced (one outstanding confirmation per session).
331    /// Called by the runner's confirmation bridge when a write tool emits a
332    /// `HumanRequest::Confirm`.
333    pub fn register_confirmation(
334        &self,
335        session_id: impl Into<String>,
336        responder: UnboundedSender<HumanResponse>,
337    ) {
338        if let Ok(mut map) = self.pending_confirmations.write() {
339            map.insert(session_id.into(), responder);
340        }
341    }
342
343    /// Take (remove + return) the pending [`HumanResponse`] sender for
344    /// `session_id`, if a turn is parked on a confirmation. Returns `None` when
345    /// no turn awaits confirmation for that session (the common case). Taking it
346    /// out — rather than cloning — guarantees a single confirmation resolves a
347    /// single parked tool call, and a duplicate `confirm_tool_action` is a no-op.
348    #[must_use]
349    pub fn take_confirmation(&self, session_id: &str) -> Option<UnboundedSender<HumanResponse>> {
350        self.pending_confirmations.write().ok()?.remove(session_id)
351    }
352
353    /// Drop any pending confirmation registered for `session_id` without
354    /// resolving it. Called when a parked turn ends (the bridge task finishes)
355    /// so a stale sender can't linger and mis-route a later confirmation.
356    pub fn clear_confirmation(&self, session_id: &str) {
357        if let Ok(mut map) = self.pending_confirmations.write() {
358            map.remove(session_id);
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use async_trait::async_trait;
367    use smooth_operator::gateway_key::resolve_gateway_key;
368    use smooth_operator_adapter_memory::InMemoryStorageAdapter;
369
370    use crate::config::{ServerConfig, StorageBackend, DEFAULT_GATEWAY_URL, DEFAULT_MODEL};
371
372    /// Build a config with an explicit env gateway key for the resolver tests.
373    fn config_with_env_key(env_key: Option<&str>) -> ServerConfig {
374        ServerConfig {
375            bind: "127.0.0.1".to_string(),
376            port: 0,
377            gateway_url: DEFAULT_GATEWAY_URL.to_string(),
378            gateway_key: env_key.map(str::to_string),
379            model: DEFAULT_MODEL.to_string(),
380            seed_kb: false,
381            max_iterations: 6,
382            max_tokens: 512,
383            storage: StorageBackend::Memory,
384            widget_auth_strict: false,
385            confirm_tools: Vec::new(),
386        }
387    }
388
389    fn state_with(config: ServerConfig) -> AppState {
390        AppState::new(Arc::new(InMemoryStorageAdapter::new()), config)
391    }
392
393    /// Per-org resolver covering exactly one org; `None` (→ env fallback) for any
394    /// other org. Mirrors what a multi-tenant host installs.
395    struct OneOrgResolver {
396        org: String,
397        key: String,
398    }
399
400    #[async_trait]
401    impl GatewayKeyResolver for OneOrgResolver {
402        async fn resolve(&self, org_id: &str) -> Option<String> {
403            (org_id == self.org).then(|| self.key.clone())
404        }
405    }
406
407    #[tokio::test]
408    async fn default_state_resolves_env_key_for_every_org() {
409        // No resolver injected: the default `EnvGatewayKeyResolver` returns the
410        // single env key for every org — unchanged local behavior.
411        let state = state_with(config_with_env_key(Some("env-key")));
412        let env = state.config.gateway_key.as_deref();
413        assert_eq!(
414            resolve_gateway_key(&state.gateway_key_resolver, "org-a", env).await,
415            Some("env-key".to_string())
416        );
417        assert_eq!(
418            resolve_gateway_key(&state.gateway_key_resolver, "org-z", env).await,
419            Some("env-key".to_string())
420        );
421    }
422
423    #[tokio::test]
424    async fn injected_resolver_overrides_per_org_and_falls_back_to_env() {
425        let config = config_with_env_key(Some("env-key"));
426        let state = state_with(config).with_gateway_key_resolver(Arc::new(OneOrgResolver {
427            org: "org-a".to_string(),
428            key: "org-a-key".to_string(),
429        }));
430        let env = state.config.gateway_key.as_deref();
431
432        // Covered org → its own key.
433        assert_eq!(
434            resolve_gateway_key(&state.gateway_key_resolver, "org-a", env).await,
435            Some("org-a-key".to_string())
436        );
437        // Uncovered org → env fallback.
438        assert_eq!(
439            resolve_gateway_key(&state.gateway_key_resolver, "org-b", env).await,
440            Some("env-key".to_string())
441        );
442    }
443
444    #[tokio::test]
445    async fn no_env_key_and_no_resolver_match_resolves_to_none() {
446        // Env key absent + default resolver → no key (turn is unavailable). Same
447        // behavior as today's `llm_config()` returning `None`.
448        let state = state_with(config_with_env_key(None));
449        let env = state.config.gateway_key.as_deref();
450        assert_eq!(
451            resolve_gateway_key(&state.gateway_key_resolver, "org-a", env).await,
452            None
453        );
454    }
455}