Skip to main content

smooth_operator/
widget_auth.rs

1//! Embeddable-widget auth: an origin allowlist + public-key `authContext`
2//! verification for browser-embedded chat widgets (`<smooth-agent-chat>`).
3//!
4//! Browser widgets connect from arbitrary customer sites, so an agent that's
5//! embeddable needs two protections the bearer-token path doesn't give:
6//!
7//! 1. **Origin allowlist** — only the sites a customer registered may embed and
8//!    drive their agent (mirrors a CORS/referrer allowlist, enforced server-side
9//!    on the WebSocket `Origin` header captured at connect).
10//! 2. **Public-key `authContext`** — a host page can pre-authenticate a known
11//!    user by HMAC-signing `{userId}:{timestamp}` with the agent's public key;
12//!    the server verifies it (replay-protected) so the turn can skip OTP.
13//!
14//! This module is the **hook**: the public server defines the
15//! [`WidgetAuthProvider`] trait + the enforcement primitives ([`origin_allowed`],
16//! [`verify_auth_context`]); the host application plugs in a concrete provider
17//! (e.g. backed by its agent database) that maps an `agentId` to its
18//! [`AgentWidgetAuth`] policy. The bundled [`PermissiveWidgetAuth`] returns no
19//! policy for any agent, so a standalone OSS server enforces nothing until a
20//! real provider is installed (see `WIDGET_AUTH_STRICT` on the server for
21//! fail-closed behavior on unknown agents).
22
23use std::collections::HashMap;
24use std::sync::RwLock;
25use std::time::{Duration, Instant};
26
27use async_trait::async_trait;
28use hmac::{Hmac, Mac};
29use sha2::Sha256;
30
31/// The embed-auth policy for one agent.
32#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
33pub struct AgentWidgetAuth {
34    /// Origins permitted to embed this agent. Each entry is an exact origin
35    /// (`https://app.example.com`), a host wildcard (`https://*.smoo.ai`), or
36    /// `*` (any). An **empty** list means *no origin is allowed* (deny all) —
37    /// configure at least one entry to permit embedding.
38    #[serde(default)]
39    pub allowed_origins: Vec<String>,
40    /// Shared secret used to verify a pre-auth `authContext` HMAC. `None` means
41    /// the agent does not support `authContext` (any supplied one is rejected).
42    #[serde(default)]
43    pub public_key: Option<String>,
44    /// The organization this agent belongs to, when the host policy service
45    /// knows it. A multi-tenant host (whose `Conversation`/`Participant` rows
46    /// carry an org FK) populates this so a widget visitor's session is created
47    /// under the agent's real org — visitors authenticate via origin +
48    /// `authContext`, not a JWT, so the org cannot come from a bearer token.
49    /// `None` (the default, and the OSS/dev case) leaves org derivation to the
50    /// connection's JWT principal, then the server's seed org.
51    #[serde(default)]
52    pub organization_id: Option<String>,
53}
54
55/// Hook for resolving an agent's [`AgentWidgetAuth`] policy.
56///
57/// Implemented by the host application (commonly backed by its agent DB/API).
58/// Returning `None` means "no policy for this agent" — the server treats that as
59/// allow in permissive mode, or deny in strict mode (`WIDGET_AUTH_STRICT`).
60#[async_trait]
61pub trait WidgetAuthProvider: Send + Sync {
62    /// The embed-auth policy for `agent_id`, or `None` if the agent has none /
63    /// is unknown.
64    async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth>;
65}
66
67/// Default provider: no policy for any agent → enforcement is off. Keeps the OSS
68/// server's `/ws` path open until a real [`WidgetAuthProvider`] is installed.
69#[derive(Debug, Default)]
70pub struct PermissiveWidgetAuth;
71
72#[async_trait]
73impl WidgetAuthProvider for PermissiveWidgetAuth {
74    async fn agent_widget_auth(&self, _agent_id: &str) -> Option<AgentWidgetAuth> {
75        None
76    }
77}
78
79/// Static map provider (`agentId` → policy). Lets a server enforce without a
80/// database, and gives hosts a simple wiring option (load from a JSON file/env).
81#[derive(Debug, Default)]
82pub struct StaticWidgetAuth {
83    rows: HashMap<String, AgentWidgetAuth>,
84}
85
86impl StaticWidgetAuth {
87    /// Build from an in-memory map.
88    #[must_use]
89    pub fn new(rows: HashMap<String, AgentWidgetAuth>) -> Self {
90        Self { rows }
91    }
92
93    /// Parse a JSON object of `{ "<agentId>": { "allowed_origins": [...],
94    /// "public_key": "..." }, ... }`.
95    ///
96    /// # Errors
97    /// Returns an error if `json` is not a valid map of the expected shape.
98    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
99        let rows: HashMap<String, AgentWidgetAuth> = serde_json::from_str(json)?;
100        Ok(Self { rows })
101    }
102}
103
104#[async_trait]
105impl WidgetAuthProvider for StaticWidgetAuth {
106    async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth> {
107        self.rows.get(agent_id).cloned()
108    }
109}
110
111/// A cache entry: the resolved policy (or `None` for a known no-policy agent) and
112/// when it was fetched, for TTL expiry.
113struct CacheEntry {
114    value: Option<AgentWidgetAuth>,
115    fetched: Instant,
116}
117
118/// HTTP-backed provider: resolves `agentId` → [`AgentWidgetAuth`] by GETting
119/// `{base_url}/{agentId}` from a host's policy service, with TTL caching.
120///
121/// This is the **generic mechanism** a host installs instead of writing a custom
122/// [`WidgetAuthProvider`]: stand up an endpoint that returns the
123/// [`AgentWidgetAuth`] JSON (`{ "allowed_origins": [...], "public_key": "..." }`)
124/// for an agent, point `HttpWidgetAuth` at it, and embed-auth is enforced against
125/// live data. (SmooAI backs this with an api-prime route over its agent DB.)
126///
127/// Response handling — chosen so a flaky policy service never *silently* opens a
128/// hole:
129/// - **2xx** → parse + cache the policy.
130/// - **404** → cache `None` (the agent legitimately has no policy; in
131///   `WIDGET_AUTH_STRICT` the server then denies it).
132/// - **5xx / network / malformed body** → return `None` **without caching**, so
133///   the next connect retries. Combined with strict mode this fails closed; in
134///   permissive mode enforcement is off anyway.
135///
136/// Cached results (incl. 404s) are reused for `ttl` (default 60s) so a busy embed
137/// doesn't hammer the policy service on every WebSocket connect.
138pub struct HttpWidgetAuth {
139    client: reqwest::Client,
140    /// Policy endpoint base (no trailing slash); the agent id is appended as a
141    /// single percent-encoded path segment.
142    base_url: String,
143    /// Optional bearer token sent to the policy service (e.g. an M2M token).
144    bearer: Option<String>,
145    ttl: Duration,
146    cache: RwLock<HashMap<String, CacheEntry>>,
147}
148
149impl HttpWidgetAuth {
150    /// Build a provider that resolves policies from `base_url` (e.g.
151    /// `https://api.smoo.ai/internal/widget-auth`). Uses a client with a 5s
152    /// timeout so a hung policy service can't stall widget connects.
153    #[must_use]
154    pub fn new(base_url: impl Into<String>) -> Self {
155        let client = reqwest::Client::builder()
156            .timeout(Duration::from_secs(5))
157            .build()
158            .unwrap_or_default();
159        Self::with_client(base_url, client)
160    }
161
162    /// Build with a caller-supplied [`reqwest::Client`] (to share a pool / set
163    /// custom timeouts or TLS).
164    #[must_use]
165    pub fn with_client(base_url: impl Into<String>, client: reqwest::Client) -> Self {
166        Self {
167            client,
168            base_url: base_url.into().trim_end_matches('/').to_string(),
169            bearer: None,
170            ttl: Duration::from_secs(60),
171            cache: RwLock::new(HashMap::new()),
172        }
173    }
174
175    /// Send `Authorization: Bearer <token>` to the policy service (builder).
176    #[must_use]
177    pub fn with_bearer(mut self, token: impl Into<String>) -> Self {
178        self.bearer = Some(token.into());
179        self
180    }
181
182    /// Override the cache TTL (builder). Default 60s.
183    #[must_use]
184    pub fn with_ttl(mut self, ttl: Duration) -> Self {
185        self.ttl = ttl;
186        self
187    }
188
189    /// A live (non-expired) cached result for `agent_id`, if any. Outer `None` =
190    /// not cached / expired; inner `Option` = the cached policy-or-no-policy.
191    fn cached(&self, agent_id: &str) -> Option<Option<AgentWidgetAuth>> {
192        let cache = self.cache.read().ok()?;
193        let entry = cache.get(agent_id)?;
194        if entry.fetched.elapsed() < self.ttl {
195            Some(entry.value.clone())
196        } else {
197            None
198        }
199    }
200
201    /// Cache a definitive result (a 2xx policy or a 404 no-policy).
202    fn store(&self, agent_id: &str, value: Option<AgentWidgetAuth>) {
203        if let Ok(mut cache) = self.cache.write() {
204            cache.insert(
205                agent_id.to_string(),
206                CacheEntry {
207                    value,
208                    fetched: Instant::now(),
209                },
210            );
211        }
212    }
213}
214
215#[async_trait]
216impl WidgetAuthProvider for HttpWidgetAuth {
217    async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth> {
218        if let Some(cached) = self.cached(agent_id) {
219            return cached;
220        }
221
222        // Build the URL by pushing the agent id as one percent-encoded segment,
223        // so an id can't manipulate the path.
224        let mut url = match reqwest::Url::parse(&self.base_url) {
225            Ok(u) => u,
226            Err(e) => {
227                tracing::warn!(error = %e, base_url = %self.base_url, "widget-auth: invalid base_url");
228                return None;
229            }
230        };
231        match url.path_segments_mut() {
232            Ok(mut segs) => {
233                segs.push(agent_id);
234            }
235            Err(()) => {
236                tracing::warn!(base_url = %self.base_url, "widget-auth: base_url cannot be a base");
237                return None;
238            }
239        }
240
241        let mut req = self.client.get(url);
242        if let Some(bearer) = &self.bearer {
243            req = req.bearer_auth(bearer);
244        }
245
246        let resp = match req.send().await {
247            Ok(r) => r,
248            Err(e) => {
249                // Transient — do NOT cache, so the next connect retries.
250                tracing::warn!(error = %e, agent_id, "widget-auth: policy fetch failed");
251                return None;
252            }
253        };
254
255        let status = resp.status();
256        if status.is_success() {
257            match resp.json::<AgentWidgetAuth>().await {
258                Ok(policy) => {
259                    let value = Some(policy);
260                    self.store(agent_id, value.clone());
261                    value
262                }
263                Err(e) => {
264                    // Malformed body (deploy skew?) — don't cache; retry next time.
265                    tracing::warn!(error = %e, agent_id, "widget-auth: malformed policy body");
266                    None
267                }
268            }
269        } else if status == reqwest::StatusCode::NOT_FOUND {
270            // Legitimate "this agent has no policy" — cache it.
271            self.store(agent_id, None);
272            None
273        } else {
274            // 5xx etc. — don't cache; fail open here, which strict mode turns
275            // into a deny.
276            tracing::warn!(%status, agent_id, "widget-auth: policy service error");
277            None
278        }
279    }
280}
281
282/// Whether `origin` is permitted by `allowed`.
283///
284/// An empty `allowed` permits nothing. Each pattern is matched as:
285/// - `*` → any origin,
286/// - an exact match (`https://app.example.com`),
287/// - a host wildcard `scheme://*.suffix` → the origin's scheme must match and
288///   its host must equal `suffix` or end with `.suffix`
289///   (`https://*.smoo.ai` matches `https://app.smoo.ai` and `https://smoo.ai`).
290#[must_use]
291pub fn origin_allowed(allowed: &[String], origin: &str) -> bool {
292    allowed
293        .iter()
294        .any(|pattern| origin_matches(pattern, origin))
295}
296
297fn origin_matches(pattern: &str, origin: &str) -> bool {
298    if pattern == "*" {
299        return true;
300    }
301    if pattern == origin {
302        return true;
303    }
304    // Host wildcard: scheme://*.suffix
305    let (Some((p_scheme, p_host)), Some((o_scheme, o_host))) =
306        (pattern.split_once("://"), origin.split_once("://"))
307    else {
308        return false;
309    };
310    if p_scheme != o_scheme {
311        return false;
312    }
313    if let Some(suffix) = p_host.strip_prefix("*.") {
314        return o_host == suffix || o_host.ends_with(&format!(".{suffix}"));
315    }
316    false
317}
318
319/// Verify a pre-auth `authContext`: an HMAC-SHA256 over `"{user_id}:{timestamp}"`
320/// keyed by `public_key`, encoded as lowercase hex in `signature_hex`, signed no
321/// more than `max_age_secs` away from `now_unix` (replay protection).
322///
323/// Returns `false` (never panics) on any malformed input, a stale/future
324/// timestamp, or a signature mismatch. The comparison is constant-time
325/// (`Mac::verify_slice`).
326#[must_use]
327pub fn verify_auth_context(
328    public_key: &str,
329    user_id: &str,
330    signature_hex: &str,
331    timestamp: i64,
332    now_unix: i64,
333    max_age_secs: i64,
334) -> bool {
335    // Replay window: reject timestamps too far in the past or future.
336    if (now_unix - timestamp).abs() > max_age_secs {
337        return false;
338    }
339    let Ok(sig) = hex::decode(signature_hex) else {
340        return false;
341    };
342    let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(public_key.as_bytes()) else {
343        return false;
344    };
345    mac.update(format!("{user_id}:{timestamp}").as_bytes());
346    mac.verify_slice(&sig).is_ok()
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn origin_exact_and_wildcard() {
355        let allow = vec![
356            "https://app.example.com".to_string(),
357            "https://*.smoo.ai".to_string(),
358        ];
359        assert!(origin_allowed(&allow, "https://app.example.com"));
360        assert!(origin_allowed(&allow, "https://dash.smoo.ai"));
361        assert!(origin_allowed(&allow, "https://smoo.ai"));
362        assert!(!origin_allowed(&allow, "https://evil.com"));
363        // Scheme must match.
364        assert!(!origin_allowed(&allow, "http://dash.smoo.ai"));
365        // Not a sub-suffix match.
366        assert!(!origin_allowed(&allow, "https://notsmoo.ai"));
367    }
368
369    #[test]
370    fn origin_star_allows_all_but_empty_denies() {
371        assert!(origin_allowed(&["*".to_string()], "https://anything.test"));
372        assert!(!origin_allowed(&[], "https://anything.test"));
373    }
374
375    fn sign(key: &str, user: &str, ts: i64) -> String {
376        let mut mac = Hmac::<Sha256>::new_from_slice(key.as_bytes()).unwrap();
377        mac.update(format!("{user}:{ts}").as_bytes());
378        hex::encode(mac.finalize().into_bytes())
379    }
380
381    #[test]
382    fn auth_context_valid_and_invalid() {
383        let key = "super-secret-public-key";
384        let now = 1_000_000;
385        let good = sign(key, "user-123", now);
386        assert!(verify_auth_context(key, "user-123", &good, now, now, 60));
387        // Within the window but slightly old.
388        assert!(verify_auth_context(
389            key,
390            "user-123",
391            &sign(key, "user-123", now - 30),
392            now - 30,
393            now,
394            60
395        ));
396        // Wrong key.
397        assert!(!verify_auth_context(
398            "other-key",
399            "user-123",
400            &good,
401            now,
402            now,
403            60
404        ));
405        // Tampered user.
406        assert!(!verify_auth_context(key, "user-999", &good, now, now, 60));
407        // Stale (outside replay window).
408        assert!(!verify_auth_context(
409            key,
410            "user-123",
411            &sign(key, "user-123", now - 600),
412            now - 600,
413            now,
414            60
415        ));
416        // Garbage signature.
417        assert!(!verify_auth_context(
418            key, "user-123", "not-hex", now, now, 60
419        ));
420    }
421
422    #[tokio::test]
423    async fn static_provider_resolves_known_agents() {
424        let json =
425            r#"{ "agent-1": { "allowed_origins": ["https://*.smoo.ai"], "public_key": "k" } }"#;
426        let p = StaticWidgetAuth::from_json(json).unwrap();
427        let a = p.agent_widget_auth("agent-1").await.unwrap();
428        assert_eq!(a.allowed_origins, vec!["https://*.smoo.ai".to_string()]);
429        assert_eq!(a.public_key.as_deref(), Some("k"));
430        assert!(p.agent_widget_auth("unknown").await.is_none());
431    }
432
433    #[tokio::test]
434    async fn permissive_provider_returns_none() {
435        assert!(PermissiveWidgetAuth
436            .agent_widget_auth("anything")
437            .await
438            .is_none());
439    }
440
441    #[tokio::test]
442    async fn http_provider_fetches_then_serves_from_cache() {
443        use wiremock::matchers::{header, method, path};
444        use wiremock::{Mock, MockServer, ResponseTemplate};
445
446        let server = MockServer::start().await;
447        Mock::given(method("GET"))
448            .and(path("/agent-9"))
449            .and(header("authorization", "Bearer m2m-token"))
450            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
451                "allowed_origins": ["https://app.smoo.ai"],
452                "public_key": "secret"
453            })))
454            .expect(1) // second call must be served from cache, not the server
455            .mount(&server)
456            .await;
457
458        let provider = HttpWidgetAuth::new(server.uri()).with_bearer("m2m-token");
459
460        let first = provider.agent_widget_auth("agent-9").await.expect("policy");
461        assert_eq!(
462            first.allowed_origins,
463            vec!["https://app.smoo.ai".to_string()]
464        );
465        assert_eq!(first.public_key.as_deref(), Some("secret"));
466
467        // Cache hit — no second upstream request (verified by `.expect(1)` on drop).
468        let second = provider.agent_widget_auth("agent-9").await.expect("cached");
469        assert_eq!(second.public_key.as_deref(), Some("secret"));
470    }
471
472    #[tokio::test]
473    async fn http_provider_404_is_none_and_cached() {
474        use wiremock::matchers::{method, path};
475        use wiremock::{Mock, MockServer, ResponseTemplate};
476
477        let server = MockServer::start().await;
478        Mock::given(method("GET"))
479            .and(path("/ghost"))
480            .respond_with(ResponseTemplate::new(404))
481            .expect(1) // a known no-policy result is cached too
482            .mount(&server)
483            .await;
484
485        let provider = HttpWidgetAuth::new(server.uri());
486        assert!(provider.agent_widget_auth("ghost").await.is_none());
487        assert!(provider.agent_widget_auth("ghost").await.is_none()); // cached
488    }
489
490    #[tokio::test]
491    async fn http_provider_server_error_is_none_and_not_cached() {
492        use wiremock::matchers::{method, path};
493        use wiremock::{Mock, MockServer, ResponseTemplate};
494
495        let server = MockServer::start().await;
496        Mock::given(method("GET"))
497            .and(path("/flaky"))
498            .respond_with(ResponseTemplate::new(500))
499            .expect(2) // NOT cached on error → the next call retries upstream
500            .mount(&server)
501            .await;
502
503        let provider = HttpWidgetAuth::new(server.uri());
504        assert!(provider.agent_widget_auth("flaky").await.is_none());
505        assert!(provider.agent_widget_auth("flaky").await.is_none()); // retried, not cached
506    }
507}