1use std::fs;
9use std::io;
10use std::path::{Component, Path, PathBuf};
11
12use serde::{Deserialize, Serialize};
13use sha2::{Digest, Sha256};
14use thiserror::Error;
15
16use crate::types::{CaseFingerprint, Invocation};
17
18#[derive(Debug, Clone, Default, Serialize)]
20pub struct FingerprintContext {
21 pub initial_session: Option<serde_json::Value>,
23 pub tool_set_hash: Option<String>,
25 pub agent_model: Option<String>,
27}
28
29#[derive(Debug, Serialize)]
30struct CanonicalCacheInput<'a> {
31 fingerprint: &'a CaseFingerprint,
32 context: &'a FingerprintContext,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
37pub struct CacheKey(String);
38
39impl CacheKey {
40 #[must_use]
42 pub fn from_fingerprint(fingerprint: &CaseFingerprint, context: &FingerprintContext) -> Self {
43 Self::from_bytes(&canonicalize_fingerprint(fingerprint, context))
44 }
45
46 #[must_use]
48 pub fn from_bytes(bytes: &[u8]) -> Self {
49 Self(hex_lower(&Sha256::digest(bytes)))
50 }
51
52 #[must_use]
54 pub fn as_hex(&self) -> &str {
55 &self.0
56 }
57
58 #[must_use]
60 pub fn into_hex(self) -> String {
61 self.0
62 }
63}
64
65#[must_use]
69pub fn canonicalize_fingerprint(
70 fingerprint: &CaseFingerprint,
71 context: &FingerprintContext,
72) -> Vec<u8> {
73 serde_json::to_vec(&CanonicalCacheInput {
74 fingerprint,
75 context,
76 })
77 .expect("CaseFingerprint + FingerprintContext always serialize")
78}
79
80#[must_use]
83pub fn tool_set_hash<'a, I>(tools: I) -> String
84where
85 I: IntoIterator<Item = (&'a str, &'a str)>,
86{
87 let mut hasher = Sha256::new();
88 let mut sorted: Vec<(&str, &str)> = tools.into_iter().collect();
89 sorted.sort_by(|a, b| a.0.cmp(b.0));
90 for (name, schema) in sorted {
91 hasher.update((name.len() as u64).to_le_bytes());
92 hasher.update(name.as_bytes());
93 hasher.update((schema.len() as u64).to_le_bytes());
94 hasher.update(schema.as_bytes());
95 }
96 hex_lower(&hasher.finalize())
97}
98
99fn hex_lower(bytes: &[u8]) -> String {
100 const HEX: &[u8; 16] = b"0123456789abcdef";
101 let mut out = String::with_capacity(bytes.len() * 2);
102 for byte in bytes {
103 out.push(HEX[(byte >> 4) as usize] as char);
104 out.push(HEX[(byte & 0x0f) as usize] as char);
105 }
106 out
107}
108
109#[derive(Debug, Error)]
111pub enum StoreError {
112 #[error("store io error: {0}")]
114 Io(String),
115 #[error("store serde error: {0}")]
117 Serde(String),
118 #[error("invalid identifier: {0}")]
120 InvalidIdentifier(String),
121}
122
123impl From<io::Error> for StoreError {
124 fn from(err: io::Error) -> Self {
125 Self::Io(err.to_string())
126 }
127}
128
129impl From<serde_json::Error> for StoreError {
130 fn from(err: serde_json::Error) -> Self {
131 Self::Serde(err.to_string())
132 }
133}
134
135pub trait EvaluationDataStore: Send + Sync {
137 fn get(
143 &self,
144 eval_set_id: &str,
145 case_id: &str,
146 key: &CacheKey,
147 ) -> Result<Option<Invocation>, StoreError>;
148
149 fn put(
156 &self,
157 eval_set_id: &str,
158 case_id: &str,
159 key: &CacheKey,
160 invocation: &Invocation,
161 ) -> Result<(), StoreError>;
162}
163
164pub struct LocalFileTaskResultStore {
166 root: PathBuf,
167}
168
169impl LocalFileTaskResultStore {
170 #[must_use]
172 pub fn new(root: impl Into<PathBuf>) -> Self {
173 Self { root: root.into() }
174 }
175
176 #[must_use]
178 pub fn root(&self) -> &Path {
179 &self.root
180 }
181
182 fn case_dir(&self, eval_set_id: &str, case_id: &str) -> Result<PathBuf, StoreError> {
183 validate_identifier(eval_set_id)?;
184 validate_identifier(case_id)?;
185 Ok(self.root.join(eval_set_id).join(case_id))
186 }
187}
188
189impl EvaluationDataStore for LocalFileTaskResultStore {
190 fn get(
191 &self,
192 eval_set_id: &str,
193 case_id: &str,
194 key: &CacheKey,
195 ) -> Result<Option<Invocation>, StoreError> {
196 let path = self
197 .case_dir(eval_set_id, case_id)?
198 .join(format!("{}.json", key.as_hex()));
199 match fs::read(&path) {
200 Ok(bytes) => Ok(Some(serde_json::from_slice(&bytes)?)),
201 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
202 Err(err) => Err(err.into()),
203 }
204 }
205
206 fn put(
207 &self,
208 eval_set_id: &str,
209 case_id: &str,
210 key: &CacheKey,
211 invocation: &Invocation,
212 ) -> Result<(), StoreError> {
213 let dir = self.case_dir(eval_set_id, case_id)?;
214 fs::create_dir_all(&dir)?;
215 let path = dir.join(format!("{}.json", key.as_hex()));
216 let bytes = serde_json::to_vec_pretty(invocation)?;
217 let tmp = path.with_extension("json.tmp");
218 fs::write(&tmp, &bytes)?;
219 fs::rename(&tmp, &path)?;
220 Ok(())
221 }
222}
223
224fn validate_identifier(id: &str) -> Result<(), StoreError> {
225 if id.is_empty() {
226 return Err(StoreError::InvalidIdentifier(
227 "identifier must not be empty".into(),
228 ));
229 }
230 let path = Path::new(id);
231 if path.is_absolute()
232 || path
233 .components()
234 .any(|c| !matches!(c, Component::Normal(_)))
235 || id.contains(['/', '\\'])
236 {
237 return Err(StoreError::InvalidIdentifier(id.to_string()));
238 }
239 Ok(())
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::types::CanonicalJsonValue;
246
247 fn fp(id: &str) -> CaseFingerprint {
248 CaseFingerprint {
249 id: id.into(),
250 name: id.into(),
251 description: None,
252 system_prompt: "sp".into(),
253 user_messages: vec!["hi".into()],
254 expected_trajectory: None,
255 expected_response: None,
256 expected_assertion: None,
257 expected_interactions: None,
258 few_shot_examples: vec![],
259 budget: None,
260 evaluators: vec![],
261 metadata: CanonicalJsonValue::Null,
262 attachments: vec![],
263 expected_environment_state: None,
264 expected_tool_intent: None,
265 semantic_tool_selection: false,
266 }
267 }
268
269 #[test]
270 fn cache_key_deterministic_and_context_sensitive() {
271 let f = fp("c1");
272 let empty = FingerprintContext::default();
273 let a = CacheKey::from_fingerprint(&f, &empty);
274 assert_eq!(a, CacheKey::from_fingerprint(&f, &empty));
275 assert_eq!(a.as_hex().len(), 64);
276 let b = CacheKey::from_fingerprint(
277 &f,
278 &FingerprintContext {
279 initial_session: Some(serde_json::json!({"k": 1})),
280 ..Default::default()
281 },
282 );
283 assert_ne!(a, b);
284 }
285
286 #[test]
287 fn tool_set_hash_is_order_independent() {
288 assert_eq!(
289 tool_set_hash([("a", "{}"), ("b", "{}")]),
290 tool_set_hash([("b", "{}"), ("a", "{}")])
291 );
292 assert_ne!(
293 tool_set_hash([("a", "{}")]),
294 tool_set_hash([("a", "{}"), ("b", "{}")])
295 );
296 }
297
298 #[test]
299 fn validate_identifier_rejects_path_traversal() {
300 assert!(validate_identifier("../evil").is_err());
301 assert!(validate_identifier("a/b").is_err());
302 assert!(validate_identifier("").is_err());
303 assert!(validate_identifier("ok-id_1.0").is_ok());
304 }
305}