1use std::collections::HashMap;
24use std::sync::RwLock;
25use std::time::{Duration, Instant};
26
27use async_trait::async_trait;
28use hmac::{Hmac, Mac};
29use sha2::Sha256;
30
31#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
33pub struct AgentWidgetAuth {
34 #[serde(default)]
39 pub allowed_origins: Vec<String>,
40 #[serde(default)]
43 pub public_key: Option<String>,
44 #[serde(default)]
52 pub organization_id: Option<String>,
53}
54
55#[async_trait]
61pub trait WidgetAuthProvider: Send + Sync {
62 async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth>;
65}
66
67#[derive(Debug, Default)]
70pub struct PermissiveWidgetAuth;
71
72#[async_trait]
73impl WidgetAuthProvider for PermissiveWidgetAuth {
74 async fn agent_widget_auth(&self, _agent_id: &str) -> Option<AgentWidgetAuth> {
75 None
76 }
77}
78
79#[derive(Debug, Default)]
82pub struct StaticWidgetAuth {
83 rows: HashMap<String, AgentWidgetAuth>,
84}
85
86impl StaticWidgetAuth {
87 #[must_use]
89 pub fn new(rows: HashMap<String, AgentWidgetAuth>) -> Self {
90 Self { rows }
91 }
92
93 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
99 let rows: HashMap<String, AgentWidgetAuth> = serde_json::from_str(json)?;
100 Ok(Self { rows })
101 }
102}
103
104#[async_trait]
105impl WidgetAuthProvider for StaticWidgetAuth {
106 async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth> {
107 self.rows.get(agent_id).cloned()
108 }
109}
110
111struct CacheEntry {
114 value: Option<AgentWidgetAuth>,
115 fetched: Instant,
116}
117
118pub struct HttpWidgetAuth {
139 client: reqwest::Client,
140 base_url: String,
143 bearer: Option<String>,
145 ttl: Duration,
146 cache: RwLock<HashMap<String, CacheEntry>>,
147}
148
149impl HttpWidgetAuth {
150 #[must_use]
154 pub fn new(base_url: impl Into<String>) -> Self {
155 let client = reqwest::Client::builder()
156 .timeout(Duration::from_secs(5))
157 .build()
158 .unwrap_or_default();
159 Self::with_client(base_url, client)
160 }
161
162 #[must_use]
165 pub fn with_client(base_url: impl Into<String>, client: reqwest::Client) -> Self {
166 Self {
167 client,
168 base_url: base_url.into().trim_end_matches('/').to_string(),
169 bearer: None,
170 ttl: Duration::from_secs(60),
171 cache: RwLock::new(HashMap::new()),
172 }
173 }
174
175 #[must_use]
177 pub fn with_bearer(mut self, token: impl Into<String>) -> Self {
178 self.bearer = Some(token.into());
179 self
180 }
181
182 #[must_use]
184 pub fn with_ttl(mut self, ttl: Duration) -> Self {
185 self.ttl = ttl;
186 self
187 }
188
189 fn cached(&self, agent_id: &str) -> Option<Option<AgentWidgetAuth>> {
192 let cache = self.cache.read().ok()?;
193 let entry = cache.get(agent_id)?;
194 if entry.fetched.elapsed() < self.ttl {
195 Some(entry.value.clone())
196 } else {
197 None
198 }
199 }
200
201 fn store(&self, agent_id: &str, value: Option<AgentWidgetAuth>) {
203 if let Ok(mut cache) = self.cache.write() {
204 cache.insert(
205 agent_id.to_string(),
206 CacheEntry {
207 value,
208 fetched: Instant::now(),
209 },
210 );
211 }
212 }
213}
214
215#[async_trait]
216impl WidgetAuthProvider for HttpWidgetAuth {
217 async fn agent_widget_auth(&self, agent_id: &str) -> Option<AgentWidgetAuth> {
218 if let Some(cached) = self.cached(agent_id) {
219 return cached;
220 }
221
222 let mut url = match reqwest::Url::parse(&self.base_url) {
225 Ok(u) => u,
226 Err(e) => {
227 tracing::warn!(error = %e, base_url = %self.base_url, "widget-auth: invalid base_url");
228 return None;
229 }
230 };
231 match url.path_segments_mut() {
232 Ok(mut segs) => {
233 segs.push(agent_id);
234 }
235 Err(()) => {
236 tracing::warn!(base_url = %self.base_url, "widget-auth: base_url cannot be a base");
237 return None;
238 }
239 }
240
241 let mut req = self.client.get(url);
242 if let Some(bearer) = &self.bearer {
243 req = req.bearer_auth(bearer);
244 }
245
246 let resp = match req.send().await {
247 Ok(r) => r,
248 Err(e) => {
249 tracing::warn!(error = %e, agent_id, "widget-auth: policy fetch failed");
251 return None;
252 }
253 };
254
255 let status = resp.status();
256 if status.is_success() {
257 match resp.json::<AgentWidgetAuth>().await {
258 Ok(policy) => {
259 let value = Some(policy);
260 self.store(agent_id, value.clone());
261 value
262 }
263 Err(e) => {
264 tracing::warn!(error = %e, agent_id, "widget-auth: malformed policy body");
266 None
267 }
268 }
269 } else if status == reqwest::StatusCode::NOT_FOUND {
270 self.store(agent_id, None);
272 None
273 } else {
274 tracing::warn!(%status, agent_id, "widget-auth: policy service error");
277 None
278 }
279 }
280}
281
282#[must_use]
291pub fn origin_allowed(allowed: &[String], origin: &str) -> bool {
292 allowed
293 .iter()
294 .any(|pattern| origin_matches(pattern, origin))
295}
296
297fn origin_matches(pattern: &str, origin: &str) -> bool {
298 if pattern == "*" {
299 return true;
300 }
301 if pattern == origin {
302 return true;
303 }
304 let (Some((p_scheme, p_host)), Some((o_scheme, o_host))) =
306 (pattern.split_once("://"), origin.split_once("://"))
307 else {
308 return false;
309 };
310 if p_scheme != o_scheme {
311 return false;
312 }
313 if let Some(suffix) = p_host.strip_prefix("*.") {
314 return o_host == suffix || o_host.ends_with(&format!(".{suffix}"));
315 }
316 false
317}
318
319#[must_use]
327pub fn verify_auth_context(
328 public_key: &str,
329 user_id: &str,
330 signature_hex: &str,
331 timestamp: i64,
332 now_unix: i64,
333 max_age_secs: i64,
334) -> bool {
335 if (now_unix - timestamp).abs() > max_age_secs {
337 return false;
338 }
339 let Ok(sig) = hex::decode(signature_hex) else {
340 return false;
341 };
342 let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(public_key.as_bytes()) else {
343 return false;
344 };
345 mac.update(format!("{user_id}:{timestamp}").as_bytes());
346 mac.verify_slice(&sig).is_ok()
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn origin_exact_and_wildcard() {
355 let allow = vec![
356 "https://app.example.com".to_string(),
357 "https://*.smoo.ai".to_string(),
358 ];
359 assert!(origin_allowed(&allow, "https://app.example.com"));
360 assert!(origin_allowed(&allow, "https://dash.smoo.ai"));
361 assert!(origin_allowed(&allow, "https://smoo.ai"));
362 assert!(!origin_allowed(&allow, "https://evil.com"));
363 assert!(!origin_allowed(&allow, "http://dash.smoo.ai"));
365 assert!(!origin_allowed(&allow, "https://notsmoo.ai"));
367 }
368
369 #[test]
370 fn origin_star_allows_all_but_empty_denies() {
371 assert!(origin_allowed(&["*".to_string()], "https://anything.test"));
372 assert!(!origin_allowed(&[], "https://anything.test"));
373 }
374
375 fn sign(key: &str, user: &str, ts: i64) -> String {
376 let mut mac = Hmac::<Sha256>::new_from_slice(key.as_bytes()).unwrap();
377 mac.update(format!("{user}:{ts}").as_bytes());
378 hex::encode(mac.finalize().into_bytes())
379 }
380
381 #[test]
382 fn auth_context_valid_and_invalid() {
383 let key = "super-secret-public-key";
384 let now = 1_000_000;
385 let good = sign(key, "user-123", now);
386 assert!(verify_auth_context(key, "user-123", &good, now, now, 60));
387 assert!(verify_auth_context(
389 key,
390 "user-123",
391 &sign(key, "user-123", now - 30),
392 now - 30,
393 now,
394 60
395 ));
396 assert!(!verify_auth_context(
398 "other-key",
399 "user-123",
400 &good,
401 now,
402 now,
403 60
404 ));
405 assert!(!verify_auth_context(key, "user-999", &good, now, now, 60));
407 assert!(!verify_auth_context(
409 key,
410 "user-123",
411 &sign(key, "user-123", now - 600),
412 now - 600,
413 now,
414 60
415 ));
416 assert!(!verify_auth_context(
418 key, "user-123", "not-hex", now, now, 60
419 ));
420 }
421
422 #[tokio::test]
423 async fn static_provider_resolves_known_agents() {
424 let json =
425 r#"{ "agent-1": { "allowed_origins": ["https://*.smoo.ai"], "public_key": "k" } }"#;
426 let p = StaticWidgetAuth::from_json(json).unwrap();
427 let a = p.agent_widget_auth("agent-1").await.unwrap();
428 assert_eq!(a.allowed_origins, vec!["https://*.smoo.ai".to_string()]);
429 assert_eq!(a.public_key.as_deref(), Some("k"));
430 assert!(p.agent_widget_auth("unknown").await.is_none());
431 }
432
433 #[tokio::test]
434 async fn permissive_provider_returns_none() {
435 assert!(PermissiveWidgetAuth
436 .agent_widget_auth("anything")
437 .await
438 .is_none());
439 }
440
441 #[tokio::test]
442 async fn http_provider_fetches_then_serves_from_cache() {
443 use wiremock::matchers::{header, method, path};
444 use wiremock::{Mock, MockServer, ResponseTemplate};
445
446 let server = MockServer::start().await;
447 Mock::given(method("GET"))
448 .and(path("/agent-9"))
449 .and(header("authorization", "Bearer m2m-token"))
450 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
451 "allowed_origins": ["https://app.smoo.ai"],
452 "public_key": "secret"
453 })))
454 .expect(1) .mount(&server)
456 .await;
457
458 let provider = HttpWidgetAuth::new(server.uri()).with_bearer("m2m-token");
459
460 let first = provider.agent_widget_auth("agent-9").await.expect("policy");
461 assert_eq!(
462 first.allowed_origins,
463 vec!["https://app.smoo.ai".to_string()]
464 );
465 assert_eq!(first.public_key.as_deref(), Some("secret"));
466
467 let second = provider.agent_widget_auth("agent-9").await.expect("cached");
469 assert_eq!(second.public_key.as_deref(), Some("secret"));
470 }
471
472 #[tokio::test]
473 async fn http_provider_404_is_none_and_cached() {
474 use wiremock::matchers::{method, path};
475 use wiremock::{Mock, MockServer, ResponseTemplate};
476
477 let server = MockServer::start().await;
478 Mock::given(method("GET"))
479 .and(path("/ghost"))
480 .respond_with(ResponseTemplate::new(404))
481 .expect(1) .mount(&server)
483 .await;
484
485 let provider = HttpWidgetAuth::new(server.uri());
486 assert!(provider.agent_widget_auth("ghost").await.is_none());
487 assert!(provider.agent_widget_auth("ghost").await.is_none()); }
489
490 #[tokio::test]
491 async fn http_provider_server_error_is_none_and_not_cached() {
492 use wiremock::matchers::{method, path};
493 use wiremock::{Mock, MockServer, ResponseTemplate};
494
495 let server = MockServer::start().await;
496 Mock::given(method("GET"))
497 .and(path("/flaky"))
498 .respond_with(ResponseTemplate::new(500))
499 .expect(2) .mount(&server)
501 .await;
502
503 let provider = HttpWidgetAuth::new(server.uri());
504 assert!(provider.agent_widget_auth("flaky").await.is_none());
505 assert!(provider.agent_widget_auth("flaky").await.is_none()); }
507}