systemprompt_template_provider/traits/
loader.rs1use std::path::Path;
2
3use async_trait::async_trait;
4use systemprompt_provider_contracts::TemplateSource;
5
6use super::error::{Result, TemplateLoaderError};
7
8#[cfg(feature = "tokio")]
9use std::io::ErrorKind;
10#[cfg(feature = "tokio")]
11use std::path::{Component, PathBuf};
12#[cfg(feature = "tokio")]
13use tokio::fs;
14
15#[async_trait]
19pub trait TemplateLoader: Send + Sync {
20 async fn load(&self, source: &TemplateSource) -> Result<String>;
21
22 fn can_load(&self, source: &TemplateSource) -> bool;
23
24 async fn load_directory(&self, _path: &Path) -> Result<Vec<(String, String)>> {
25 Err(TemplateLoaderError::DirectoryLoadingUnsupported)
26 }
27}
28
29#[cfg(feature = "tokio")]
30#[derive(Debug, Default)]
31pub struct FileSystemLoader {
32 base_paths: Vec<PathBuf>,
33}
34
35#[cfg(feature = "tokio")]
36impl FileSystemLoader {
37 #[must_use]
38 pub const fn new(base_paths: Vec<PathBuf>) -> Self {
39 Self { base_paths }
40 }
41
42 #[must_use]
43 pub fn with_path(path: impl Into<PathBuf>) -> Self {
44 Self {
45 base_paths: vec![path.into()],
46 }
47 }
48
49 #[must_use]
50 pub fn add_path(mut self, path: impl Into<PathBuf>) -> Self {
51 self.base_paths.push(path.into());
52 self
53 }
54
55 fn has_traversal_components(path: &Path) -> bool {
56 path.components().any(|c| matches!(c, Component::ParentDir))
57 }
58
59 async fn is_within_base_paths(&self, canonical: &Path) -> Result<bool> {
60 for base in &self.base_paths {
61 match fs::canonicalize(base).await {
62 Ok(canonical_base) if canonical.starts_with(&canonical_base) => return Ok(true),
63 Ok(_) => {},
64 Err(e) if e.kind() == ErrorKind::NotFound => {},
65 Err(e) => return Err(TemplateLoaderError::io(base, e)),
66 }
67 }
68 Ok(false)
69 }
70
71 async fn canonicalize_and_validate(&self, path: &Path) -> Result<PathBuf> {
72 let canonical = fs::canonicalize(path)
73 .await
74 .map_err(|e| TemplateLoaderError::io(path, e))?;
75
76 if !self.is_within_base_paths(&canonical).await? {
77 return Err(TemplateLoaderError::OutsideBasePath(path.to_path_buf()));
78 }
79
80 Ok(canonical)
81 }
82
83 async fn try_read_from_base(&self, base: &Path, relative: &Path) -> Option<Result<String>> {
84 let full_path = base.join(relative);
85
86 match fs::canonicalize(&full_path).await {
87 Ok(canonical) => {
88 let canonical_base = match fs::canonicalize(base).await {
89 Ok(cb) => cb,
90 Err(e) => return Some(Err(TemplateLoaderError::io(base, e))),
91 };
92
93 if !canonical.starts_with(&canonical_base) {
94 return Some(Err(TemplateLoaderError::OutsideBasePath(full_path)));
95 }
96
97 Some(
98 fs::read_to_string(&canonical)
99 .await
100 .map_err(|e| TemplateLoaderError::io(&full_path, e)),
101 )
102 },
103 Err(e) if e.kind() == ErrorKind::NotFound => None,
104 Err(e) => Some(Err(TemplateLoaderError::io(&full_path, e))),
105 }
106 }
107}
108
109#[cfg(feature = "tokio")]
110#[async_trait]
111impl TemplateLoader for FileSystemLoader {
112 async fn load(&self, source: &TemplateSource) -> Result<String> {
113 match source {
114 TemplateSource::Embedded(content) => Ok((*content).to_string()),
115 TemplateSource::File(path) => {
116 if Self::has_traversal_components(path) {
117 return Err(TemplateLoaderError::DirectoryTraversal(path.clone()));
118 }
119
120 if path.is_absolute() {
121 let canonical = self.canonicalize_and_validate(path).await?;
122 return fs::read_to_string(&canonical)
123 .await
124 .map_err(|e| TemplateLoaderError::io(path, e));
125 }
126
127 if self.base_paths.is_empty() {
128 return Err(TemplateLoaderError::NoBasePaths);
129 }
130
131 for base in &self.base_paths {
132 if let Some(result) = self.try_read_from_base(base, path).await {
133 return result;
134 }
135 }
136
137 Err(TemplateLoaderError::NotFound(path.clone()))
138 },
139 TemplateSource::Directory(path) => {
140 Err(TemplateLoaderError::DirectoryNotSupported(path.clone()))
141 },
142 }
143 }
144
145 fn can_load(&self, source: &TemplateSource) -> bool {
146 matches!(
147 source,
148 TemplateSource::Embedded(_) | TemplateSource::File(_)
149 )
150 }
151
152 async fn load_directory(&self, path: &Path) -> Result<Vec<(String, String)>> {
153 if Self::has_traversal_components(path) {
154 return Err(TemplateLoaderError::DirectoryTraversal(path.to_path_buf()));
155 }
156
157 if self.base_paths.is_empty() {
158 return Err(TemplateLoaderError::NoBasePaths);
159 }
160
161 let dir_path = if path.is_absolute() {
162 self.canonicalize_and_validate(path).await?
163 } else {
164 let mut found_path = None;
165 for base in &self.base_paths {
166 let candidate = base.join(path);
167 match fs::canonicalize(&candidate).await {
168 Ok(canonical) => {
169 let canonical_base = fs::canonicalize(base)
170 .await
171 .map_err(|e| TemplateLoaderError::io(base, e))?;
172
173 if !canonical.starts_with(&canonical_base) {
174 return Err(TemplateLoaderError::OutsideBasePath(candidate));
175 }
176
177 found_path = Some(canonical);
178 break;
179 },
180 Err(e) if e.kind() == ErrorKind::NotFound => {},
181 Err(e) => return Err(TemplateLoaderError::io(&candidate, e)),
182 }
183 }
184 found_path.ok_or_else(|| TemplateLoaderError::NotFound(path.to_path_buf()))?
185 };
186
187 let mut templates = Vec::new();
188 let mut entries = fs::read_dir(&dir_path)
189 .await
190 .map_err(|e| TemplateLoaderError::io(&dir_path, e))?;
191
192 while let Some(entry) = entries
193 .next_entry()
194 .await
195 .map_err(|e| TemplateLoaderError::io(&dir_path, e))?
196 {
197 let entry_path = entry.path();
198
199 if entry_path.extension().is_some_and(|ext| ext == "html") {
200 let Some(file_stem) = entry_path.file_stem() else {
201 continue;
202 };
203
204 let template_name = file_stem
205 .to_str()
206 .ok_or_else(|| TemplateLoaderError::InvalidEncoding(entry_path.clone()))?
207 .to_owned();
208
209 let content = fs::read_to_string(&entry_path)
210 .await
211 .map_err(|e| TemplateLoaderError::io(&entry_path, e))?;
212
213 templates.push((template_name, content));
214 }
215 }
216
217 Ok(templates)
218 }
219}
220
221#[derive(Debug, Default, Clone, Copy)]
222pub struct EmbeddedLoader;
223
224#[async_trait]
225impl TemplateLoader for EmbeddedLoader {
226 async fn load(&self, source: &TemplateSource) -> Result<String> {
227 match source {
228 TemplateSource::Embedded(content) => Ok((*content).to_string()),
229 _ => Err(TemplateLoaderError::EmbeddedOnly),
230 }
231 }
232
233 fn can_load(&self, source: &TemplateSource) -> bool {
234 matches!(source, TemplateSource::Embedded(_))
235 }
236}