Skip to main content

smooai_config/
bootstrap.rs

1//! Lightweight cold-start config fetcher.
2//!
3//! This module exists for callers that need to read a single config
4//! value from a deploy script, container entry-point, or other
5//! cold-start context where the full SDK is too heavy or pulls in a
6//! problematic transitive dependency.
7//!
8//! It has **zero** imports from other modules in this crate and uses
9//! only `reqwest` + `serde_json` (which are already crate deps).
10//!
11//! It performs a single OAuth `client_credentials` exchange, then a
12//! single GET against `/organizations/{org_id}/config/values` and
13//! caches the values map per-process per-env so repeated reads inside
14//! the same process avoid the round-trip.
15//!
16//! Inputs (read from `std::env`):
17//!
18//! - `SMOOAI_CONFIG_API_URL` — base URL (default `https://api.smoo.ai`)
19//! - `SMOOAI_CONFIG_AUTH_URL` — OAuth base URL (default `https://auth.smoo.ai`;
20//!   legacy `SMOOAI_AUTH_URL` also accepted)
21//! - `SMOOAI_CONFIG_CLIENT_ID` — OAuth M2M client id
22//! - `SMOOAI_CONFIG_CLIENT_SECRET` — OAuth M2M client secret
23//!   (legacy `SMOOAI_CONFIG_API_KEY` accepted)
24//! - `SMOOAI_CONFIG_ORG_ID` — target org id
25//! - `SMOOAI_CONFIG_ENV` — default env name (fallback when no SST stage)
26
27use std::collections::HashMap;
28use std::sync::Mutex;
29
30use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
31use serde_json::Value;
32use thiserror::Error;
33
34/// URL-encode characters: anything not in unreserved set per RFC 3986.
35/// (alphanumeric, `-`, `_`, `.`, `~` are left alone — same as JS encodeURIComponent.)
36const URL_ENCODE_SET: &AsciiSet = &CONTROLS
37    .add(b' ')
38    .add(b'!')
39    .add(b'"')
40    .add(b'#')
41    .add(b'$')
42    .add(b'%')
43    .add(b'&')
44    .add(b'\'')
45    .add(b'(')
46    .add(b')')
47    .add(b'*')
48    .add(b'+')
49    .add(b',')
50    .add(b'/')
51    .add(b':')
52    .add(b';')
53    .add(b'<')
54    .add(b'=')
55    .add(b'>')
56    .add(b'?')
57    .add(b'@')
58    .add(b'[')
59    .add(b'\\')
60    .add(b']')
61    .add(b'^')
62    .add(b'`')
63    .add(b'{')
64    .add(b'|')
65    .add(b'}');
66
67/// Errors returned by [`bootstrap_fetch`].
68#[derive(Debug, Error)]
69pub enum BootstrapError {
70    #[error("[smooai-config/bootstrap] missing SMOOAI_CONFIG_{{CLIENT_ID,CLIENT_SECRET,ORG_ID}} in env. Set these (e.g. via `pnpm sst shell --stage <stage>`) before calling bootstrap_fetch.")]
71    MissingCredentials,
72    #[error("[smooai-config/bootstrap] OAuth token exchange failed: HTTP {status} {body}")]
73    OAuthFailed { status: u16, body: String },
74    #[error("[smooai-config/bootstrap] OAuth token endpoint returned no access_token")]
75    MissingAccessToken,
76    #[error("[smooai-config/bootstrap] GET /config/values failed: HTTP {status} {body}")]
77    ValuesFailed { status: u16, body: String },
78    #[error("[smooai-config/bootstrap] HTTP error: {0}")]
79    Http(#[from] reqwest::Error),
80    #[error("[smooai-config/bootstrap] response not JSON: {0}")]
81    InvalidJson(#[from] serde_json::Error),
82}
83
84#[derive(Debug, Clone)]
85struct BootstrapCreds {
86    api_url: String,
87    auth_url: String,
88    client_id: String,
89    client_secret: String,
90    org_id: String,
91}
92
93fn first_non_empty(values: &[Option<String>]) -> Option<String> {
94    values
95        .iter()
96        .find_map(|v| v.as_ref().filter(|s| !s.is_empty()).cloned())
97}
98
99fn read_creds(env: &HashMap<String, String>) -> Result<BootstrapCreds, BootstrapError> {
100    let api_url = env
101        .get("SMOOAI_CONFIG_API_URL")
102        .cloned()
103        .filter(|s| !s.is_empty())
104        .unwrap_or_else(|| "https://api.smoo.ai".to_string());
105    let auth_url = first_non_empty(&[
106        env.get("SMOOAI_CONFIG_AUTH_URL").cloned(),
107        env.get("SMOOAI_AUTH_URL").cloned(),
108    ])
109    .unwrap_or_else(|| "https://auth.smoo.ai".to_string());
110    let client_id = env.get("SMOOAI_CONFIG_CLIENT_ID").cloned().unwrap_or_default();
111    let client_secret = first_non_empty(&[
112        env.get("SMOOAI_CONFIG_CLIENT_SECRET").cloned(),
113        env.get("SMOOAI_CONFIG_API_KEY").cloned(),
114    ])
115    .unwrap_or_default();
116    let org_id = env.get("SMOOAI_CONFIG_ORG_ID").cloned().unwrap_or_default();
117
118    if client_id.is_empty() || client_secret.is_empty() || org_id.is_empty() {
119        return Err(BootstrapError::MissingCredentials);
120    }
121    Ok(BootstrapCreds {
122        api_url,
123        auth_url,
124        client_id,
125        client_secret,
126        org_id,
127    })
128}
129
130fn resolve_env(env: &HashMap<String, String>, explicit: Option<&str>) -> String {
131    if let Some(e) = explicit {
132        if !e.is_empty() {
133            return e.to_string();
134        }
135    }
136    let mut stage = env.get("SST_STAGE").cloned().filter(|s| !s.is_empty());
137    if stage.is_none() {
138        stage = env.get("NEXT_PUBLIC_SST_STAGE").cloned().filter(|s| !s.is_empty());
139    }
140    if stage.is_none() {
141        if let Some(raw) = env.get("SST_RESOURCE_App").filter(|s| !s.is_empty()) {
142            if let Ok(parsed) = serde_json::from_str::<Value>(raw) {
143                if let Some(s) = parsed.get("stage").and_then(|v| v.as_str()) {
144                    if !s.is_empty() {
145                        stage = Some(s.to_string());
146                    }
147                }
148            }
149        }
150    }
151    match stage {
152        Some(s) if s == "production" => "production".to_string(),
153        Some(s) => s,
154        None => env
155            .get("SMOOAI_CONFIG_ENV")
156            .cloned()
157            .filter(|s| !s.is_empty())
158            .unwrap_or_else(|| "development".to_string()),
159    }
160}
161
162/// In-process cache of fetched values, keyed by env name.
163static CACHE: Mutex<Option<(String, HashMap<String, Value>)>> = Mutex::new(None);
164
165/// Test-only: clear the in-process cache.
166#[doc(hidden)]
167pub fn __reset_bootstrap_cache() {
168    let mut guard = CACHE.lock().unwrap();
169    *guard = None;
170}
171
172fn env_map() -> HashMap<String, String> {
173    std::env::vars().collect()
174}
175
176/// Fetch a single config value by camelCase key.
177///
178/// Returns `Ok(None)` if the key is not present in the values map. Only
179/// env, auth, and network failures produce errors.
180///
181/// The full values map is cached per-process per-env after the first
182/// call.
183pub async fn bootstrap_fetch(key: &str, environment: Option<&str>) -> Result<Option<String>, BootstrapError> {
184    bootstrap_fetch_with_env(key, environment, &env_map(), &reqwest::Client::new()).await
185}
186
187/// Same as [`bootstrap_fetch`] but with an explicit env map and client.
188/// Useful for tests; not part of the stable public API.
189#[doc(hidden)]
190pub async fn bootstrap_fetch_with_env(
191    key: &str,
192    environment: Option<&str>,
193    env: &HashMap<String, String>,
194    client: &reqwest::Client,
195) -> Result<Option<String>, BootstrapError> {
196    let env_name = resolve_env(env, environment);
197
198    let need_fetch = {
199        let guard = CACHE.lock().unwrap();
200        match guard.as_ref() {
201            Some((cached_env, _)) => cached_env != &env_name,
202            None => true,
203        }
204    };
205
206    if need_fetch {
207        let creds = read_creds(env)?;
208        let token = mint_access_token(client, &creds).await?;
209        let values = fetch_values(client, &creds, &token, &env_name).await?;
210        let mut guard = CACHE.lock().unwrap();
211        *guard = Some((env_name.clone(), values));
212    }
213
214    let guard = CACHE.lock().unwrap();
215    let values = &guard.as_ref().expect("cache populated above").1;
216    Ok(values.get(key).and_then(value_to_string))
217}
218
219fn value_to_string(v: &Value) -> Option<String> {
220    match v {
221        Value::Null => None,
222        Value::String(s) => Some(s.clone()),
223        Value::Bool(b) => Some(if *b { "true".to_string() } else { "false".to_string() }),
224        Value::Number(n) => Some(n.to_string()),
225        other => Some(other.to_string()),
226    }
227}
228
229async fn mint_access_token(client: &reqwest::Client, creds: &BootstrapCreds) -> Result<String, BootstrapError> {
230    let auth_base = creds.auth_url.trim_end_matches('/');
231    let url = format!("{}/token", auth_base);
232    let form = [
233        ("grant_type", "client_credentials"),
234        ("provider", "client_credentials"),
235        ("client_id", creds.client_id.as_str()),
236        ("client_secret", creds.client_secret.as_str()),
237    ];
238
239    let resp = client.post(&url).form(&form).send().await?;
240    let status = resp.status();
241    let body = resp.text().await.unwrap_or_default();
242    if !status.is_success() {
243        return Err(BootstrapError::OAuthFailed {
244            status: status.as_u16(),
245            body,
246        });
247    }
248    let parsed: Value = serde_json::from_str(&body)?;
249    let token = parsed
250        .get("access_token")
251        .and_then(|v| v.as_str())
252        .map(|s| s.to_string());
253    token
254        .filter(|t| !t.is_empty())
255        .ok_or(BootstrapError::MissingAccessToken)
256}
257
258async fn fetch_values(
259    client: &reqwest::Client,
260    creds: &BootstrapCreds,
261    token: &str,
262    env: &str,
263) -> Result<HashMap<String, Value>, BootstrapError> {
264    let api_base = creds.api_url.trim_end_matches('/');
265    let org = utf8_percent_encode(&creds.org_id, URL_ENCODE_SET).to_string();
266    let env_enc = utf8_percent_encode(env, URL_ENCODE_SET).to_string();
267    let url = format!(
268        "{}/organizations/{}/config/values?environment={}",
269        api_base, org, env_enc
270    );
271    let resp = client
272        .get(&url)
273        .bearer_auth(token)
274        .header("Accept", "application/json")
275        .send()
276        .await?;
277    let status = resp.status();
278    let body = resp.text().await.unwrap_or_default();
279    if !status.is_success() {
280        return Err(BootstrapError::ValuesFailed {
281            status: status.as_u16(),
282            body,
283        });
284    }
285    let parsed: Value = serde_json::from_str(&body)?;
286    let values = parsed
287        .get("values")
288        .and_then(|v| v.as_object())
289        .map(|m| {
290            m.iter()
291                .map(|(k, v)| (k.clone(), v.clone()))
292                .collect::<HashMap<String, Value>>()
293        })
294        .unwrap_or_default();
295    Ok(values)
296}
297
298#[cfg(test)]
299#[allow(clippy::await_holding_lock)]
300mod tests {
301    use super::*;
302    use serde_json::json;
303    use std::sync::Mutex as StdMutex;
304    use wiremock::matchers::{method, path, query_param};
305    use wiremock::{Mock, MockServer, ResponseTemplate};
306
307    // All bootstrap_fetch_with_env tests share the process-wide CACHE,
308    // so we serialize them with a dedicated mutex.
309    static TEST_LOCK: StdMutex<()> = StdMutex::new(());
310
311    fn lock_and_reset() -> std::sync::MutexGuard<'static, ()> {
312        // Recover from any prior poisoned panic so a single failing
313        // test doesn't break the rest of the suite.
314        let g = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
315        super::__reset_bootstrap_cache();
316        g
317    }
318
319    fn base_env(server_url: &str) -> HashMap<String, String> {
320        let mut m = HashMap::new();
321        m.insert("SMOOAI_CONFIG_API_URL".into(), server_url.into());
322        m.insert("SMOOAI_CONFIG_AUTH_URL".into(), server_url.into());
323        m.insert("SMOOAI_CONFIG_CLIENT_ID".into(), "client-id-123".into());
324        m.insert("SMOOAI_CONFIG_CLIENT_SECRET".into(), "client-secret-456".into());
325        m.insert("SMOOAI_CONFIG_ORG_ID".into(), "org-789".into());
326        m
327    }
328
329    async fn mount_oauth_ok(server: &MockServer, token: &str) {
330        Mock::given(method("POST"))
331            .and(path("/token"))
332            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"access_token": token})))
333            .mount(server)
334            .await;
335    }
336
337    async fn mount_values(server: &MockServer, env: &str, values: serde_json::Value) {
338        Mock::given(method("GET"))
339            .and(path("/organizations/org-789/config/values"))
340            .and(query_param("environment", env))
341            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"values": values})))
342            .mount(server)
343            .await;
344    }
345
346    #[tokio::test]
347    async fn returns_value_for_known_key() {
348        let _g = lock_and_reset();
349        let server = MockServer::start().await;
350        mount_oauth_ok(&server, "TOKEN").await;
351        mount_values(&server, "development", json!({"databaseUrl": "postgres://x"})).await;
352        let env = base_env(&server.uri());
353        let v = bootstrap_fetch_with_env("databaseUrl", None, &env, &reqwest::Client::new())
354            .await
355            .unwrap();
356        assert_eq!(v, Some("postgres://x".to_string()));
357    }
358
359    #[tokio::test]
360    async fn returns_none_for_missing_key() {
361        let _g = lock_and_reset();
362        let server = MockServer::start().await;
363        mount_oauth_ok(&server, "T").await;
364        mount_values(&server, "development", json!({"other": "x"})).await;
365        let env = base_env(&server.uri());
366        let v = bootstrap_fetch_with_env("databaseUrl", None, &env, &reqwest::Client::new())
367            .await
368            .unwrap();
369        assert_eq!(v, None);
370    }
371
372    #[tokio::test]
373    async fn caches_values_per_env() {
374        let _g = lock_and_reset();
375        let server = MockServer::start().await;
376        // Each mount with `expect(1)` would fail if called more than once.
377        Mock::given(method("POST"))
378            .and(path("/token"))
379            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"access_token": "T"})))
380            .expect(1)
381            .mount(&server)
382            .await;
383        Mock::given(method("GET"))
384            .and(path("/organizations/org-789/config/values"))
385            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"values": {"a": "1", "b": "2"}})))
386            .expect(1)
387            .mount(&server)
388            .await;
389        let env = base_env(&server.uri());
390        let c = reqwest::Client::new();
391        assert_eq!(
392            bootstrap_fetch_with_env("a", None, &env, &c).await.unwrap(),
393            Some("1".into())
394        );
395        assert_eq!(
396            bootstrap_fetch_with_env("b", None, &env, &c).await.unwrap(),
397            Some("2".into())
398        );
399    }
400
401    #[tokio::test]
402    async fn refetches_on_env_change() {
403        let _g = lock_and_reset();
404        let server = MockServer::start().await;
405        Mock::given(method("POST"))
406            .and(path("/token"))
407            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"access_token": "T"})))
408            .expect(2)
409            .mount(&server)
410            .await;
411        mount_values(&server, "development", json!({"a": "dev"})).await;
412        mount_values(&server, "production", json!({"a": "prod"})).await;
413        let env = base_env(&server.uri());
414        let c = reqwest::Client::new();
415        assert_eq!(
416            bootstrap_fetch_with_env("a", Some("development"), &env, &c)
417                .await
418                .unwrap(),
419            Some("dev".into())
420        );
421        assert_eq!(
422            bootstrap_fetch_with_env("a", Some("production"), &env, &c)
423                .await
424                .unwrap(),
425            Some("prod".into())
426        );
427    }
428
429    #[tokio::test]
430    async fn missing_creds_errors() {
431        let _g = lock_and_reset();
432        let mut env = base_env("http://example.test");
433        env.remove("SMOOAI_CONFIG_CLIENT_ID");
434        let err = bootstrap_fetch_with_env("k", None, &env, &reqwest::Client::new())
435            .await
436            .unwrap_err();
437        matches!(err, BootstrapError::MissingCredentials);
438    }
439
440    #[tokio::test]
441    async fn accepts_legacy_api_key() {
442        let _g = lock_and_reset();
443        let server = MockServer::start().await;
444        Mock::given(method("POST"))
445            .and(path("/token"))
446            .and(wiremock::matchers::body_string_contains("client_secret=legacy-secret"))
447            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"access_token": "T"})))
448            .expect(1)
449            .mount(&server)
450            .await;
451        mount_values(&server, "development", json!({"k": "v"})).await;
452        let mut env = base_env(&server.uri());
453        env.remove("SMOOAI_CONFIG_CLIENT_SECRET");
454        env.insert("SMOOAI_CONFIG_API_KEY".into(), "legacy-secret".into());
455        let v = bootstrap_fetch_with_env("k", None, &env, &reqwest::Client::new())
456            .await
457            .unwrap();
458        assert_eq!(v, Some("v".into()));
459    }
460
461    #[tokio::test]
462    async fn oauth_failure_returns_error() {
463        let _g = lock_and_reset();
464        let server = MockServer::start().await;
465        Mock::given(method("POST"))
466            .and(path("/token"))
467            .respond_with(ResponseTemplate::new(401).set_body_string("invalid_client"))
468            .mount(&server)
469            .await;
470        let env = base_env(&server.uri());
471        let err = bootstrap_fetch_with_env("k", None, &env, &reqwest::Client::new())
472            .await
473            .unwrap_err();
474        match err {
475            BootstrapError::OAuthFailed { status, .. } => assert_eq!(status, 401),
476            _ => panic!("expected OAuthFailed, got {:?}", err),
477        }
478    }
479
480    #[tokio::test]
481    async fn values_failure_returns_error() {
482        let _g = lock_and_reset();
483        let server = MockServer::start().await;
484        mount_oauth_ok(&server, "T").await;
485        Mock::given(method("GET"))
486            .and(path("/organizations/org-789/config/values"))
487            .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
488            .mount(&server)
489            .await;
490        let env = base_env(&server.uri());
491        let err = bootstrap_fetch_with_env("k", None, &env, &reqwest::Client::new())
492            .await
493            .unwrap_err();
494        match err {
495            BootstrapError::ValuesFailed { status, .. } => assert_eq!(status, 500),
496            _ => panic!("expected ValuesFailed, got {:?}", err),
497        }
498    }
499
500    #[tokio::test]
501    async fn oauth_missing_access_token_errors() {
502        let _g = lock_and_reset();
503        let server = MockServer::start().await;
504        Mock::given(method("POST"))
505            .and(path("/token"))
506            .respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
507            .mount(&server)
508            .await;
509        let env = base_env(&server.uri());
510        let err = bootstrap_fetch_with_env("k", None, &env, &reqwest::Client::new())
511            .await
512            .unwrap_err();
513        matches!(err, BootstrapError::MissingAccessToken);
514    }
515
516    #[test]
517    fn resolve_env_explicit_wins() {
518        let mut env = HashMap::new();
519        env.insert("SST_STAGE".into(), "ignored".into());
520        assert_eq!(resolve_env(&env, Some("explicit")), "explicit");
521    }
522
523    #[test]
524    fn resolve_env_sst_stage() {
525        let mut env = HashMap::new();
526        env.insert("SST_STAGE".into(), "brentrager".into());
527        assert_eq!(resolve_env(&env, None), "brentrager");
528    }
529
530    #[test]
531    fn resolve_env_next_public_stage() {
532        let mut env = HashMap::new();
533        env.insert("NEXT_PUBLIC_SST_STAGE".into(), "dev-stage".into());
534        assert_eq!(resolve_env(&env, None), "dev-stage");
535    }
536
537    #[test]
538    fn resolve_env_sst_resource_app() {
539        let mut env = HashMap::new();
540        env.insert("SST_RESOURCE_App".into(), r#"{"stage":"sst-resource-stage"}"#.into());
541        assert_eq!(resolve_env(&env, None), "sst-resource-stage");
542    }
543
544    #[test]
545    fn resolve_env_production() {
546        let mut env = HashMap::new();
547        env.insert("SST_STAGE".into(), "production".into());
548        assert_eq!(resolve_env(&env, None), "production");
549    }
550
551    #[test]
552    fn resolve_env_smooai_env_fallback() {
553        let mut env = HashMap::new();
554        env.insert("SMOOAI_CONFIG_ENV".into(), "qa".into());
555        assert_eq!(resolve_env(&env, None), "qa");
556    }
557
558    #[test]
559    fn resolve_env_development_default() {
560        let env = HashMap::new();
561        assert_eq!(resolve_env(&env, None), "development");
562    }
563
564    #[test]
565    fn resolve_env_malformed_sst_resource_app_falls_through() {
566        let mut env = HashMap::new();
567        env.insert("SST_RESOURCE_App".into(), "{not json".into());
568        env.insert("SMOOAI_CONFIG_ENV".into(), "qa".into());
569        assert_eq!(resolve_env(&env, None), "qa");
570    }
571
572    #[tokio::test]
573    async fn stringifies_non_string_values() {
574        let _g = lock_and_reset();
575        let server = MockServer::start().await;
576        mount_oauth_ok(&server, "T").await;
577        mount_values(&server, "development", json!({"count": 42, "flag": true, "pi": 3.5})).await;
578        let env = base_env(&server.uri());
579        let c = reqwest::Client::new();
580        assert_eq!(
581            bootstrap_fetch_with_env("count", None, &env, &c).await.unwrap(),
582            Some("42".into())
583        );
584        assert_eq!(
585            bootstrap_fetch_with_env("flag", None, &env, &c).await.unwrap(),
586            Some("true".into())
587        );
588        assert_eq!(
589            bootstrap_fetch_with_env("pi", None, &env, &c).await.unwrap(),
590            Some("3.5".into())
591        );
592    }
593}