1use std::any::Any;
2use std::collections::HashMap;
3use std::ops::Range;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::{fmt, io};
7
8use async_trait::async_trait;
9use parking_lot::lock_api::RwLockWriteGuard;
10use parking_lot::{RawRwLock, RwLock, RwLockUpgradableReadGuard};
11use tantivy::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteError};
12use tantivy::directory::{DirectoryLock, FileHandle, Lock, OwnedBytes, WatchCallback, WatchHandle, WritePtr};
13use tantivy::{Directory, HasLen};
14
15use crate::directories::byte_range_cache::ByteRangeCache;
16
17#[derive(Default, Debug, Clone)]
18pub struct FileStat {
19 pub file_length: Option<u64>,
20 pub generation: u32,
21}
22
23#[derive(Default, Debug, Clone)]
24pub struct MaterializedFileStat {
25 pub file_length: u64,
26 pub generation: u32,
27}
28
29impl FileStat {
30 pub fn inc_gen(&mut self, new_len: Option<u64>) {
31 self.file_length = new_len;
32 self.generation += 1;
33 }
34}
35
36#[derive(Default, Clone)]
37pub struct FileStats(Arc<RwLock<HashMap<PathBuf, FileStat>>>);
38
39impl FileStats {
40 pub fn from_file_lengths(file_lengths: HashMap<PathBuf, u64>) -> Self {
41 FileStats(Arc::new(RwLock::new(HashMap::from_iter(file_lengths.into_iter().map(|(k, v)| {
42 (
43 k,
44 FileStat {
45 file_length: Some(v),
46 generation: 0,
47 },
48 )
49 })))))
50 }
51
52 pub fn inc_gen(&self, path: &Path, new_len: Option<u64>) -> RwLockWriteGuard<'_, RawRwLock, HashMap<PathBuf, FileStat>> {
53 let mut write_lock = self.0.write();
54 write_lock.entry(path.to_path_buf()).or_default().inc_gen(new_len);
55 write_lock
56 }
57
58 pub fn get_or_set(&self, path: &Path, f: impl FnOnce() -> u64) -> MaterializedFileStat {
59 let read_lock = self.0.upgradable_read();
60 let file_stat = read_lock.get(path);
61 let file_length = file_stat.and_then(|file_stat| file_stat.file_length);
62 let generation = file_stat.map(|file_stat| file_stat.generation).unwrap_or_default();
63 match file_length {
64 None => {
65 let file_length = f();
66 let file_stat = FileStat {
67 file_length: Some(file_length),
68 generation,
69 };
70 RwLockUpgradableReadGuard::upgrade(read_lock).insert(path.to_path_buf(), file_stat);
71 MaterializedFileStat { file_length, generation }
72 }
73 Some(file_length) => MaterializedFileStat { file_length, generation },
74 }
75 }
76}
77
78#[derive(Clone)]
80pub struct CachingDirectory {
81 underlying: Arc<dyn Directory>,
82 cache: Arc<ByteRangeCache>,
83 file_stats: FileStats,
84}
85
86impl CachingDirectory {
87 pub fn bounded(underlying: Arc<dyn Directory>, _capacity_in_bytes: usize, file_stats: FileStats) -> CachingDirectory {
92 CachingDirectory {
93 underlying,
94 cache: Arc::new(ByteRangeCache::with_infinite_capacity()),
95 file_stats,
96 }
97 }
98
99 pub fn unbounded(underlying: Arc<dyn Directory>, file_stats: FileStats) -> CachingDirectory {
100 CachingDirectory {
101 underlying,
102 cache: Arc::new(ByteRangeCache::with_infinite_capacity()),
103 file_stats,
104 }
105 }
106}
107
108impl fmt::Debug for CachingDirectory {
109 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110 write!(f, "CachingDirectory({:?})", self.underlying)
111 }
112}
113
114struct CachingFileHandle {
115 path: PathBuf,
116 cache: Arc<ByteRangeCache>,
117 underlying_filehandle: Arc<dyn FileHandle>,
118 generation: u32,
119 len: u64,
120}
121
122impl CachingFileHandle {
123 pub fn get_key(&self) -> PathBuf {
124 PathBuf::from(format!("{}@{}", self.path.to_string_lossy(), self.generation))
125 }
126}
127
128impl fmt::Debug for CachingFileHandle {
129 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
130 write!(
131 f,
132 "CachingFileHandle(path={:?}, underlying={:?})",
133 &self.path,
134 self.underlying_filehandle.as_ref()
135 )
136 }
137}
138
139#[async_trait]
140impl FileHandle for CachingFileHandle {
141 fn read_bytes(&self, byte_range: Range<u64>) -> io::Result<OwnedBytes> {
142 if let Some(bytes) = self.cache.get_slice(&self.get_key(), byte_range.clone()) {
143 return Ok(bytes);
144 }
145 let owned_bytes = self.underlying_filehandle.read_bytes(byte_range.clone())?;
146 self.cache.put_slice(self.get_key(), byte_range, owned_bytes.clone());
147 Ok(owned_bytes)
148 }
149
150 async fn read_bytes_async(&self, byte_range: Range<u64>) -> io::Result<OwnedBytes> {
151 if let Some(owned_bytes) = self.cache.get_slice(&self.get_key(), byte_range.clone()) {
152 return Ok(owned_bytes);
153 }
154 let read_bytes = self.underlying_filehandle.read_bytes_async(byte_range.clone()).await?;
155 self.cache.put_slice(self.get_key(), byte_range, read_bytes.clone());
156 Ok(read_bytes)
157 }
158}
159
160impl HasLen for CachingFileHandle {
161 fn len(&self) -> u64 {
162 self.len
163 }
164}
165
166#[async_trait]
167impl Directory for CachingDirectory {
168 fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
169 let underlying_filehandle = self.underlying.get_file_handle(path)?;
170 let underlying_filehandle_ref = underlying_filehandle.as_ref();
171 let file_stat = self.file_stats.get_or_set(path, || underlying_filehandle_ref.len());
172 Ok(Arc::new(CachingFileHandle {
173 path: path.to_path_buf(),
174 cache: self.cache.clone(),
175 len: file_stat.file_length,
176 generation: file_stat.generation,
177 underlying_filehandle,
178 }))
179 }
180
181 fn delete(&self, path: &Path) -> Result<(), DeleteError> {
182 let _lock = self.file_stats.inc_gen(path, None);
184 self.underlying.delete(path)
185 }
186
187 fn exists(&self, path: &Path) -> Result<bool, OpenReadError> {
188 self.underlying.exists(path)
189 }
190
191 fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
192 let _lock = self.file_stats.inc_gen(path, None);
193 self.underlying.open_write(path)
194 }
195
196 fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
197 let file_handle = self.get_file_handle(path)?;
198 let owned_bytes = file_handle
199 .read_bytes(0..file_handle.len())
200 .map_err(|io_error| OpenReadError::wrap_io_error(io_error, path.to_path_buf()))?;
201 Ok(owned_bytes.as_slice().to_vec())
202 }
203
204 async fn atomic_read_async(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
205 let file_handle = self.get_file_handle(path)?;
206 let owned_bytes = file_handle
207 .read_bytes_async(0..file_handle.len())
208 .await
209 .map_err(|io_error| OpenReadError::wrap_io_error(io_error, path.to_path_buf()))?;
210 Ok(owned_bytes.as_slice().to_vec())
211 }
212
213 fn atomic_write(&self, path: &Path, data: &[u8]) -> io::Result<()> {
214 let _lock = self.file_stats.inc_gen(path, None);
215 self.underlying.atomic_write(path, data)
216 }
217
218 fn sync_directory(&self) -> io::Result<()> {
219 self.underlying.sync_directory()
220 }
221
222 fn acquire_lock(&self, lock: &Lock) -> Result<DirectoryLock, LockError> {
223 self.underlying.acquire_lock(lock)
224 }
225
226 fn watch(&self, callback: WatchCallback) -> tantivy::Result<WatchHandle> {
227 self.underlying.watch(callback)
228 }
229
230 fn as_any(&self) -> &dyn Any {
231 self
232 }
233
234 fn underlying_directory(&self) -> Option<&dyn Directory> {
235 Some(self.underlying.as_ref())
236 }
237
238 fn real_directory(&self) -> &dyn Directory {
239 self.underlying.real_directory()
240 }
241}