Skip to main content

tracel_xtask/
environment.rs

1use std::{
2    collections::HashMap,
3    fmt::{self, Display, Write as _},
4    marker::PhantomData,
5    path::PathBuf,
6};
7
8use strum::{EnumIter, EnumString, IntoEnumIterator as _};
9
10use crate::{group_error, group_info, utils::git};
11
12/// Implicit index which means that index '1' is omitted in display.
13#[derive(Clone, Debug, PartialEq, Default)]
14pub struct ImplicitIndex;
15
16/// Explicit index which means that index is always in display.
17#[derive(Clone, Debug, PartialEq, Default)]
18pub struct ExplicitIndex;
19
20/// Style for how to format `{base}{index}`.
21pub trait IndexStyle {
22    fn format(base: &str, index: u8) -> String;
23}
24
25impl IndexStyle for ImplicitIndex {
26    fn format(base: &str, index: u8) -> String {
27        if index == 1 {
28            base.to_string()
29        } else {
30            format!("{base}{index}")
31        }
32    }
33}
34
35impl IndexStyle for ExplicitIndex {
36    fn format(base: &str, index: u8) -> String {
37        format!("{base}{index}")
38    }
39}
40
41#[derive(Clone, Debug, Default, PartialEq)]
42pub struct Environment<M = ImplicitIndex> {
43    pub name: EnvironmentName,
44    pub index: EnvironmentIndex,
45    _marker: PhantomData<M>,
46}
47
48impl<M> Environment<M> {
49    pub fn new(name: EnvironmentName, index: u8) -> Self {
50        Self {
51            name,
52            index: index.into(),
53            _marker: PhantomData,
54        }
55    }
56
57    pub fn index(&self) -> u8 {
58        self.index.index
59    }
60}
61
62impl Environment<ImplicitIndex> {
63    /// Turn an non explicit environment into an explicit one.
64    /// An explicit environment will always append the index number to its display names.
65    /// Whereas a non-explicit one (default) only append the index if it is different than 1.
66    pub fn into_explicit(self) -> Environment<ExplicitIndex> {
67        Environment {
68            name: self.name.clone(),
69            index: self.index().into(),
70            _marker: PhantomData,
71        }
72    }
73}
74
75impl Environment<ExplicitIndex> {
76    pub fn into_implicit(self) -> Environment<ImplicitIndex> {
77        Environment {
78            name: self.name.clone(),
79            index: self.index().into(),
80            _marker: PhantomData,
81        }
82    }
83}
84
85impl<M: IndexStyle> Display for Environment<M> {
86    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87        write!(f, "{}", self.medium())
88    }
89}
90
91impl<M: IndexStyle> Environment<M> {
92    pub fn long(&self) -> String {
93        M::format(self.name.long(), self.index())
94    }
95
96    pub fn medium(&self) -> String {
97        M::format(self.name.medium(), self.index())
98    }
99
100    pub fn short(&self) -> String {
101        M::format(&self.name.short().to_string(), self.index())
102    }
103
104    /// Return the two .env files for a given family:
105    /// - Base: `.env`, `.env.<env_medium>`
106    /// - Secrets: `.env.secrets`, `.env.<env_medium>.secrets`
107    /// - Infra: `.env.infra`, `.env.<env_medium>.infra`
108    /// - InfraSecrets: `.env.infra.secrets`, `.env.<env_medium>.infra.secrets`
109    fn dotenv_files_for_family(&self, family: DotEnvFamily) -> [String; 2] {
110        let suffix = family.to_string();
111        let env_medium = self.medium();
112        if suffix.is_empty() {
113            // Base
114            [".env".to_owned(), format!(".env.{env_medium}")]
115        } else {
116            // Other families
117            [
118                format!(".env{suffix}"),
119                format!(".env.{env_medium}{suffix}"),
120            ]
121        }
122    }
123
124    /// Backward-compatible helper for env-specific base filename.
125    pub fn get_dotenv_filename(&self) -> String {
126        // second element of the Base family
127        self.dotenv_files_for_family(DotEnvFamily::Base)[1].clone()
128    }
129
130    /// Backward-compatible helper for env-specific secrets filename.
131    pub fn get_dotenv_secrets_filename(&self) -> String {
132        // second element of the Secrets family
133        self.dotenv_files_for_family(DotEnvFamily::Secrets)[1].clone()
134    }
135
136    /// All possible .env files for this environment, by family.
137    /// Order matters: later files override earlier ones.
138    pub fn get_env_files(&self) -> Vec<String> {
139        DotEnvFamily::iter()
140            .flat_map(|family| self.dotenv_files_for_family(family))
141            .collect()
142    }
143
144    /// Load the .env environment files family.
145    pub fn load(&self, prefix: Option<&str>) -> anyhow::Result<()> {
146        let files = self.get_env_files();
147        for file in files {
148            let path = if let Some(p) = prefix {
149                PathBuf::from(p).join(&file)
150            } else {
151                PathBuf::from(&file)
152            };
153            if path.exists() {
154                match dotenvy::from_path(&path) {
155                    Ok(_) => {
156                        group_info!("loading '{}' file...", path.display());
157                    }
158                    Err(e) => {
159                        group_error!("error while loading '{}' file ({})", path.display(), e);
160                    }
161                }
162            }
163        }
164
165        Ok(())
166    }
167
168    /// Merge all the .env files of the environment with all variable expanded
169    pub fn merge_env_files(&self) -> anyhow::Result<PathBuf> {
170        let repo_root = git::git_repo_root_or_cwd()?;
171        let files = self.get_env_files();
172        // merged set of env vars, the later files override earlier ones
173        // we sort keys to have a more deterministic merged file result
174        let mut merged: HashMap<String, String> = HashMap::new();
175        for filename in files {
176            let path = repo_root.join(&filename);
177            if !path.exists() {
178                eprintln!(
179                    "⚠️ Warning: environment file '{}' ({}) not found, skipping...",
180                    filename,
181                    path.display()
182                );
183                continue;
184            }
185            for item in dotenvy::from_path_iter(&path)? {
186                let (key, value) = item?;
187                unsafe {
188                    std::env::set_var(&key, &value);
189                }
190                merged.insert(key, value);
191            }
192        }
193        let mut keys: Vec<_> = merged.keys().cloned().collect();
194        keys.sort();
195        // write merged file
196        let mut out = String::new();
197        for key in keys {
198            let val = &merged[&key];
199            writeln!(&mut out, "{key}={val}")?;
200        }
201        let tmp_path = std::env::temp_dir().join(format!("merged-env-{}.tmp", std::process::id()));
202        std::fs::write(&tmp_path, out)?;
203        Ok(tmp_path)
204    }
205}
206
207#[derive(EnumString, EnumIter, Default, Clone, Debug, PartialEq, clap::ValueEnum)]
208#[strum(serialize_all = "lowercase")]
209pub enum EnvironmentName {
210    /// Development environment (alias: dev).
211    #[default]
212    #[clap(alias = "dev")]
213    Development,
214    /// Staging environment (alias: stag).
215    #[clap(alias = "stag")]
216    Staging,
217    /// Testing environment (alias: test).
218    #[clap(alias = "test")]
219    Test,
220    /// Production environment (alias: prod).
221    #[clap(alias = "prod")]
222    Production,
223}
224
225impl Display for EnvironmentName {
226    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
227        write!(f, "{}", self.medium())
228    }
229}
230
231impl EnvironmentName {
232    pub fn long(&self) -> &'static str {
233        match self {
234            EnvironmentName::Development => "development",
235            EnvironmentName::Staging => "staging",
236            EnvironmentName::Test => "test",
237            EnvironmentName::Production => "production",
238        }
239    }
240
241    pub fn medium(&self) -> &'static str {
242        match self {
243            EnvironmentName::Development => "dev",
244            EnvironmentName::Staging => "stag",
245            EnvironmentName::Test => "test",
246            EnvironmentName::Production => "prod",
247        }
248    }
249
250    pub fn short(&self) -> char {
251        match self {
252            EnvironmentName::Development => 'd',
253            EnvironmentName::Staging => 's',
254            EnvironmentName::Test => 't',
255            EnvironmentName::Production => 'p',
256        }
257    }
258}
259
260#[derive(Clone, Debug, PartialEq)]
261pub struct EnvironmentIndex {
262    pub index: u8,
263}
264
265impl Default for EnvironmentIndex {
266    fn default() -> Self {
267        Self { index: 1 }
268    }
269}
270
271impl Display for EnvironmentIndex {
272    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
273        write!(f, "{}", self.index)
274    }
275}
276
277impl From<u8> for EnvironmentIndex {
278    fn from(index: u8) -> Self {
279        Self { index }
280    }
281}
282
283#[derive(EnumString, EnumIter, Clone, Debug, PartialEq, clap::ValueEnum)]
284enum DotEnvFamily {
285    Base,
286    Secrets,
287    Infra,
288    InfraSecrets,
289}
290
291impl Display for DotEnvFamily {
292    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
293        match self {
294            DotEnvFamily::Base => write!(f, ""),
295            DotEnvFamily::Secrets => write!(f, ".secrets"),
296            DotEnvFamily::Infra => write!(f, ".infra"),
297            DotEnvFamily::InfraSecrets => write!(f, ".infra.secrets"),
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use rstest::rstest;
306    use serial_test::serial;
307    use std::env;
308
309    // For tests we always use the implicit style
310    type TestEnv = Environment<ImplicitIndex>;
311
312    fn expected_vars(env: &TestEnv) -> Vec<(String, String)> {
313        let suffix = match env.name {
314            EnvironmentName::Development => "DEV",
315            EnvironmentName::Staging => "STAG",
316            EnvironmentName::Test => "TEST",
317            EnvironmentName::Production => "PROD",
318        };
319
320        vec![
321            ("FROM_DOTENV".to_string(), ".env".to_string()),
322            (
323                format!("FROM_DOTENV_{suffix}").to_string(),
324                env.get_dotenv_filename(),
325            ),
326            (
327                format!("FROM_DOTENV_{suffix}_SECRETS").to_string(),
328                env.get_dotenv_secrets_filename(),
329            ),
330        ]
331    }
332
333    #[rstest]
334    #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
335    #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
336    #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
337    #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
338    #[serial]
339    fn test_environment_load(#[case] env: TestEnv) {
340        // Remove possible prior values
341        for (key, _) in expected_vars(&env) {
342            unsafe {
343                env::remove_var(key);
344            }
345        }
346
347        // Run the actual function under test
348        env.load(Some("../.."))
349            .expect("Environment load should succeed");
350
351        // Assert each expected env var is present and has the correct value
352        for (key, expected_value) in expected_vars(&env) {
353            let actual_value =
354                env::var(&key).unwrap_or_else(|_| panic!("Missing expected env var: {key}"));
355            assert_eq!(
356                actual_value, expected_value,
357                "Environment variable {key} should be set to {expected_value} but was {actual_value}"
358            );
359        }
360    }
361
362    #[rstest]
363    #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
364    #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
365    #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
366    #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
367    #[serial]
368    fn test_environment_merge_env_files(#[case] env: TestEnv) {
369        // Make sure we start from a clean state
370        for (key, _) in expected_vars(&env) {
371            unsafe {
372                env::remove_var(key);
373            }
374        }
375        // Generate the merged env file
376        let merged_path = env
377            .merge_env_files()
378            .expect("merge_env_files should succeed");
379        assert!(
380            merged_path.exists(),
381            "Merged env file should exist at {}",
382            merged_path.display()
383        );
384        // Parse the merged file as a .env file again
385        let mut merged_map: std::collections::HashMap<String, String> =
386            std::collections::HashMap::new();
387        for item in
388            dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
389        {
390            let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
391            merged_map.insert(key, value);
392        }
393        // All the vars we expect from the individual files must be present
394        for (key, expected_value) in expected_vars(&env) {
395            let actual_value = merged_map
396                .get(&key)
397                .unwrap_or_else(|| panic!("Missing expected merged env var: {key}"));
398            assert_eq!(
399                actual_value, &expected_value,
400                "Merged env var {key} should be {expected_value} but was {actual_value}"
401            );
402        }
403    }
404
405    #[test]
406    #[serial]
407    fn test_environment_merge_env_files_expansion() {
408        let env = Environment::<ImplicitIndex>::new(EnvironmentName::Staging, 1);
409        // Clean any prior values that could interfere
410        unsafe {
411            env::remove_var("LOG_LEVEL_TEST");
412            env::remove_var("RUST_LOG_TEST");
413            env::remove_var("RUST_LOG_STAG_TEST");
414        }
415
416        let merged_path = env
417            .merge_env_files()
418            .expect("merge_env_files should succeed");
419        let mut merged_map: std::collections::HashMap<String, String> =
420            std::collections::HashMap::new();
421        for item in
422            dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
423        {
424            let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
425            merged_map.insert(key, value);
426        }
427
428        let log_level = merged_map
429            .get("LOG_LEVEL_TEST")
430            .expect("LOG_LEVEL_TEST should be present in merged env file");
431        let rust_log = merged_map
432            .get("RUST_LOG_TEST")
433            .expect("RUST_LOG_TEST should be present in merged env file");
434
435        // 1) We should not see the raw placeholder anymore
436        assert!(
437            !rust_log.contains("${LOG_LEVEL_TEST}"),
438            "RUST_LOG_TEST should not contain the raw placeholder '${{LOG_LEVEL}}', got: {rust_log}"
439        );
440        // 2) The expanded LOG_LEVEL_TEST value should appear in RUST_LOG_TEST
441        assert!(
442            rust_log.contains(log_level),
443            "RUST_LOG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_TEST={rust_log}"
444        );
445        // Cross-file expansion with RUST_LOG_STAG_TEST that references LOG_LEVEL_TEST from base .env
446        let rust_log_stag = merged_map
447            .get("RUST_LOG_STAG_TEST")
448            .expect("RUST_LOG_STAG_TEST should be present in merged env file");
449        // 3) No raw placeholder in the cross-file value either
450        assert!(
451            !rust_log_stag.contains("${LOG_LEVEL_TEST}"),
452            "RUST_LOG_STAG_TEST should not contain the raw placeholder '${{LOG_LEVEL_TEST}}', got: {rust_log_stag}"
453        );
454        // 4) The expanded LOG_LEVEL_TEST value should appear in RUST_LOG_STAG_TEST
455        assert!(
456            rust_log_stag.contains(log_level),
457            "RUST_LOG_STAG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_STAG_TEST={rust_log_stag}"
458        );
459    }
460}