Skip to main content

systemprompt_template_provider/traits/
loader.rs

1use std::path::Path;
2
3use async_trait::async_trait;
4use systemprompt_provider_contracts::TemplateSource;
5
6use super::error::{Result, TemplateLoaderError};
7
8#[cfg(feature = "tokio")]
9use std::io::ErrorKind;
10#[cfg(feature = "tokio")]
11use std::path::{Component, PathBuf};
12#[cfg(feature = "tokio")]
13use tokio::fs;
14
15#[async_trait]
16pub trait TemplateLoader: Send + Sync {
17    async fn load(&self, source: &TemplateSource) -> Result<String>;
18
19    fn can_load(&self, source: &TemplateSource) -> bool;
20
21    async fn load_directory(&self, _path: &Path) -> Result<Vec<(String, String)>> {
22        Err(TemplateLoaderError::DirectoryLoadingUnsupported)
23    }
24}
25
26#[cfg(feature = "tokio")]
27#[derive(Debug, Default)]
28pub struct FileSystemLoader {
29    base_paths: Vec<PathBuf>,
30}
31
32#[cfg(feature = "tokio")]
33impl FileSystemLoader {
34    #[must_use]
35    pub const fn new(base_paths: Vec<PathBuf>) -> Self {
36        Self { base_paths }
37    }
38
39    #[must_use]
40    pub fn with_path(path: impl Into<PathBuf>) -> Self {
41        Self {
42            base_paths: vec![path.into()],
43        }
44    }
45
46    #[must_use]
47    pub fn add_path(mut self, path: impl Into<PathBuf>) -> Self {
48        self.base_paths.push(path.into());
49        self
50    }
51
52    fn has_traversal_components(path: &Path) -> bool {
53        path.components().any(|c| matches!(c, Component::ParentDir))
54    }
55
56    async fn is_within_base_paths(&self, canonical: &Path) -> Result<bool> {
57        for base in &self.base_paths {
58            match fs::canonicalize(base).await {
59                Ok(canonical_base) if canonical.starts_with(&canonical_base) => return Ok(true),
60                Ok(_) => {},
61                Err(e) if e.kind() == ErrorKind::NotFound => {},
62                Err(e) => return Err(TemplateLoaderError::io(base, e)),
63            }
64        }
65        Ok(false)
66    }
67
68    async fn canonicalize_and_validate(&self, path: &Path) -> Result<PathBuf> {
69        let canonical = fs::canonicalize(path)
70            .await
71            .map_err(|e| TemplateLoaderError::io(path, e))?;
72
73        if !self.is_within_base_paths(&canonical).await? {
74            return Err(TemplateLoaderError::OutsideBasePath(path.to_path_buf()));
75        }
76
77        Ok(canonical)
78    }
79
80    async fn try_read_from_base(&self, base: &Path, relative: &Path) -> Option<Result<String>> {
81        let full_path = base.join(relative);
82
83        match fs::canonicalize(&full_path).await {
84            Ok(canonical) => {
85                let canonical_base = match fs::canonicalize(base).await {
86                    Ok(cb) => cb,
87                    Err(e) => return Some(Err(TemplateLoaderError::io(base, e))),
88                };
89
90                if !canonical.starts_with(&canonical_base) {
91                    return Some(Err(TemplateLoaderError::OutsideBasePath(full_path)));
92                }
93
94                Some(
95                    fs::read_to_string(&canonical)
96                        .await
97                        .map_err(|e| TemplateLoaderError::io(&full_path, e)),
98                )
99            },
100            Err(e) if e.kind() == ErrorKind::NotFound => None,
101            Err(e) => Some(Err(TemplateLoaderError::io(&full_path, e))),
102        }
103    }
104}
105
106#[cfg(feature = "tokio")]
107#[async_trait]
108impl TemplateLoader for FileSystemLoader {
109    async fn load(&self, source: &TemplateSource) -> Result<String> {
110        match source {
111            TemplateSource::Embedded(content) => Ok((*content).to_string()),
112            TemplateSource::File(path) => {
113                if Self::has_traversal_components(path) {
114                    return Err(TemplateLoaderError::DirectoryTraversal(path.clone()));
115                }
116
117                if path.is_absolute() {
118                    let canonical = self.canonicalize_and_validate(path).await?;
119                    return fs::read_to_string(&canonical)
120                        .await
121                        .map_err(|e| TemplateLoaderError::io(path, e));
122                }
123
124                if self.base_paths.is_empty() {
125                    return Err(TemplateLoaderError::NoBasePaths);
126                }
127
128                for base in &self.base_paths {
129                    if let Some(result) = self.try_read_from_base(base, path).await {
130                        return result;
131                    }
132                }
133
134                Err(TemplateLoaderError::NotFound(path.clone()))
135            },
136            TemplateSource::Directory(path) => {
137                Err(TemplateLoaderError::DirectoryNotSupported(path.clone()))
138            },
139        }
140    }
141
142    fn can_load(&self, source: &TemplateSource) -> bool {
143        matches!(
144            source,
145            TemplateSource::Embedded(_) | TemplateSource::File(_)
146        )
147    }
148
149    async fn load_directory(&self, path: &Path) -> Result<Vec<(String, String)>> {
150        if Self::has_traversal_components(path) {
151            return Err(TemplateLoaderError::DirectoryTraversal(path.to_path_buf()));
152        }
153
154        if self.base_paths.is_empty() {
155            return Err(TemplateLoaderError::NoBasePaths);
156        }
157
158        let dir_path = if path.is_absolute() {
159            self.canonicalize_and_validate(path).await?
160        } else {
161            let mut found_path = None;
162            for base in &self.base_paths {
163                let candidate = base.join(path);
164                match fs::canonicalize(&candidate).await {
165                    Ok(canonical) => {
166                        let canonical_base = fs::canonicalize(base)
167                            .await
168                            .map_err(|e| TemplateLoaderError::io(base, e))?;
169
170                        if !canonical.starts_with(&canonical_base) {
171                            return Err(TemplateLoaderError::OutsideBasePath(candidate));
172                        }
173
174                        found_path = Some(canonical);
175                        break;
176                    },
177                    Err(e) if e.kind() == ErrorKind::NotFound => {},
178                    Err(e) => return Err(TemplateLoaderError::io(&candidate, e)),
179                }
180            }
181            found_path.ok_or_else(|| TemplateLoaderError::NotFound(path.to_path_buf()))?
182        };
183
184        let mut templates = Vec::new();
185        let mut entries = fs::read_dir(&dir_path)
186            .await
187            .map_err(|e| TemplateLoaderError::io(&dir_path, e))?;
188
189        while let Some(entry) = entries
190            .next_entry()
191            .await
192            .map_err(|e| TemplateLoaderError::io(&dir_path, e))?
193        {
194            let entry_path = entry.path();
195
196            if entry_path.extension().is_some_and(|ext| ext == "html") {
197                let Some(file_stem) = entry_path.file_stem() else {
198                    continue;
199                };
200
201                let template_name = file_stem
202                    .to_str()
203                    .ok_or_else(|| TemplateLoaderError::InvalidEncoding(entry_path.clone()))?
204                    .to_owned();
205
206                let content = fs::read_to_string(&entry_path)
207                    .await
208                    .map_err(|e| TemplateLoaderError::io(&entry_path, e))?;
209
210                templates.push((template_name, content));
211            }
212        }
213
214        Ok(templates)
215    }
216}
217
218#[derive(Debug, Default, Clone, Copy)]
219pub struct EmbeddedLoader;
220
221#[async_trait]
222impl TemplateLoader for EmbeddedLoader {
223    async fn load(&self, source: &TemplateSource) -> Result<String> {
224        match source {
225            TemplateSource::Embedded(content) => Ok((*content).to_string()),
226            _ => Err(TemplateLoaderError::EmbeddedOnly),
227        }
228    }
229
230    fn can_load(&self, source: &TemplateSource) -> bool {
231        matches!(source, TemplateSource::Embedded(_))
232    }
233}