1use std::path::Path;
2
3use anyhow::{Context, Result};
4use rust_embed::Embed;
5use tera::Tera;
6
7#[derive(Embed)]
8#[folder = "templates/"]
9pub struct TemplateAssets;
10
11pub fn load_templates() -> Result<Tera> {
13 load_templates_with_custom_dir(None)
14}
15
16pub fn load_templates_with_custom_dir(custom_dir: Option<&Path>) -> Result<Tera> {
19 let mut tera = Tera::default();
20
21 for file in TemplateAssets::iter() {
22 let path = file.as_ref();
23 if let Some(content) = TemplateAssets::get(path) {
24 let text = std::str::from_utf8(content.data.as_ref())?;
25 tera.add_raw_template(path, text)?;
26 }
27 }
28
29 if let Some(dir) = custom_dir {
30 if dir.is_dir() {
31 load_custom_dir(&mut tera, dir)?;
32 }
33 } else if let Some(default_dir) = dirs_default_templates()
34 && default_dir.is_dir()
35 {
36 load_custom_dir(&mut tera, &default_dir)?;
37 }
38
39 Ok(tera)
40}
41
42fn dirs_default_templates() -> Option<std::path::PathBuf> {
43 std::env::var("HOME")
44 .ok()
45 .map(|h| std::path::PathBuf::from(h).join(".ward").join("templates"))
46}
47
48fn load_custom_dir(tera: &mut Tera, dir: &Path) -> Result<()> {
49 walk_dir(tera, dir, dir)
50}
51
52fn walk_dir(tera: &mut Tera, base: &Path, current: &Path) -> Result<()> {
53 for entry in std::fs::read_dir(current)
54 .with_context(|| format!("Failed to read custom template dir: {}", current.display()))?
55 {
56 let entry = entry?;
57 let path = entry.path();
58
59 if path.is_dir() {
60 walk_dir(tera, base, &path)?;
61 } else if path.is_file() {
62 let rel = path.strip_prefix(base).with_context(|| {
63 format!("Failed to compute relative path for {}", path.display())
64 })?;
65 let name = rel.to_string_lossy();
66 let text = std::fs::read_to_string(&path)
67 .with_context(|| format!("Failed to read custom template: {}", path.display()))?;
68 tera.add_raw_template(&name, &text)?;
69 }
70 }
71 Ok(())
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 #[test]
79 fn loads_embedded_templates() {
80 let tera = load_templates().unwrap();
81 let names: Vec<&str> = tera.get_template_names().collect();
82 assert!(names.iter().any(|n| n.contains("dependabot")));
83 assert!(names.iter().any(|n| n.contains("codeql")));
84 assert!(names.iter().any(|n| n.contains("dependency-submission")));
85 assert!(names.iter().any(|n| n.contains("copilot-review")));
86 }
87
88 #[test]
89 fn loads_with_nonexistent_custom_dir() {
90 let tera =
91 load_templates_with_custom_dir(Some(Path::new("/nonexistent/path/templates"))).unwrap();
92 let names: Vec<&str> = tera.get_template_names().collect();
93 assert!(names.iter().any(|n| n.contains("dependabot")));
94 }
95
96 #[test]
97 fn custom_dir_overrides_embedded() {
98 let dir = tempfile::tempdir().unwrap();
99 let sub = dir.path().join("dependabot");
100 std::fs::create_dir_all(&sub).unwrap();
101 std::fs::write(
102 sub.join("gradle.yml.tera"),
103 "custom: {{ custom_var | default(value='hello') }}",
104 )
105 .unwrap();
106
107 let tera = load_templates_with_custom_dir(Some(dir.path())).unwrap();
108 let ctx = tera::Context::new();
109 let result = tera.render("dependabot/gradle.yml.tera", &ctx).unwrap();
110 assert_eq!(result, "custom: hello");
111 }
112
113 #[test]
114 fn custom_dir_adds_new_templates() {
115 let dir = tempfile::tempdir().unwrap();
116 std::fs::write(dir.path().join("my-custom.tera"), "hello world").unwrap();
117
118 let tera = load_templates_with_custom_dir(Some(dir.path())).unwrap();
119 let names: Vec<&str> = tera.get_template_names().collect();
120 assert!(names.contains(&"my-custom.tera"));
121
122 let ctx = tera::Context::new();
123 let result = tera.render("my-custom.tera", &ctx).unwrap();
124 assert_eq!(result, "hello world");
125 }
126
127 #[test]
128 fn custom_dir_walks_recursively() {
129 let dir = tempfile::tempdir().unwrap();
130 let nested = dir.path().join("a").join("b");
131 std::fs::create_dir_all(&nested).unwrap();
132 std::fs::write(nested.join("deep.tera"), "deep template").unwrap();
133
134 let tera = load_templates_with_custom_dir(Some(dir.path())).unwrap();
135 let ctx = tera::Context::new();
136 let result = tera.render("a/b/deep.tera", &ctx).unwrap();
137 assert_eq!(result, "deep template");
138 }
139
140 #[test]
141 fn rendered_templates_produce_valid_yaml() {
142 let tera = load_templates().unwrap();
143 let mut ctx = tera::Context::new();
144 ctx.insert("java_version", "21");
145 ctx.insert("node_version", "20");
146 ctx.insert("default_branch", "main");
147 ctx.insert("registry_url", "https://example.com/maven");
148 ctx.insert("jfrog_oidc_provider", "test-provider");
149
150 for name in tera.get_template_names() {
151 if name.ends_with(".yml.tera") {
152 let rendered = tera.render(name, &ctx).unwrap();
153 let parsed: Result<serde_yaml::Value, _> = serde_yaml::from_str(&rendered);
154 assert!(
155 parsed.is_ok(),
156 "Template {name} produced invalid YAML: {}",
157 parsed.unwrap_err()
158 );
159 }
160 }
161 }
162
163 #[test]
164 fn renders_dependabot_gradle_template() {
165 let tera = load_templates().unwrap();
166 let mut ctx = tera::Context::new();
167 ctx.insert("registry_url", "https://example.com/maven");
168 ctx.insert("jfrog_oidc_provider", "my-provider");
169 let result = tera.render("dependabot/gradle.yml.tera", &ctx).unwrap();
170 assert!(result.contains("https://example.com/maven"));
171 assert!(result.contains("my-provider"));
172 assert!(result.contains("package-ecosystem: gradle"));
173 }
174
175 #[test]
176 fn renders_dependabot_gradle_with_defaults() {
177 let tera = load_templates().unwrap();
178 let ctx = tera::Context::new();
179 let result = tera.render("dependabot/gradle.yml.tera", &ctx).unwrap();
180 assert!(result.contains("https://repo.maven.apache.org/maven2"));
181 assert!(!result.contains("jfrog-oidc-provider-name"));
182 }
183
184 #[test]
185 fn renders_codeql_gradle_template() {
186 let tera = load_templates().unwrap();
187 let mut ctx = tera::Context::new();
188 ctx.insert("java_version", "17");
189 let result = tera.render("codeql/gradle.yml.tera", &ctx).unwrap();
190 assert!(result.contains("JAVA_VERSION: 17"));
191 assert!(result.contains("java-kotlin"));
192 }
193
194 #[test]
195 fn renders_codeql_npm_template() {
196 let tera = load_templates().unwrap();
197 let mut ctx = tera::Context::new();
198 ctx.insert("node_version", "20");
199 let result = tera.render("codeql/npm.yml.tera", &ctx).unwrap();
200 assert!(result.contains("NODE_VERSION: 20"));
201 assert!(result.contains("javascript-typescript"));
202 }
203
204 #[test]
205 fn renders_dependency_submission_template() {
206 let tera = load_templates().unwrap();
207 let mut ctx = tera::Context::new();
208 ctx.insert("java_version", "21");
209 ctx.insert("default_branch", "main");
210 let result = tera
211 .render("dependency-submission/gradle.yml.tera", &ctx)
212 .unwrap();
213 assert!(result.contains("JAVA_VERSION: 21"));
214 assert!(result.contains("branches:"));
215 }
216}