Skip to main content

smooai_config/
build.rs

1//! Deploy-time baker for `smooai-config` (Rust parity with TypeScript/Python).
2//!
3//! Fetches every config value for an environment via [`ConfigClient`], partitions
4//! them into `public` and `secret` sections (feature flags are skipped),
5//! JSON-encodes the partition, and encrypts with AES-256-GCM. The caller writes
6//! the resulting blob to disk, ships it with the function bundle, and sets two
7//! environment variables on the runtime:
8//!
9//! ```text
10//! SMOO_CONFIG_KEY_FILE = <absolute path to the blob at runtime>
11//! SMOO_CONFIG_KEY      = <returned key_b64>
12//! ```
13//!
14//! At cold start, [`crate::runtime::build_config_runtime`] reads both and
15//! decrypts once into an in-memory cache.
16//!
17//! Blob layout (wire-compatible with the TypeScript + Python bakers):
18//! `nonce (12 random bytes) || ciphertext || authTag (16 bytes)`.
19
20use std::collections::HashMap;
21
22use aes_gcm::aead::{Aead, KeyInit, OsRng, Payload};
23use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
24use base64::engine::general_purpose::STANDARD as B64;
25use base64::Engine as _;
26use serde_json::Value;
27use thiserror::Error;
28
29use crate::client::{ConfigClient, ConfigClientError};
30
31/// Classification returned by a [`Classifier`].
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Classification {
34    /// Bake into the `public` partition of the blob.
35    Public,
36    /// Bake into the `secret` partition of the blob.
37    Secret,
38    /// Drop — not included in the blob (typically feature flags).
39    Skip,
40}
41
42/// Classifier type: given a key + value, decides which partition the key lands in.
43pub type Classifier = Box<dyn Fn(&str, &Value) -> Classification + Send + Sync>;
44
45/// Inputs for [`build_bundle`]. Mirrors the TypeScript `BuildBundleOptions` shape —
46/// bundles the [`ConfigClient`] connection params plus an optional classifier.
47pub struct BuildBundleOptions {
48    /// Base URL of the config API, e.g. `https://config.smoo.ai`.
49    pub base_url: String,
50    /// OAuth issuer base URL, e.g. `https://auth.smoo.ai`. `None` falls back to
51    /// `SMOOAI_CONFIG_AUTH_URL` env var (or the default `https://auth.smoo.ai`).
52    /// SMOODEV-975.
53    pub auth_url: Option<String>,
54    /// OAuth2 client ID. SMOODEV-975 — when `None`, the runtime falls back
55    /// to `api_key` so legacy deploy scripts that only ever set a single
56    /// secret still authenticate.
57    pub client_id: Option<String>,
58    /// OAuth2 client secret used to mint a JWT. (Field name retained for
59    /// backwards-compat with existing deploy glue; treat it as the client secret.)
60    pub api_key: String,
61    /// Organization ID that owns the config values.
62    pub org_id: String,
63    /// Environment to fetch (e.g. `production`, `staging`). Defaults to the
64    /// client's own default environment when `None`.
65    pub environment: Option<String>,
66    /// Per-key classifier. If `None`, every key lands in `public`. Use a
67    /// schema-driven classifier for the typical case — the default is rarely
68    /// what production code wants.
69    pub classify: Option<Classifier>,
70}
71
72/// Output of [`build_bundle`].
73#[derive(Debug)]
74pub struct BuildBundleResult {
75    /// Base64-encoded 32-byte AES-256 key. Set as `SMOO_CONFIG_KEY`.
76    pub key_b64: String,
77    /// Encrypted blob: `nonce || ciphertext || authTag`. Write to disk and
78    /// bundle with the function. Point `SMOO_CONFIG_KEY_FILE` at the path.
79    pub blob: Vec<u8>,
80    /// Size of the blob in bytes.
81    pub size: u64,
82    /// Number of keys baked into the blob (public + secret).
83    pub key_count: usize,
84    /// Number of keys skipped (e.g. feature flags).
85    pub skipped_count: usize,
86}
87
88/// Errors produced by [`build_bundle`].
89#[derive(Debug, Error)]
90pub enum BuildError {
91    /// The live config fetch via [`ConfigClient`] failed (transport, OAuth, or non-2xx).
92    #[error("failed to fetch config values: {0}")]
93    Fetch(#[from] ConfigClientError),
94    /// Underlying reqwest transport error (legacy variant kept for compat).
95    #[error("config fetch transport error: {0}")]
96    Request(#[from] reqwest::Error),
97    /// Serializing the partitioned config to JSON failed.
98    #[error("failed to serialize config values to JSON: {0}")]
99    Serialize(#[from] serde_json::Error),
100    /// AES-GCM encryption failed. In practice this only happens if the AEAD
101    /// implementation itself rejects the inputs — effectively unreachable.
102    #[error("aes-gcm encryption failed: {0}")]
103    Encrypt(String),
104}
105
106/// Fetch + encrypt config values for an environment.
107///
108/// Pulls all values via [`ConfigClient::get_all_values`], runs each through
109/// `options.classify` (default: everything goes into `public`), JSON-encodes
110/// the `{public, secret}` partition, and encrypts with a fresh random 32-byte
111/// AES-256 key + 12-byte nonce. Returns the ciphertext blob and the base64
112/// key so the caller can ship both.
113pub async fn build_bundle(options: BuildBundleOptions) -> Result<BuildBundleResult, BuildError> {
114    let BuildBundleOptions {
115        base_url,
116        auth_url,
117        client_id,
118        api_key,
119        org_id,
120        environment,
121        classify,
122    } = options;
123
124    let resolved_client_id = client_id.unwrap_or_else(|| api_key.clone());
125    // Apply the optional auth_url override so the runtime client's
126    // TokenProvider targets the test's mock issuer when supplied.
127    if let Some(url) = &auth_url {
128        std::env::set_var("SMOOAI_CONFIG_AUTH_URL", url);
129    }
130
131    let mut client = match &environment {
132        Some(env) => ConfigClient::with_environment(&base_url, &resolved_client_id, &api_key, &org_id, env),
133        None => ConfigClient::new(&base_url, &resolved_client_id, &api_key, &org_id),
134    };
135
136    let all = client.get_all_values(environment.as_deref()).await?;
137
138    let mut public_map: HashMap<String, Value> = HashMap::new();
139    let mut secret_map: HashMap<String, Value> = HashMap::new();
140    let mut skipped_count: usize = 0;
141
142    for (key, value) in all {
143        let section = match classify {
144            Some(ref f) => f(&key, &value),
145            None => Classification::Public,
146        };
147        match section {
148            Classification::Public => {
149                public_map.insert(key, value);
150            }
151            Classification::Secret => {
152                secret_map.insert(key, value);
153            }
154            Classification::Skip => {
155                skipped_count += 1;
156            }
157        }
158    }
159
160    let key_count = public_map.len() + secret_map.len();
161
162    // Serialize the partition with a stable shape that the hydrator can parse.
163    let partitioned = serde_json::json!({
164        "public": public_map,
165        "secret": secret_map,
166    });
167    let plaintext = serde_json::to_vec(&partitioned)?;
168
169    // Generate key and nonce.
170    let key_bytes: [u8; 32] = {
171        let k = Aes256Gcm::generate_key(&mut OsRng);
172        k.into()
173    };
174    let nonce_bytes: [u8; 12] = {
175        let n = Aes256Gcm::generate_nonce(&mut OsRng);
176        n.into()
177    };
178
179    let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key_bytes));
180    let nonce = Nonce::from_slice(&nonce_bytes);
181
182    let ciphertext_and_tag = cipher
183        .encrypt(
184            nonce,
185            Payload {
186                msg: &plaintext,
187                aad: &[],
188            },
189        )
190        .map_err(|e| BuildError::Encrypt(e.to_string()))?;
191
192    // Blob layout: nonce || ciphertext || authTag. aes-gcm returns ciphertext
193    // with the 16-byte tag already appended, matching the TS and Python wire
194    // format.
195    let mut blob = Vec::with_capacity(nonce_bytes.len() + ciphertext_and_tag.len());
196    blob.extend_from_slice(&nonce_bytes);
197    blob.extend_from_slice(&ciphertext_and_tag);
198
199    let size = blob.len() as u64;
200    let key_b64 = B64.encode(key_bytes);
201
202    Ok(BuildBundleResult {
203        key_b64,
204        blob,
205        size,
206        key_count,
207        skipped_count,
208    })
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use wiremock::matchers::{header, method, path_regex, query_param};
215    use wiremock::{Mock, MockServer, ResponseTemplate};
216
217    #[tokio::test]
218    async fn build_bundle_encrypts_and_reports_counts() {
219        let mock_server = MockServer::start().await;
220
221        // SMOODEV-975: OAuth handshake stub — mints "stub-jwt" which
222        // the values endpoint validates against below.
223        Mock::given(method("POST"))
224            .and(path_regex(r"^/token$"))
225            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
226                "access_token": "stub-jwt",
227                "expires_in": 3600
228            })))
229            .mount(&mock_server)
230            .await;
231        Mock::given(method("GET"))
232            .and(path_regex(r"/organizations/.+/config/values"))
233            .and(query_param("environment", "production"))
234            .and(header("Authorization", "Bearer stub-jwt"))
235            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
236                "values": {
237                    "apiUrl": "https://api.example.com",
238                    "tavilyApiKey": "tvly-abc",
239                    "newFlow": true,
240                }
241            })))
242            .mount(&mock_server)
243            .await;
244
245        let classify: Classifier = Box::new(|key, _v| match key {
246            "tavilyApiKey" => Classification::Secret,
247            "newFlow" => Classification::Skip,
248            _ => Classification::Public,
249        });
250
251        let result = build_bundle(BuildBundleOptions {
252            base_url: mock_server.uri(),
253            auth_url: Some(mock_server.uri()),
254            client_id: Some("test-api-key".to_string()),
255            api_key: "test-api-key".to_string(),
256            org_id: "test-org".to_string(),
257            environment: Some("production".to_string()),
258            classify: Some(classify),
259        })
260        .await
261        .unwrap();
262
263        assert_eq!(result.key_count, 2); // apiUrl + tavilyApiKey
264        assert_eq!(result.skipped_count, 1); // newFlow
265        assert!(result.blob.len() > 12 + 16); // nonce + tag at minimum
266        assert_eq!(result.size, result.blob.len() as u64);
267        // key_b64 decodes to exactly 32 bytes
268        let key = B64.decode(&result.key_b64).unwrap();
269        assert_eq!(key.len(), 32);
270    }
271
272    #[tokio::test]
273    async fn build_bundle_default_classifier_makes_everything_public() {
274        let mock_server = MockServer::start().await;
275
276        // SMOODEV-975: OAuth handshake stub.
277        Mock::given(method("POST"))
278            .and(path_regex(r"^/token$"))
279            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
280                "access_token": "stub-jwt",
281                "expires_in": 3600
282            })))
283            .mount(&mock_server)
284            .await;
285        Mock::given(method("GET"))
286            .and(path_regex(r"/organizations/.+/config/values"))
287            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
288                "values": {"FOO": "bar", "BAZ": 42}
289            })))
290            .mount(&mock_server)
291            .await;
292
293        let result = build_bundle(BuildBundleOptions {
294            base_url: mock_server.uri(),
295            auth_url: Some(mock_server.uri()),
296            client_id: Some("k".to_string()),
297            api_key: "k".to_string(),
298            org_id: "o".to_string(),
299            environment: Some("test".to_string()),
300            classify: None,
301        })
302        .await
303        .unwrap();
304
305        assert_eq!(result.key_count, 2);
306        assert_eq!(result.skipped_count, 0);
307    }
308}