Skip to main content

smooth_operator/
gateway_key.rs

1//! Per-org LLM gateway-key resolution: the seam that lets a multi-tenant
2//! deployment bill/scope each org's turns to its **own** gateway key while a
3//! standalone/local server keeps using the single environment key.
4//!
5//! A turn runs against an OpenAI-compatible LLM gateway authenticated with a
6//! gateway key (e.g. a per-org LiteLLM virtual key). The reference server reads
7//! one key from `SMOOAI_GATEWAY_KEY` and uses it for every turn. A hosted,
8//! multi-tenant flavor instead wants to resolve a **different** key per org so
9//! usage is attributed and budgeted per tenant.
10//!
11//! This module is the **hook**: the public server defines the
12//! [`GatewayKeyResolver`] trait + the default [`EnvGatewayKeyResolver`] (which
13//! returns the single env key — the unchanged local/default behavior); the host
14//! application plugs in a concrete resolver (e.g. backed by its per-org key
15//! store) via `AppState::with_gateway_key_resolver`. No SmooAI/DB specifics live
16//! here — only the trait and the env default.
17//!
18//! ## Resolution contract
19//!
20//! [`GatewayKeyResolver::resolve`] returns `Some(key)` to **override** the key
21//! for that org, or `None` to **fall back** to the server's configured env key.
22//! The per-turn LLM-config build always falls back to the env key on `None`, so
23//! a resolver that only knows about a subset of orgs is safe — unknown orgs use
24//! the env key exactly as today.
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29
30/// Hook for resolving the LLM gateway key to use for a given org's turn.
31///
32/// Implemented by the host application (commonly backed by a per-org key store
33/// — e.g. a LiteLLM virtual key per tenant). Returning `None` means "no
34/// org-specific key" and the server falls back to its configured env key, so a
35/// resolver that covers only some orgs is safe.
36#[async_trait]
37pub trait GatewayKeyResolver: Send + Sync {
38    /// The gateway key to bill/scope this `org_id`'s turn to, or `None` to use
39    /// the server's default (env) key.
40    async fn resolve(&self, org_id: &str) -> Option<String>;
41}
42
43/// Default resolver: returns the single configured environment gateway key for
44/// every org (the unchanged local/default behavior — no per-org scoping).
45///
46/// Constructed from the server's resolved gateway key. When the env key is
47/// absent (`None`), this resolver returns `None` for every org, so the server
48/// behaves exactly as it does today (a clean `LLM_UNAVAILABLE` error on a turn).
49#[derive(Debug, Clone, Default)]
50pub struct EnvGatewayKeyResolver {
51    env_key: Option<String>,
52}
53
54impl EnvGatewayKeyResolver {
55    /// Build the env resolver over the server's configured gateway key (the
56    /// value of `SMOOAI_GATEWAY_KEY`, or `None` when unset).
57    #[must_use]
58    pub fn new(env_key: Option<String>) -> Self {
59        Self { env_key }
60    }
61}
62
63#[async_trait]
64impl GatewayKeyResolver for EnvGatewayKeyResolver {
65    async fn resolve(&self, _org_id: &str) -> Option<String> {
66        self.env_key.clone()
67    }
68}
69
70/// Resolve the gateway key for `org_id`, falling back to `env_key` when the
71/// resolver returns `None`.
72///
73/// This is the single place the per-turn LLM-config build calls: inject any
74/// [`GatewayKeyResolver`] and the env key, and get back the key the turn should
75/// use, or `None` when neither the resolver nor the env supplies one (turn is
76/// then unavailable). Keeping the fallback here means every flavor — and every
77/// polyglot port — resolves identically.
78pub async fn resolve_gateway_key(
79    resolver: &Arc<dyn GatewayKeyResolver>,
80    org_id: &str,
81    env_key: Option<&str>,
82) -> Option<String> {
83    match resolver.resolve(org_id).await {
84        Some(key) => Some(key),
85        None => env_key.map(str::to_string),
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    /// A stub resolver that overrides a single org with a fixed key and returns
94    /// `None` for any other org (so the env fallback is exercised).
95    struct OneOrgResolver {
96        org: String,
97        key: String,
98    }
99
100    #[async_trait]
101    impl GatewayKeyResolver for OneOrgResolver {
102        async fn resolve(&self, org_id: &str) -> Option<String> {
103            if org_id == self.org {
104                Some(self.key.clone())
105            } else {
106                None
107            }
108        }
109    }
110
111    #[tokio::test]
112    async fn env_resolver_returns_env_key_for_every_org() {
113        let resolver = EnvGatewayKeyResolver::new(Some("env-key".to_string()));
114        assert_eq!(resolver.resolve("org-a").await, Some("env-key".to_string()));
115        assert_eq!(resolver.resolve("org-b").await, Some("env-key".to_string()));
116    }
117
118    #[tokio::test]
119    async fn env_resolver_returns_none_when_env_absent() {
120        let resolver = EnvGatewayKeyResolver::new(None);
121        assert_eq!(resolver.resolve("org-a").await, None);
122    }
123
124    #[tokio::test]
125    async fn injected_resolver_overrides_per_org() {
126        let resolver: Arc<dyn GatewayKeyResolver> = Arc::new(OneOrgResolver {
127            org: "org-a".to_string(),
128            key: "org-a-key".to_string(),
129        });
130        // The covered org gets its own key (the env fallback is ignored).
131        assert_eq!(
132            resolve_gateway_key(&resolver, "org-a", Some("env-key")).await,
133            Some("org-a-key".to_string())
134        );
135    }
136
137    #[tokio::test]
138    async fn falls_back_to_env_when_resolver_returns_none() {
139        let resolver: Arc<dyn GatewayKeyResolver> = Arc::new(OneOrgResolver {
140            org: "org-a".to_string(),
141            key: "org-a-key".to_string(),
142        });
143        // An org the resolver doesn't cover falls back to the env key.
144        assert_eq!(
145            resolve_gateway_key(&resolver, "org-b", Some("env-key")).await,
146            Some("env-key".to_string())
147        );
148    }
149
150    #[tokio::test]
151    async fn resolves_to_none_when_neither_resolver_nor_env_supply_a_key() {
152        let resolver: Arc<dyn GatewayKeyResolver> = Arc::new(EnvGatewayKeyResolver::new(None));
153        assert_eq!(resolve_gateway_key(&resolver, "org-a", None).await, None);
154    }
155}