Skip to main content

ward/config/
templates.rs

1use std::path::Path;
2
3use anyhow::{Context, Result};
4use rust_embed::Embed;
5use tera::Tera;
6
7#[derive(Embed)]
8#[folder = "templates/"]
9pub struct TemplateAssets;
10
11/// Load embedded templates into a Tera instance.
12pub fn load_templates() -> Result<Tera> {
13    load_templates_with_custom_dir(None)
14}
15
16/// Load embedded templates, then overlay templates from a custom directory.
17/// Custom templates with the same name override embedded ones.
18pub fn load_templates_with_custom_dir(custom_dir: Option<&Path>) -> Result<Tera> {
19    let mut tera = Tera::default();
20
21    for file in TemplateAssets::iter() {
22        let path = file.as_ref();
23        if let Some(content) = TemplateAssets::get(path) {
24            let text = std::str::from_utf8(content.data.as_ref())?;
25            tera.add_raw_template(path, text)?;
26        }
27    }
28
29    if let Some(dir) = custom_dir {
30        if dir.is_dir() {
31            load_custom_dir(&mut tera, dir)?;
32        }
33    } else if let Some(default_dir) = dirs_default_templates()
34        && default_dir.is_dir()
35    {
36        load_custom_dir(&mut tera, &default_dir)?;
37    }
38
39    Ok(tera)
40}
41
42fn dirs_default_templates() -> Option<std::path::PathBuf> {
43    std::env::var("HOME")
44        .ok()
45        .map(|h| std::path::PathBuf::from(h).join(".ward").join("templates"))
46}
47
48fn load_custom_dir(tera: &mut Tera, dir: &Path) -> Result<()> {
49    walk_dir(tera, dir, dir)
50}
51
52fn walk_dir(tera: &mut Tera, base: &Path, current: &Path) -> Result<()> {
53    for entry in std::fs::read_dir(current)
54        .with_context(|| format!("Failed to read custom template dir: {}", current.display()))?
55    {
56        let entry = entry?;
57        let path = entry.path();
58
59        if path.is_dir() {
60            walk_dir(tera, base, &path)?;
61        } else if path.is_file() {
62            let rel = path.strip_prefix(base).with_context(|| {
63                format!("Failed to compute relative path for {}", path.display())
64            })?;
65            let name = rel.to_string_lossy();
66            let text = std::fs::read_to_string(&path)
67                .with_context(|| format!("Failed to read custom template: {}", path.display()))?;
68            tera.add_raw_template(&name, &text)?;
69        }
70    }
71    Ok(())
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn loads_embedded_templates() {
80        let tera = load_templates().unwrap();
81        let names: Vec<&str> = tera.get_template_names().collect();
82        assert!(names.iter().any(|n| n.contains("dependabot")));
83        assert!(names.iter().any(|n| n.contains("codeql")));
84        assert!(names.iter().any(|n| n.contains("dependency-submission")));
85        assert!(names.iter().any(|n| n.contains("copilot-review")));
86    }
87
88    #[test]
89    fn loads_with_nonexistent_custom_dir() {
90        let tera =
91            load_templates_with_custom_dir(Some(Path::new("/nonexistent/path/templates"))).unwrap();
92        let names: Vec<&str> = tera.get_template_names().collect();
93        assert!(names.iter().any(|n| n.contains("dependabot")));
94    }
95
96    #[test]
97    fn custom_dir_overrides_embedded() {
98        let dir = tempfile::tempdir().unwrap();
99        let sub = dir.path().join("dependabot");
100        std::fs::create_dir_all(&sub).unwrap();
101        std::fs::write(
102            sub.join("gradle.yml.tera"),
103            "custom: {{ custom_var | default(value='hello') }}",
104        )
105        .unwrap();
106
107        let tera = load_templates_with_custom_dir(Some(dir.path())).unwrap();
108        let ctx = tera::Context::new();
109        let result = tera.render("dependabot/gradle.yml.tera", &ctx).unwrap();
110        assert_eq!(result, "custom: hello");
111    }
112
113    #[test]
114    fn custom_dir_adds_new_templates() {
115        let dir = tempfile::tempdir().unwrap();
116        std::fs::write(dir.path().join("my-custom.tera"), "hello world").unwrap();
117
118        let tera = load_templates_with_custom_dir(Some(dir.path())).unwrap();
119        let names: Vec<&str> = tera.get_template_names().collect();
120        assert!(names.contains(&"my-custom.tera"));
121
122        let ctx = tera::Context::new();
123        let result = tera.render("my-custom.tera", &ctx).unwrap();
124        assert_eq!(result, "hello world");
125    }
126
127    #[test]
128    fn custom_dir_walks_recursively() {
129        let dir = tempfile::tempdir().unwrap();
130        let nested = dir.path().join("a").join("b");
131        std::fs::create_dir_all(&nested).unwrap();
132        std::fs::write(nested.join("deep.tera"), "deep template").unwrap();
133
134        let tera = load_templates_with_custom_dir(Some(dir.path())).unwrap();
135        let ctx = tera::Context::new();
136        let result = tera.render("a/b/deep.tera", &ctx).unwrap();
137        assert_eq!(result, "deep template");
138    }
139
140    #[test]
141    fn rendered_templates_produce_valid_yaml() {
142        let tera = load_templates().unwrap();
143        let mut ctx = tera::Context::new();
144        ctx.insert("java_version", "21");
145        ctx.insert("node_version", "20");
146        ctx.insert("default_branch", "main");
147        ctx.insert("registry_url", "https://example.com/maven");
148        ctx.insert("jfrog_oidc_provider", "test-provider");
149
150        for name in tera.get_template_names() {
151            if name.ends_with(".yml.tera") {
152                let rendered = tera.render(name, &ctx).unwrap();
153                let parsed: Result<serde_yaml::Value, _> = serde_yaml::from_str(&rendered);
154                assert!(
155                    parsed.is_ok(),
156                    "Template {name} produced invalid YAML: {}",
157                    parsed.unwrap_err()
158                );
159            }
160        }
161    }
162
163    #[test]
164    fn renders_dependabot_gradle_template() {
165        let tera = load_templates().unwrap();
166        let mut ctx = tera::Context::new();
167        ctx.insert("registry_url", "https://example.com/maven");
168        ctx.insert("jfrog_oidc_provider", "my-provider");
169        let result = tera.render("dependabot/gradle.yml.tera", &ctx).unwrap();
170        assert!(result.contains("https://example.com/maven"));
171        assert!(result.contains("my-provider"));
172        assert!(result.contains("package-ecosystem: gradle"));
173    }
174
175    #[test]
176    fn renders_dependabot_gradle_with_defaults() {
177        let tera = load_templates().unwrap();
178        let ctx = tera::Context::new();
179        let result = tera.render("dependabot/gradle.yml.tera", &ctx).unwrap();
180        assert!(result.contains("https://repo.maven.apache.org/maven2"));
181        assert!(!result.contains("jfrog-oidc-provider-name"));
182    }
183
184    #[test]
185    fn renders_codeql_gradle_template() {
186        let tera = load_templates().unwrap();
187        let mut ctx = tera::Context::new();
188        ctx.insert("java_version", "17");
189        let result = tera.render("codeql/gradle.yml.tera", &ctx).unwrap();
190        assert!(result.contains("JAVA_VERSION: 17"));
191        assert!(result.contains("java-kotlin"));
192    }
193
194    #[test]
195    fn renders_codeql_npm_template() {
196        let tera = load_templates().unwrap();
197        let mut ctx = tera::Context::new();
198        ctx.insert("node_version", "20");
199        let result = tera.render("codeql/npm.yml.tera", &ctx).unwrap();
200        assert!(result.contains("NODE_VERSION: 20"));
201        assert!(result.contains("javascript-typescript"));
202    }
203
204    #[test]
205    fn renders_dependency_submission_template() {
206        let tera = load_templates().unwrap();
207        let mut ctx = tera::Context::new();
208        ctx.insert("java_version", "21");
209        ctx.insert("default_branch", "main");
210        let result = tera
211            .render("dependency-submission/gradle.yml.tera", &ctx)
212            .unwrap();
213        assert!(result.contains("JAVA_VERSION: 21"));
214        assert!(result.contains("branches:"));
215    }
216}