1use std::{
2 collections::HashMap,
3 fmt::{self, Display, Write as _},
4 marker::PhantomData,
5 path::PathBuf,
6};
7
8use strum::{EnumIter, EnumString, IntoEnumIterator as _};
9
10use crate::{group_error, group_info, utils::git};
11
12#[derive(Clone, Debug, PartialEq, Default)]
14pub struct ImplicitIndex;
15
16#[derive(Clone, Debug, PartialEq, Default)]
18pub struct ExplicitIndex;
19
20pub trait IndexStyle {
22 fn format(base: &str, index: u8) -> String;
23}
24
25impl IndexStyle for ImplicitIndex {
26 fn format(base: &str, index: u8) -> String {
27 if index == 1 {
28 base.to_string()
29 } else {
30 format!("{base}{index}")
31 }
32 }
33}
34
35impl IndexStyle for ExplicitIndex {
36 fn format(base: &str, index: u8) -> String {
37 format!("{base}{index}")
38 }
39}
40
41#[derive(Clone, Debug, Default, PartialEq)]
42pub struct Environment<M = ImplicitIndex> {
43 pub name: EnvironmentName,
44 pub index: EnvironmentIndex,
45 _marker: PhantomData<M>,
46}
47
48impl<M> Environment<M> {
49 pub fn new(name: EnvironmentName, index: u8) -> Self {
50 Self {
51 name,
52 index: index.into(),
53 _marker: PhantomData,
54 }
55 }
56
57 pub fn index(&self) -> u8 {
58 self.index.index
59 }
60}
61
62impl Environment<ImplicitIndex> {
63 pub fn into_explicit(self) -> Environment<ExplicitIndex> {
67 Environment {
68 name: self.name.clone(),
69 index: self.index().into(),
70 _marker: PhantomData,
71 }
72 }
73}
74
75impl Environment<ExplicitIndex> {
76 pub fn into_implicit(self) -> Environment<ImplicitIndex> {
77 Environment {
78 name: self.name.clone(),
79 index: self.index().into(),
80 _marker: PhantomData,
81 }
82 }
83}
84
85impl<M: IndexStyle> Display for Environment<M> {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 write!(f, "{}", self.medium())
88 }
89}
90
91impl<M: IndexStyle> Environment<M> {
92 pub fn long(&self) -> String {
93 M::format(self.name.long(), self.index())
94 }
95
96 pub fn medium(&self) -> String {
97 M::format(self.name.medium(), self.index())
98 }
99
100 pub fn short(&self) -> String {
101 M::format(&self.name.short().to_string(), self.index())
102 }
103
104 fn dotenv_files_for_family(&self, family: DotEnvFamily) -> [String; 2] {
110 let suffix = family.to_string();
111 let env_medium = self.medium();
112 if suffix.is_empty() {
113 [".env".to_owned(), format!(".env.{env_medium}")]
115 } else {
116 [
118 format!(".env{suffix}"),
119 format!(".env.{env_medium}{suffix}"),
120 ]
121 }
122 }
123
124 pub fn get_dotenv_filename(&self) -> String {
126 self.dotenv_files_for_family(DotEnvFamily::Base)[1].clone()
128 }
129
130 pub fn get_dotenv_secrets_filename(&self) -> String {
132 self.dotenv_files_for_family(DotEnvFamily::Secrets)[1].clone()
134 }
135
136 pub fn get_env_files(&self) -> Vec<String> {
139 DotEnvFamily::iter()
140 .flat_map(|family| self.dotenv_files_for_family(family))
141 .collect()
142 }
143
144 pub fn load(&self, prefix: Option<&str>) -> anyhow::Result<()> {
146 let files = self.get_env_files();
147 for file in files {
148 let path = if let Some(p) = prefix {
149 PathBuf::from(p).join(&file)
150 } else {
151 PathBuf::from(&file)
152 };
153 if path.exists() {
154 match dotenvy::from_path(&path) {
155 Ok(_) => {
156 group_info!("loading '{}' file...", path.display());
157 }
158 Err(e) => {
159 group_error!("error while loading '{}' file ({})", path.display(), e);
160 }
161 }
162 }
163 }
164
165 Ok(())
166 }
167
168 pub fn merge_env_files(&self) -> anyhow::Result<PathBuf> {
170 let repo_root = git::git_repo_root_or_cwd()?;
171 let files = self.get_env_files();
172 let mut merged: HashMap<String, String> = HashMap::new();
175 for filename in files {
176 let path = repo_root.join(&filename);
177 if !path.exists() {
178 eprintln!(
179 "⚠️ Warning: environment file '{}' ({}) not found, skipping...",
180 filename,
181 path.display()
182 );
183 continue;
184 }
185 for item in dotenvy::from_path_iter(&path)? {
186 let (key, value) = item?;
187 unsafe {
188 std::env::set_var(&key, &value);
189 }
190 merged.insert(key, value);
191 }
192 }
193 let mut keys: Vec<_> = merged.keys().cloned().collect();
194 keys.sort();
195 let mut out = String::new();
197 for key in keys {
198 let val = &merged[&key];
199 writeln!(&mut out, "{key}={val}")?;
200 }
201 let tmp_path = std::env::temp_dir().join(format!("merged-env-{}.tmp", std::process::id()));
202 std::fs::write(&tmp_path, out)?;
203 Ok(tmp_path)
204 }
205}
206
207#[derive(EnumString, EnumIter, Default, Clone, Debug, PartialEq, clap::ValueEnum)]
208#[strum(serialize_all = "lowercase")]
209pub enum EnvironmentName {
210 #[default]
212 #[clap(alias = "dev")]
213 Development,
214 #[clap(alias = "stag")]
216 Staging,
217 #[clap(alias = "test")]
219 Test,
220 #[clap(alias = "prod")]
222 Production,
223}
224
225impl Display for EnvironmentName {
226 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
227 write!(f, "{}", self.medium())
228 }
229}
230
231impl EnvironmentName {
232 pub fn long(&self) -> &'static str {
233 match self {
234 EnvironmentName::Development => "development",
235 EnvironmentName::Staging => "staging",
236 EnvironmentName::Test => "test",
237 EnvironmentName::Production => "production",
238 }
239 }
240
241 pub fn medium(&self) -> &'static str {
242 match self {
243 EnvironmentName::Development => "dev",
244 EnvironmentName::Staging => "stag",
245 EnvironmentName::Test => "test",
246 EnvironmentName::Production => "prod",
247 }
248 }
249
250 pub fn short(&self) -> char {
251 match self {
252 EnvironmentName::Development => 'd',
253 EnvironmentName::Staging => 's',
254 EnvironmentName::Test => 't',
255 EnvironmentName::Production => 'p',
256 }
257 }
258}
259
260#[derive(Clone, Debug, PartialEq)]
261pub struct EnvironmentIndex {
262 pub index: u8,
263}
264
265impl Default for EnvironmentIndex {
266 fn default() -> Self {
267 Self { index: 1 }
268 }
269}
270
271impl Display for EnvironmentIndex {
272 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
273 write!(f, "{}", self.index)
274 }
275}
276
277impl From<u8> for EnvironmentIndex {
278 fn from(index: u8) -> Self {
279 Self { index }
280 }
281}
282
283#[derive(EnumString, EnumIter, Clone, Debug, PartialEq, clap::ValueEnum)]
284enum DotEnvFamily {
285 Base,
286 Secrets,
287 Infra,
288 InfraSecrets,
289}
290
291impl Display for DotEnvFamily {
292 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
293 match self {
294 DotEnvFamily::Base => write!(f, ""),
295 DotEnvFamily::Secrets => write!(f, ".secrets"),
296 DotEnvFamily::Infra => write!(f, ".infra"),
297 DotEnvFamily::InfraSecrets => write!(f, ".infra.secrets"),
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use rstest::rstest;
306 use serial_test::serial;
307 use std::env;
308
309 type TestEnv = Environment<ImplicitIndex>;
311
312 fn expected_vars(env: &TestEnv) -> Vec<(String, String)> {
313 let suffix = match env.name {
314 EnvironmentName::Development => "DEV",
315 EnvironmentName::Staging => "STAG",
316 EnvironmentName::Test => "TEST",
317 EnvironmentName::Production => "PROD",
318 };
319
320 vec![
321 ("FROM_DOTENV".to_string(), ".env".to_string()),
322 (
323 format!("FROM_DOTENV_{suffix}").to_string(),
324 env.get_dotenv_filename(),
325 ),
326 (
327 format!("FROM_DOTENV_{suffix}_SECRETS").to_string(),
328 env.get_dotenv_secrets_filename(),
329 ),
330 ]
331 }
332
333 #[rstest]
334 #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
335 #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
336 #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
337 #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
338 #[serial]
339 fn test_environment_load(#[case] env: TestEnv) {
340 for (key, _) in expected_vars(&env) {
342 unsafe {
343 env::remove_var(key);
344 }
345 }
346
347 env.load(Some("../.."))
349 .expect("Environment load should succeed");
350
351 for (key, expected_value) in expected_vars(&env) {
353 let actual_value =
354 env::var(&key).unwrap_or_else(|_| panic!("Missing expected env var: {key}"));
355 assert_eq!(
356 actual_value, expected_value,
357 "Environment variable {key} should be set to {expected_value} but was {actual_value}"
358 );
359 }
360 }
361
362 #[rstest]
363 #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
364 #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
365 #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
366 #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
367 #[serial]
368 fn test_environment_merge_env_files(#[case] env: TestEnv) {
369 for (key, _) in expected_vars(&env) {
371 unsafe {
372 env::remove_var(key);
373 }
374 }
375 let merged_path = env
377 .merge_env_files()
378 .expect("merge_env_files should succeed");
379 assert!(
380 merged_path.exists(),
381 "Merged env file should exist at {}",
382 merged_path.display()
383 );
384 let mut merged_map: std::collections::HashMap<String, String> =
386 std::collections::HashMap::new();
387 for item in
388 dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
389 {
390 let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
391 merged_map.insert(key, value);
392 }
393 for (key, expected_value) in expected_vars(&env) {
395 let actual_value = merged_map
396 .get(&key)
397 .unwrap_or_else(|| panic!("Missing expected merged env var: {key}"));
398 assert_eq!(
399 actual_value, &expected_value,
400 "Merged env var {key} should be {expected_value} but was {actual_value}"
401 );
402 }
403 }
404
405 #[test]
406 #[serial]
407 fn test_environment_merge_env_files_expansion() {
408 let env = Environment::<ImplicitIndex>::new(EnvironmentName::Staging, 1);
409 unsafe {
411 env::remove_var("LOG_LEVEL_TEST");
412 env::remove_var("RUST_LOG_TEST");
413 env::remove_var("RUST_LOG_STAG_TEST");
414 }
415
416 let merged_path = env
417 .merge_env_files()
418 .expect("merge_env_files should succeed");
419 let mut merged_map: std::collections::HashMap<String, String> =
420 std::collections::HashMap::new();
421 for item in
422 dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
423 {
424 let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
425 merged_map.insert(key, value);
426 }
427
428 let log_level = merged_map
429 .get("LOG_LEVEL_TEST")
430 .expect("LOG_LEVEL_TEST should be present in merged env file");
431 let rust_log = merged_map
432 .get("RUST_LOG_TEST")
433 .expect("RUST_LOG_TEST should be present in merged env file");
434
435 assert!(
437 !rust_log.contains("${LOG_LEVEL_TEST}"),
438 "RUST_LOG_TEST should not contain the raw placeholder '${{LOG_LEVEL}}', got: {rust_log}"
439 );
440 assert!(
442 rust_log.contains(log_level),
443 "RUST_LOG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_TEST={rust_log}"
444 );
445 let rust_log_stag = merged_map
447 .get("RUST_LOG_STAG_TEST")
448 .expect("RUST_LOG_STAG_TEST should be present in merged env file");
449 assert!(
451 !rust_log_stag.contains("${LOG_LEVEL_TEST}"),
452 "RUST_LOG_STAG_TEST should not contain the raw placeholder '${{LOG_LEVEL_TEST}}', got: {rust_log_stag}"
453 );
454 assert!(
456 rust_log_stag.contains(log_level),
457 "RUST_LOG_STAG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_STAG_TEST={rust_log_stag}"
458 );
459 }
460}