1use std::{
2 collections::HashMap,
3 fmt::{self, Display, Write as _},
4 marker::PhantomData,
5 path::PathBuf,
6};
7
8use strum::{EnumIter, EnumString};
9
10use crate::{group_error, group_info, utils::git};
11
12#[derive(Clone, Debug, PartialEq, Default)]
14pub struct ImplicitIndex;
15
16#[derive(Clone, Debug, PartialEq, Default)]
18pub struct ExplicitIndex;
19
20pub trait IndexStyle {
22 fn format(base: &str, index: u8) -> String;
23}
24
25impl IndexStyle for ImplicitIndex {
26 fn format(base: &str, index: u8) -> String {
27 if index == 1 {
28 base.to_string()
29 } else {
30 format!("{base}{index}")
31 }
32 }
33}
34
35impl IndexStyle for ExplicitIndex {
36 fn format(base: &str, index: u8) -> String {
37 format!("{base}{index}")
38 }
39}
40
41#[derive(Clone, Debug, Default, PartialEq)]
42pub struct Environment<M = ImplicitIndex> {
43 pub name: EnvironmentName,
44 pub index: EnvironmentIndex,
45 _marker: PhantomData<M>,
46}
47
48impl<M> Environment<M> {
49 pub fn new(name: EnvironmentName, index: u8) -> Self {
50 Self {
51 name,
52 index: index.into(),
53 _marker: PhantomData,
54 }
55 }
56
57 pub fn index(&self) -> u8 {
58 self.index.index
59 }
60}
61
62impl Environment<ImplicitIndex> {
63 pub fn into_explicit(self) -> Environment<ExplicitIndex> {
67 Environment {
68 name: self.name.clone(),
69 index: self.index().into(),
70 _marker: PhantomData,
71 }
72 }
73}
74
75impl Environment<ExplicitIndex> {
76 pub fn into_implicit(self) -> Environment<ImplicitIndex> {
77 Environment {
78 name: self.name.clone(),
79 index: self.index().into(),
80 _marker: PhantomData,
81 }
82 }
83}
84
85impl<M: IndexStyle> Display for Environment<M> {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 write!(f, "{}", self.medium())
88 }
89}
90
91impl<M: IndexStyle> Environment<M> {
92 pub fn long(&self) -> String {
93 M::format(self.name.long(), self.index())
94 }
95
96 pub fn medium(&self) -> String {
97 M::format(self.name.medium(), self.index())
98 }
99
100 pub fn short(&self) -> String {
101 M::format(&self.name.short().to_string(), self.index())
102 }
103
104 pub fn get_dotenv_filename(&self) -> String {
105 format!(".env.{self}")
106 }
107
108 pub fn get_dotenv_secrets_filename(&self) -> String {
109 format!("{}.secrets", self.get_dotenv_filename())
110 }
111
112 pub fn get_env_files(&self) -> [String; 3] {
113 let filename = self.get_dotenv_filename();
114 let secrets_filename = self.get_dotenv_secrets_filename();
115 [
116 ".env".to_owned(),
117 filename.to_owned(),
118 secrets_filename.to_owned(),
119 ]
120 }
121
122 pub fn load(&self, prefix: Option<&str>) -> anyhow::Result<()> {
124 let files = self.get_env_files();
125 files.iter().for_each(|f| {
126 let path = if let Some(p) = prefix {
127 std::path::PathBuf::from(p).join(f)
128 } else {
129 std::path::PathBuf::from(f)
130 };
131 if path.exists() {
132 match dotenvy::from_filename(f) {
133 Ok(_) => {
134 group_info!("loading '{}' file...", f);
135 }
136 Err(e) => {
137 group_error!("error while loading '{}' file ({})", f, e);
138 }
139 }
140 }
141 });
142 Ok(())
143 }
144
145 pub fn merge_env_files(&self) -> anyhow::Result<PathBuf> {
147 let repo_root = git::git_repo_root_or_cwd()?;
148 let files = self.get_env_files();
149 let mut merged: HashMap<String, String> = HashMap::new();
152 for filename in files {
153 let path = repo_root.join(&filename);
154 if !path.exists() {
155 eprintln!(
156 "⚠️ Warning: environment file '{}' ({}) not found, skipping...",
157 filename,
158 path.display()
159 );
160 continue;
161 }
162 for item in dotenvy::from_path_iter(&path)? {
163 let (key, value) = item?;
164 std::env::set_var(&key, &value);
165 merged.insert(key, value);
166 }
167 }
168 let mut keys: Vec<_> = merged.keys().cloned().collect();
169 keys.sort();
170 let mut out = String::new();
172 for key in keys {
173 let val = &merged[&key];
174 writeln!(&mut out, "{key}={val}")?;
175 }
176 let tmp_path = std::env::temp_dir().join(format!("merged-env-{}.tmp", std::process::id()));
177 std::fs::write(&tmp_path, out)?;
178 Ok(tmp_path)
179 }
180}
181
182#[derive(EnumString, EnumIter, Default, Clone, Debug, PartialEq, clap::ValueEnum)]
183#[strum(serialize_all = "lowercase")]
184pub enum EnvironmentName {
185 #[default]
187 #[clap(alias = "dev")]
188 Development,
189 #[clap(alias = "stag")]
191 Staging,
192 #[clap(alias = "test")]
194 Test,
195 #[clap(alias = "prod")]
197 Production,
198}
199
200impl Display for EnvironmentName {
201 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202 write!(f, "{}", self.medium())
203 }
204}
205
206impl EnvironmentName {
207 pub fn long(&self) -> &'static str {
208 match self {
209 EnvironmentName::Development => "development",
210 EnvironmentName::Staging => "staging",
211 EnvironmentName::Test => "test",
212 EnvironmentName::Production => "production",
213 }
214 }
215
216 pub fn medium(&self) -> &'static str {
217 match self {
218 EnvironmentName::Development => "dev",
219 EnvironmentName::Staging => "stag",
220 EnvironmentName::Test => "test",
221 EnvironmentName::Production => "prod",
222 }
223 }
224
225 pub fn short(&self) -> char {
226 match self {
227 EnvironmentName::Development => 'd',
228 EnvironmentName::Staging => 's',
229 EnvironmentName::Test => 't',
230 EnvironmentName::Production => 'p',
231 }
232 }
233}
234
235#[derive(Clone, Debug, PartialEq)]
236pub struct EnvironmentIndex {
237 pub index: u8,
238}
239
240impl Default for EnvironmentIndex {
241 fn default() -> Self {
242 Self { index: 1 }
243 }
244}
245
246impl Display for EnvironmentIndex {
247 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248 write!(f, "{}", self.index)
249 }
250}
251
252impl From<u8> for EnvironmentIndex {
253 fn from(index: u8) -> Self {
254 Self { index }
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use rstest::rstest;
262 use serial_test::serial;
263 use std::env;
264
265 type TestEnv = Environment<ImplicitIndex>;
267
268 fn expected_vars(env: &TestEnv) -> Vec<(String, String)> {
269 let suffix = match env.name {
270 EnvironmentName::Development => "DEV",
271 EnvironmentName::Staging => "STAG",
272 EnvironmentName::Test => "TEST",
273 EnvironmentName::Production => "PROD",
274 };
275
276 vec![
277 ("FROM_DOTENV".to_string(), ".env".to_string()),
278 (
279 format!("FROM_DOTENV_{suffix}").to_string(),
280 env.get_dotenv_filename(),
281 ),
282 (
283 format!("FROM_DOTENV_{suffix}_SECRETS").to_string(),
284 env.get_dotenv_secrets_filename(),
285 ),
286 ]
287 }
288
289 #[rstest]
290 #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
291 #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
292 #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
293 #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
294 #[serial]
295 fn test_environment_load(#[case] env: TestEnv) {
296 for (key, _) in expected_vars(&env) {
298 env::remove_var(key);
299 }
300
301 env.load(Some("../.."))
303 .expect("Environment load should succeed");
304
305 for (key, expected_value) in expected_vars(&env) {
307 let actual_value =
308 env::var(&key).unwrap_or_else(|_| panic!("Missing expected env var: {key}"));
309 assert_eq!(
310 actual_value, expected_value,
311 "Environment variable {key} should be set to {expected_value} but was {actual_value}"
312 );
313 }
314 }
315
316 #[rstest]
317 #[case::dev(TestEnv::new(EnvironmentName::Development, 1))]
318 #[case::stag(TestEnv::new(EnvironmentName::Staging, 1))]
319 #[case::test(TestEnv::new(EnvironmentName::Test, 1))]
320 #[case::prod(TestEnv::new(EnvironmentName::Production, 1))]
321 #[serial]
322 fn test_environment_merge_env_files(#[case] env: TestEnv) {
323 for (key, _) in expected_vars(&env) {
325 env::remove_var(key);
326 }
327 let merged_path = env
329 .merge_env_files()
330 .expect("merge_env_files should succeed");
331 assert!(
332 merged_path.exists(),
333 "Merged env file should exist at {}",
334 merged_path.display()
335 );
336 let mut merged_map: std::collections::HashMap<String, String> =
338 std::collections::HashMap::new();
339 for item in
340 dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
341 {
342 let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
343 merged_map.insert(key, value);
344 }
345 for (key, expected_value) in expected_vars(&env) {
347 let actual_value = merged_map
348 .get(&key)
349 .unwrap_or_else(|| panic!("Missing expected merged env var: {key}"));
350 assert_eq!(
351 actual_value, &expected_value,
352 "Merged env var {key} should be {expected_value} but was {actual_value}"
353 );
354 }
355 }
356
357 #[test]
358 #[serial]
359 fn test_environment_merge_env_files_expansion() {
360 let env = Environment::<ImplicitIndex>::new(EnvironmentName::Staging, 1);
361 env::remove_var("LOG_LEVEL_TEST");
363 env::remove_var("RUST_LOG_TEST");
364 env::remove_var("RUST_LOG_STAG_TEST");
365
366 let merged_path = env
367 .merge_env_files()
368 .expect("merge_env_files should succeed");
369 let mut merged_map: std::collections::HashMap<String, String> =
370 std::collections::HashMap::new();
371 for item in
372 dotenvy::from_path_iter(&merged_path).expect("Reading merged env file should succeed")
373 {
374 let (key, value) = item.expect("Parsing key/value from merged env file should succeed");
375 merged_map.insert(key, value);
376 }
377
378 let log_level = merged_map
379 .get("LOG_LEVEL_TEST")
380 .expect("LOG_LEVEL_TEST should be present in merged env file");
381 let rust_log = merged_map
382 .get("RUST_LOG_TEST")
383 .expect("RUST_LOG_TEST should be present in merged env file");
384
385 assert!(
387 !rust_log.contains("${LOG_LEVEL_TEST}"),
388 "RUST_LOG_TEST should not contain the raw placeholder '${{LOG_LEVEL}}', got: {rust_log}"
389 );
390 assert!(
392 rust_log.contains(log_level),
393 "RUST_LOG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_TEST={rust_log}"
394 );
395 let rust_log_stag = merged_map
397 .get("RUST_LOG_STAG_TEST")
398 .expect("RUST_LOG_STAG_TEST should be present in merged env file");
399 assert!(
401 !rust_log_stag.contains("${LOG_LEVEL_TEST}"),
402 "RUST_LOG_STAG_TEST should not contain the raw placeholder '${{LOG_LEVEL_TEST}}', got: {rust_log_stag}"
403 );
404 assert!(
406 rust_log_stag.contains(log_level),
407 "RUST_LOG_STAG_TEST should contain the expanded LOG_LEVEL_TEST value; LOG_LEVEL_TEST={log_level}, RUST_LOG_STAG_TEST={rust_log_stag}"
408 );
409 }
410}