Skip to main content

systemprompt_template_provider/traits/
loader.rs

1use std::path::Path;
2
3use async_trait::async_trait;
4use systemprompt_provider_contracts::TemplateSource;
5
6use super::error::{Result, TemplateLoaderError};
7
8#[cfg(feature = "tokio")]
9use std::io::ErrorKind;
10#[cfg(feature = "tokio")]
11use std::path::{Component, PathBuf};
12#[cfg(feature = "tokio")]
13use tokio::fs;
14
15/// `#[async_trait]` is required: this trait is consumed as
16/// `Arc<dyn TemplateLoader>` (see `DynTemplateLoader`), so it must be
17/// `dyn`-compatible — native `async fn` in traits is not.
18#[async_trait]
19pub trait TemplateLoader: Send + Sync {
20    async fn load(&self, source: &TemplateSource) -> Result<String>;
21
22    fn can_load(&self, source: &TemplateSource) -> bool;
23
24    async fn load_directory(&self, _path: &Path) -> Result<Vec<(String, String)>> {
25        Err(TemplateLoaderError::DirectoryLoadingUnsupported)
26    }
27}
28
29#[cfg(feature = "tokio")]
30#[derive(Debug, Default)]
31pub struct FileSystemLoader {
32    base_paths: Vec<PathBuf>,
33}
34
35#[cfg(feature = "tokio")]
36impl FileSystemLoader {
37    #[must_use]
38    pub const fn new(base_paths: Vec<PathBuf>) -> Self {
39        Self { base_paths }
40    }
41
42    #[must_use]
43    pub fn with_path(path: impl Into<PathBuf>) -> Self {
44        Self {
45            base_paths: vec![path.into()],
46        }
47    }
48
49    #[must_use]
50    pub fn add_path(mut self, path: impl Into<PathBuf>) -> Self {
51        self.base_paths.push(path.into());
52        self
53    }
54
55    fn has_traversal_components(path: &Path) -> bool {
56        path.components().any(|c| matches!(c, Component::ParentDir))
57    }
58
59    async fn is_within_base_paths(&self, canonical: &Path) -> Result<bool> {
60        for base in &self.base_paths {
61            match fs::canonicalize(base).await {
62                Ok(canonical_base) if canonical.starts_with(&canonical_base) => return Ok(true),
63                Ok(_) => {},
64                Err(e) if e.kind() == ErrorKind::NotFound => {},
65                Err(e) => return Err(TemplateLoaderError::io(base, e)),
66            }
67        }
68        Ok(false)
69    }
70
71    async fn canonicalize_and_validate(&self, path: &Path) -> Result<PathBuf> {
72        let canonical = fs::canonicalize(path)
73            .await
74            .map_err(|e| TemplateLoaderError::io(path, e))?;
75
76        if !self.is_within_base_paths(&canonical).await? {
77            return Err(TemplateLoaderError::OutsideBasePath(path.to_path_buf()));
78        }
79
80        Ok(canonical)
81    }
82
83    async fn try_read_from_base(&self, base: &Path, relative: &Path) -> Option<Result<String>> {
84        let full_path = base.join(relative);
85
86        match fs::canonicalize(&full_path).await {
87            Ok(canonical) => {
88                let canonical_base = match fs::canonicalize(base).await {
89                    Ok(cb) => cb,
90                    Err(e) => return Some(Err(TemplateLoaderError::io(base, e))),
91                };
92
93                if !canonical.starts_with(&canonical_base) {
94                    return Some(Err(TemplateLoaderError::OutsideBasePath(full_path)));
95                }
96
97                Some(
98                    fs::read_to_string(&canonical)
99                        .await
100                        .map_err(|e| TemplateLoaderError::io(&full_path, e)),
101                )
102            },
103            Err(e) if e.kind() == ErrorKind::NotFound => None,
104            Err(e) => Some(Err(TemplateLoaderError::io(&full_path, e))),
105        }
106    }
107}
108
109#[cfg(feature = "tokio")]
110#[async_trait]
111impl TemplateLoader for FileSystemLoader {
112    async fn load(&self, source: &TemplateSource) -> Result<String> {
113        match source {
114            TemplateSource::Embedded(content) => Ok((*content).to_string()),
115            TemplateSource::File(path) => {
116                if Self::has_traversal_components(path) {
117                    return Err(TemplateLoaderError::DirectoryTraversal(path.clone()));
118                }
119
120                if path.is_absolute() {
121                    let canonical = self.canonicalize_and_validate(path).await?;
122                    return fs::read_to_string(&canonical)
123                        .await
124                        .map_err(|e| TemplateLoaderError::io(path, e));
125                }
126
127                if self.base_paths.is_empty() {
128                    return Err(TemplateLoaderError::NoBasePaths);
129                }
130
131                for base in &self.base_paths {
132                    if let Some(result) = self.try_read_from_base(base, path).await {
133                        return result;
134                    }
135                }
136
137                Err(TemplateLoaderError::NotFound(path.clone()))
138            },
139            TemplateSource::Directory(path) => {
140                Err(TemplateLoaderError::DirectoryNotSupported(path.clone()))
141            },
142        }
143    }
144
145    fn can_load(&self, source: &TemplateSource) -> bool {
146        matches!(
147            source,
148            TemplateSource::Embedded(_) | TemplateSource::File(_)
149        )
150    }
151
152    async fn load_directory(&self, path: &Path) -> Result<Vec<(String, String)>> {
153        if Self::has_traversal_components(path) {
154            return Err(TemplateLoaderError::DirectoryTraversal(path.to_path_buf()));
155        }
156
157        if self.base_paths.is_empty() {
158            return Err(TemplateLoaderError::NoBasePaths);
159        }
160
161        let dir_path = if path.is_absolute() {
162            self.canonicalize_and_validate(path).await?
163        } else {
164            let mut found_path = None;
165            for base in &self.base_paths {
166                let candidate = base.join(path);
167                match fs::canonicalize(&candidate).await {
168                    Ok(canonical) => {
169                        let canonical_base = fs::canonicalize(base)
170                            .await
171                            .map_err(|e| TemplateLoaderError::io(base, e))?;
172
173                        if !canonical.starts_with(&canonical_base) {
174                            return Err(TemplateLoaderError::OutsideBasePath(candidate));
175                        }
176
177                        found_path = Some(canonical);
178                        break;
179                    },
180                    Err(e) if e.kind() == ErrorKind::NotFound => {},
181                    Err(e) => return Err(TemplateLoaderError::io(&candidate, e)),
182                }
183            }
184            found_path.ok_or_else(|| TemplateLoaderError::NotFound(path.to_path_buf()))?
185        };
186
187        let mut templates = Vec::new();
188        let mut entries = fs::read_dir(&dir_path)
189            .await
190            .map_err(|e| TemplateLoaderError::io(&dir_path, e))?;
191
192        while let Some(entry) = entries
193            .next_entry()
194            .await
195            .map_err(|e| TemplateLoaderError::io(&dir_path, e))?
196        {
197            let entry_path = entry.path();
198
199            if entry_path.extension().is_some_and(|ext| ext == "html") {
200                let Some(file_stem) = entry_path.file_stem() else {
201                    continue;
202                };
203
204                let template_name = file_stem
205                    .to_str()
206                    .ok_or_else(|| TemplateLoaderError::InvalidEncoding(entry_path.clone()))?
207                    .to_owned();
208
209                let content = fs::read_to_string(&entry_path)
210                    .await
211                    .map_err(|e| TemplateLoaderError::io(&entry_path, e))?;
212
213                templates.push((template_name, content));
214            }
215        }
216
217        Ok(templates)
218    }
219}
220
221#[derive(Debug, Default, Clone, Copy)]
222pub struct EmbeddedLoader;
223
224#[async_trait]
225impl TemplateLoader for EmbeddedLoader {
226    async fn load(&self, source: &TemplateSource) -> Result<String> {
227        match source {
228            TemplateSource::Embedded(content) => Ok((*content).to_string()),
229            _ => Err(TemplateLoaderError::EmbeddedOnly),
230        }
231    }
232
233    fn can_load(&self, source: &TemplateSource) -> bool {
234        matches!(source, TemplateSource::Embedded(_))
235    }
236}