1use std::collections::HashMap;
21
22use aes_gcm::aead::{Aead, KeyInit, OsRng, Payload};
23use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
24use base64::engine::general_purpose::STANDARD as B64;
25use base64::Engine as _;
26use serde_json::Value;
27use thiserror::Error;
28
29use crate::client::{ConfigClient, ConfigClientError};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Classification {
34 Public,
36 Secret,
38 Skip,
40}
41
42pub type Classifier = Box<dyn Fn(&str, &Value) -> Classification + Send + Sync>;
44
45pub struct BuildBundleOptions {
48 pub base_url: String,
50 pub auth_url: Option<String>,
54 pub client_id: Option<String>,
58 pub api_key: String,
61 pub org_id: String,
63 pub environment: Option<String>,
66 pub classify: Option<Classifier>,
70}
71
72#[derive(Debug)]
74pub struct BuildBundleResult {
75 pub key_b64: String,
77 pub blob: Vec<u8>,
80 pub size: u64,
82 pub key_count: usize,
84 pub skipped_count: usize,
86}
87
88#[derive(Debug, Error)]
90pub enum BuildError {
91 #[error("failed to fetch config values: {0}")]
93 Fetch(#[from] ConfigClientError),
94 #[error("config fetch transport error: {0}")]
96 Request(#[from] reqwest::Error),
97 #[error("failed to serialize config values to JSON: {0}")]
99 Serialize(#[from] serde_json::Error),
100 #[error("aes-gcm encryption failed: {0}")]
103 Encrypt(String),
104}
105
106pub async fn build_bundle(options: BuildBundleOptions) -> Result<BuildBundleResult, BuildError> {
114 let BuildBundleOptions {
115 base_url,
116 auth_url,
117 client_id,
118 api_key,
119 org_id,
120 environment,
121 classify,
122 } = options;
123
124 let resolved_client_id = client_id.unwrap_or_else(|| api_key.clone());
125 if let Some(url) = &auth_url {
128 std::env::set_var("SMOOAI_CONFIG_AUTH_URL", url);
129 }
130
131 let mut client = match &environment {
132 Some(env) => ConfigClient::with_environment(&base_url, &resolved_client_id, &api_key, &org_id, env),
133 None => ConfigClient::new(&base_url, &resolved_client_id, &api_key, &org_id),
134 };
135
136 let all = client.get_all_values(environment.as_deref()).await?;
137
138 let mut public_map: HashMap<String, Value> = HashMap::new();
139 let mut secret_map: HashMap<String, Value> = HashMap::new();
140 let mut skipped_count: usize = 0;
141
142 for (key, value) in all {
143 let section = match classify {
144 Some(ref f) => f(&key, &value),
145 None => Classification::Public,
146 };
147 match section {
148 Classification::Public => {
149 public_map.insert(key, value);
150 }
151 Classification::Secret => {
152 secret_map.insert(key, value);
153 }
154 Classification::Skip => {
155 skipped_count += 1;
156 }
157 }
158 }
159
160 let key_count = public_map.len() + secret_map.len();
161
162 let partitioned = serde_json::json!({
164 "public": public_map,
165 "secret": secret_map,
166 });
167 let plaintext = serde_json::to_vec(&partitioned)?;
168
169 let key_bytes: [u8; 32] = {
171 let k = Aes256Gcm::generate_key(&mut OsRng);
172 k.into()
173 };
174 let nonce_bytes: [u8; 12] = {
175 let n = Aes256Gcm::generate_nonce(&mut OsRng);
176 n.into()
177 };
178
179 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key_bytes));
180 let nonce = Nonce::from_slice(&nonce_bytes);
181
182 let ciphertext_and_tag = cipher
183 .encrypt(
184 nonce,
185 Payload {
186 msg: &plaintext,
187 aad: &[],
188 },
189 )
190 .map_err(|e| BuildError::Encrypt(e.to_string()))?;
191
192 let mut blob = Vec::with_capacity(nonce_bytes.len() + ciphertext_and_tag.len());
196 blob.extend_from_slice(&nonce_bytes);
197 blob.extend_from_slice(&ciphertext_and_tag);
198
199 let size = blob.len() as u64;
200 let key_b64 = B64.encode(key_bytes);
201
202 Ok(BuildBundleResult {
203 key_b64,
204 blob,
205 size,
206 key_count,
207 skipped_count,
208 })
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use wiremock::matchers::{header, method, path_regex, query_param};
215 use wiremock::{Mock, MockServer, ResponseTemplate};
216
217 #[tokio::test]
218 async fn build_bundle_encrypts_and_reports_counts() {
219 let mock_server = MockServer::start().await;
220
221 Mock::given(method("POST"))
224 .and(path_regex(r"^/token$"))
225 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
226 "access_token": "stub-jwt",
227 "expires_in": 3600
228 })))
229 .mount(&mock_server)
230 .await;
231 Mock::given(method("GET"))
232 .and(path_regex(r"/organizations/.+/config/values"))
233 .and(query_param("environment", "production"))
234 .and(header("Authorization", "Bearer stub-jwt"))
235 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
236 "values": {
237 "apiUrl": "https://api.example.com",
238 "tavilyApiKey": "tvly-abc",
239 "newFlow": true,
240 }
241 })))
242 .mount(&mock_server)
243 .await;
244
245 let classify: Classifier = Box::new(|key, _v| match key {
246 "tavilyApiKey" => Classification::Secret,
247 "newFlow" => Classification::Skip,
248 _ => Classification::Public,
249 });
250
251 let result = build_bundle(BuildBundleOptions {
252 base_url: mock_server.uri(),
253 auth_url: Some(mock_server.uri()),
254 client_id: Some("test-api-key".to_string()),
255 api_key: "test-api-key".to_string(),
256 org_id: "test-org".to_string(),
257 environment: Some("production".to_string()),
258 classify: Some(classify),
259 })
260 .await
261 .unwrap();
262
263 assert_eq!(result.key_count, 2); assert_eq!(result.skipped_count, 1); assert!(result.blob.len() > 12 + 16); assert_eq!(result.size, result.blob.len() as u64);
267 let key = B64.decode(&result.key_b64).unwrap();
269 assert_eq!(key.len(), 32);
270 }
271
272 #[tokio::test]
273 async fn build_bundle_default_classifier_makes_everything_public() {
274 let mock_server = MockServer::start().await;
275
276 Mock::given(method("POST"))
278 .and(path_regex(r"^/token$"))
279 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
280 "access_token": "stub-jwt",
281 "expires_in": 3600
282 })))
283 .mount(&mock_server)
284 .await;
285 Mock::given(method("GET"))
286 .and(path_regex(r"/organizations/.+/config/values"))
287 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
288 "values": {"FOO": "bar", "BAZ": 42}
289 })))
290 .mount(&mock_server)
291 .await;
292
293 let result = build_bundle(BuildBundleOptions {
294 base_url: mock_server.uri(),
295 auth_url: Some(mock_server.uri()),
296 client_id: Some("k".to_string()),
297 api_key: "k".to_string(),
298 org_id: "o".to_string(),
299 environment: Some("test".to_string()),
300 classify: None,
301 })
302 .await
303 .unwrap();
304
305 assert_eq!(result.key_count, 2);
306 assert_eq!(result.skipped_count, 0);
307 }
308}