wombatkv_node/
kv_blob_cache.rs1#![forbid(unsafe_code)]
2use bytes::Bytes;
26use std::path::{Path, PathBuf};
27
28pub trait KvBlobCache: Send + Sync {
32 fn get(&self, key: &str) -> Option<(Bytes, &'static str)>;
35
36 fn put(&self, key: &str, payload: Bytes);
38
39 fn contains(&self, key: &str) -> bool;
41
42 fn clear(&self);
46
47 fn remove(&self, _key: &str) -> bool {
52 false
53 }
54}
55
56#[derive(Debug)]
58pub enum KvBlobCacheError {
59 Io(String),
60 InvalidConfig(String),
61}
62
63impl std::fmt::Display for KvBlobCacheError {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 Self::Io(m) => write!(f, "WombatKV cache io error: {m}"),
67 Self::InvalidConfig(m) => write!(f, "WombatKV cache invalid config: {m}"),
68 }
69 }
70}
71
72impl std::error::Error for KvBlobCacheError {}
73
74#[derive(Debug)]
84pub struct FlatFileKvBlobCache {
85 root: PathBuf,
86}
87
88impl FlatFileKvBlobCache {
89 pub fn open(root: PathBuf) -> Result<Self, KvBlobCacheError> {
91 if let Some(parent) = root.parent() {
92 std::fs::create_dir_all(parent).map_err(|err| {
93 KvBlobCacheError::Io(format!("create parent {}: {err}", parent.display()))
94 })?;
95 }
96 std::fs::create_dir_all(&root).map_err(|err| {
97 KvBlobCacheError::Io(format!("create root {}: {err}", root.display()))
98 })?;
99 Ok(Self { root })
100 }
101
102 #[must_use]
103 pub fn root(&self) -> &Path {
104 &self.root
105 }
106
107 fn path_for_key(&self, key: &str) -> PathBuf {
111 let mut safe = String::with_capacity(key.len() + 4);
112 for ch in key.chars() {
113 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.' || ch == '=' {
114 safe.push(ch);
115 } else {
116 safe.push('_');
117 }
118 }
119 safe.push_str(".kv");
120 self.root.join(safe)
121 }
122}
123
124impl KvBlobCache for FlatFileKvBlobCache {
125 fn get(&self, key: &str) -> Option<(Bytes, &'static str)> {
126 let path = self.path_for_key(key);
127 let vec = std::fs::read(&path).ok()?;
128 Some((Bytes::from(vec), "load_flat"))
129 }
130
131 fn put(&self, key: &str, payload: Bytes) {
132 let path = self.path_for_key(key);
133 let tmp = path.with_extension("kv.tmp");
134 if let Err(err) = std::fs::write(&tmp, &payload) {
135 eprintln!("WombatKV flat cache: write tmp {} failed: {err}", tmp.display());
136 return;
137 }
138 if let Err(err) = std::fs::rename(&tmp, &path) {
139 let _ = std::fs::remove_file(&tmp);
140 eprintln!(
141 "WombatKV flat cache: rename {} -> {} failed: {err}",
142 tmp.display(),
143 path.display()
144 );
145 }
146 }
147
148 fn contains(&self, key: &str) -> bool {
149 self.path_for_key(key).exists()
150 }
151
152 fn clear(&self) {
153 let _ = std::fs::remove_dir_all(&self.root);
156 let _ = std::fs::create_dir_all(&self.root);
157 }
158
159 fn remove(&self, key: &str) -> bool {
160 std::fs::remove_file(self.path_for_key(key)).is_ok()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::{FlatFileKvBlobCache, KvBlobCache};
170 use bytes::Bytes;
171 use tempfile::tempdir;
172
173 #[test]
174 fn flat_cache_roundtrip() {
175 let dir = tempdir().unwrap();
176 let cache = FlatFileKvBlobCache::open(dir.path().join("puffer")).unwrap();
177 let key = "ns/v1/sha=abc";
178 let payload = Bytes::from(vec![7_u8; 4096]);
179 assert!(!cache.contains(key));
180 assert!(cache.get(key).is_none());
181 cache.put(key, payload.clone());
182 assert!(cache.contains(key));
183 let (got, op) = cache.get(key).unwrap();
184 assert_eq!(got.as_ref(), payload.as_ref());
185 assert_eq!(op, "load_flat");
186 }
187
188 #[test]
189 fn clear_removes_entries() {
190 let dir = tempdir().unwrap();
191 let cache = FlatFileKvBlobCache::open(dir.path().join("puffer")).unwrap();
192 cache.put("k", Bytes::from(vec![1_u8; 16]));
193 assert!(cache.contains("k"));
194 cache.clear();
195 assert!(!cache.contains("k"));
196 }
197
198 #[test]
199 fn unsafe_key_chars_get_sanitized() {
200 let dir = tempdir().unwrap();
201 let cache = FlatFileKvBlobCache::open(dir.path().join("puffer")).unwrap();
202 let key = "ns/v1/model=abc/sha=def";
203 cache.put(key, Bytes::from(vec![3_u8; 8]));
204 let (got, _) = cache.get(key).unwrap();
205 assert_eq!(got.len(), 8);
206 }
207}