1mod query_len;
2mod uffd;
3mod vec_writer;
4
5use crate::{uffd::round_up_to_page, vec_writer::VecWriter};
6use dashmap::DashMap;
7use log::info;
8use nix::sys::mman::{mmap, MapFlags, ProtFlags};
9use std::{
10 ops::{Deref, Range},
11 path::Path,
12 slice,
13 sync::Arc,
14};
15use tantivy::{
16 directory::{
17 error::{DeleteError, OpenReadError, OpenWriteError},
18 WatchHandle, WritePtr,
19 },
20 Directory,
21};
22use tantivy_common::{file_slice::FileHandle, HasLen, OwnedBytes, StableDeref};
23use tokio::runtime::Runtime;
24use uffd::UffdFile;
25use userfaultfd::UffdBuilder;
26
27thread_local! {
28 pub(crate) static BLOCKING_HTTP_CLIENT: reqwest::blocking::Client = reqwest::blocking::Client::new();
29}
30
31#[derive(Clone)]
32struct MmapArc {
33 slice: &'static [u8],
34}
35
36impl Deref for MmapArc {
37 type Target = [u8];
38
39 #[inline]
40 fn deref(&self) -> &[u8] {
41 self.slice
42 }
43}
44unsafe impl StableDeref for MmapArc {}
45
46#[derive(Debug, Clone, Hash, Eq, PartialEq)]
47struct CacheKey {
48 base_url: String,
49 path: String,
50 chunk: usize,
51}
52
53#[derive(Debug, Clone)]
54struct HttpFileHandle<const CHUNK_SIZE: usize> {
55 owned_bytes: Arc<OwnedBytes>,
56 _uffd_file: Option<Arc<UffdFile<CHUNK_SIZE>>>,
57}
58
59impl<const CHUNK_SIZE: usize> HttpFileHandle<CHUNK_SIZE> {
60 pub(crate) fn new(runtime: Arc<Runtime>, file_size: usize, artifact_url: String) -> Self {
61 let mmap_len = round_up_to_page(file_size, CHUNK_SIZE);
62 let uffd = UffdBuilder::new()
63 .close_on_exec(true)
64 .user_mode_only(true)
65 .create()
66 .unwrap();
67
68 let addr = unsafe {
69 mmap(
70 None,
71 mmap_len.try_into().unwrap(),
72 ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
73 MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS | MapFlags::MAP_NORESERVE,
74 None::<std::os::fd::BorrowedFd>,
75 0,
76 )
77 .expect("mmap")
78 };
79
80 let mmap_ptr = addr as usize;
81
82 uffd.register(addr, mmap_len).unwrap();
83
84 let uffd_file = Arc::new(UffdFile::new(
85 Arc::new(uffd),
86 runtime,
87 mmap_ptr,
88 artifact_url.clone(),
89 ));
90 {
91 let uffd_file = uffd_file.clone();
92 std::thread::spawn(move || {
93 uffd_file.handle_faults();
94 });
95 }
96 let owned_bytes = Arc::new(OwnedBytes::new(MmapArc {
97 slice: unsafe { slice::from_raw_parts(mmap_ptr as *const u8, file_size) },
98 }));
99
100 Self {
101 owned_bytes,
102 _uffd_file: Some(uffd_file),
103 }
104 }
105}
106
107impl<const CHUNK_SIZE: usize> FileHandle for HttpFileHandle<CHUNK_SIZE> {
108 fn read_bytes(&self, range: Range<usize>) -> std::io::Result<OwnedBytes> {
109 Ok(self.owned_bytes.slice(range))
110 }
111}
112
113impl<const CHUNK_SIZE: usize> HasLen for HttpFileHandle<CHUNK_SIZE> {
114 fn len(&self) -> usize {
115 self.owned_bytes.len()
116 }
117}
118
119#[derive(Debug, Clone)]
127pub struct RemoteDirectory<const CHUNK_SIZE: usize> {
128 base_url: String,
129 file_handle_cache: Arc<DashMap<String, Arc<HttpFileHandle<CHUNK_SIZE>>>>,
130 atomic_read_cache: Arc<DashMap<String, Vec<u8>>>,
131 uffd_runtime: Arc<Runtime>,
132}
133
134impl<const CHUNK_SIZE: usize> RemoteDirectory<CHUNK_SIZE> {
135 pub fn new(base_url: &str) -> Self {
140 let rt = Runtime::new().unwrap();
141
142 Self {
143 base_url: base_url.to_string(),
144 file_handle_cache: Arc::new(DashMap::new()),
145 atomic_read_cache: Arc::new(DashMap::new()),
146 uffd_runtime: Arc::new(rt),
147 }
148 }
149
150 fn format_url(&self, path: &Path) -> String {
151 if self.base_url.ends_with('/') {
152 format!("{}{}", self.base_url, path.display())
153 } else {
154 format!("{}/{}", self.base_url, path.display())
155 }
156 }
157}
158
159impl<const CHUNK_SIZE: usize> Directory for RemoteDirectory<CHUNK_SIZE> {
160 fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
161 let url = self.format_url(path);
162 {
163 if let Some(file_handle) = self.file_handle_cache.get(&url) {
164 return Ok(file_handle.clone());
165 }
166 }
167 let file_len = query_len::len(&url);
168 let len = round_up_to_page(file_len, CHUNK_SIZE);
169
170 if len == 0 {
171 return Ok(Arc::new(HttpFileHandle::<CHUNK_SIZE> {
172 owned_bytes: Arc::new(OwnedBytes::new(MmapArc { slice: &[] })),
173 _uffd_file: None,
174 }));
175 }
176
177 let file_handle = Arc::new(HttpFileHandle::<CHUNK_SIZE>::new(
178 self.uffd_runtime.clone(),
179 file_len,
180 url.clone(),
181 ));
182 self.file_handle_cache.insert(url, file_handle.clone());
183
184 Ok(file_handle)
185 }
186
187 fn delete(&self, path: &Path) -> Result<(), DeleteError> {
188 if path == Path::new(".tantivy-meta.lock") {
189 return Ok(());
190 }
191
192 Err(DeleteError::IoError {
193 io_error: Arc::new(std::io::Error::new(
194 std::io::ErrorKind::Other,
195 "Delete not supported",
196 )),
197 filepath: path.to_path_buf(),
198 })
199 }
200
201 fn exists(&self, path: &Path) -> Result<bool, OpenReadError> {
202 if path == Path::new(".tantivy-meta.lock") {
203 return Ok(true);
204 }
205 Ok(query_len::len(&self.format_url(path)) > 0)
206 }
207
208 fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
209 if path == Path::new(".tantivy-meta.lock") {
210 return Ok(WritePtr::new(Box::new(VecWriter::new(path.to_path_buf()))));
211 }
212 dbg!(path);
213 Err(OpenWriteError::IoError {
214 io_error: Arc::new(std::io::Error::new(
215 std::io::ErrorKind::Other,
216 "Write not supported",
217 )),
218 filepath: path.to_path_buf(),
219 })
220 }
221
222 fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
223 let url = self.format_url(path);
224 if let Some(bytes) = self.atomic_read_cache.get(&url) {
225 return Ok(bytes.clone());
226 }
227
228 info!("Fetching {} in atomic read.", url);
229 let response = BLOCKING_HTTP_CLIENT.with(|client| client.get(&url).send());
230 let response = if let Err(_e) = response {
231 return Err(OpenReadError::IoError {
232 io_error: Arc::new(std::io::Error::new(
233 std::io::ErrorKind::Other,
234 "Fetch failed for atomic read.",
235 )),
236 filepath: path.to_path_buf(),
237 });
238 } else {
239 response.unwrap()
240 };
241 let bytes = response.bytes().unwrap();
242
243 let bytes = bytes.to_vec();
244 self.atomic_read_cache.insert(url, bytes.clone());
245 Ok(bytes)
246 }
247
248 fn atomic_write(&self, _path: &Path, _data: &[u8]) -> std::io::Result<()> {
249 Err(std::io::Error::new(
250 std::io::ErrorKind::Other,
251 "Write not supported",
252 ))
253 }
254
255 fn sync_directory(&self) -> std::io::Result<()> {
256 Ok(())
257 }
258
259 fn watch(
260 &self,
261 _watch_callback: tantivy::directory::WatchCallback,
262 ) -> tantivy::Result<tantivy::directory::WatchHandle> {
263 Ok(WatchHandle::empty())
264 }
265}
266
267#[cfg(test)]
268pub(crate) mod test {
269
270 use std::{path::PathBuf, str::FromStr, sync::OnceLock};
271
272 use tantivy::{directory::ManagedDirectory, doc, schema::Field, Directory, Index};
273 use tiny_http::{Header, Method, Response, Server};
274
275 pub(crate) static TEST_SERVER_BASE_URL: OnceLock<String> = OnceLock::new();
276
277 pub(crate) fn test_schema_name() -> Field {
278 test_schema().get_field("name").unwrap()
279 }
280
281 pub(crate) fn test_schema_doc() -> Field {
282 test_schema().get_field("doc").unwrap()
283 }
284
285 pub(crate) fn test_schema() -> tantivy::schema::Schema {
286 let mut schema_builder = tantivy::schema::Schema::builder();
287 schema_builder.add_text_field("name", tantivy::schema::TEXT | tantivy::schema::STORED);
288 schema_builder.add_text_field("doc", tantivy::schema::TEXT | tantivy::schema::STORED);
289 schema_builder.build()
290 }
291
292 fn init_test_index_no_remote() -> ManagedDirectory {
293 let schema = test_schema();
294 let index = Index::create_in_ram(schema);
295 let index = std::thread::spawn(move || {
296 let mut writer = index.writer(15_000_000).unwrap();
297 writer
298 .add_document(doc!(
299 test_schema_name() => "LICENSE_MIT",
300 test_schema_doc() => include_str!("../LICENSE_MIT"),
301 ))
302 .unwrap();
303 writer
304 .add_document(doc!(
305 test_schema_name() => "LICENSE_APACHE",
306 test_schema_doc() => include_str!("../LICENSE_APACHE"),
307 ))
308 .unwrap();
309 writer.commit().unwrap();
310 drop(writer);
311 let ids = index.searchable_segment_ids().unwrap();
312 let writer = index.writer(15_000_000).unwrap();
313
314 tokio::runtime::Runtime::new().unwrap().block_on(async {
315 let mut writer = writer;
316 writer.merge(&ids).await.unwrap()
317 });
318
319 index
320 })
321 .join()
322 .unwrap();
323 let dir = index.directory().clone();
324 drop(index);
325
326 for path in dir.list_managed_files() {
327 if path.ends_with("meta.json") {
328 continue;
329 }
330 dir.validate_checksum(&path).unwrap();
331 }
332
333 dir
334 }
335
336 pub(crate) fn test_index() -> Index {
337 let http_directory =
339 super::RemoteDirectory::<8192>::new(&TEST_SERVER_BASE_URL.get().unwrap());
340 Index::open(http_directory).unwrap()
341 }
342
343 fn run_test_server() {
344 let test_index = init_test_index_no_remote();
345
346 let server = Server::http("127.0.0.1:0").unwrap();
347
348 std::thread::spawn(move || {
349 TEST_SERVER_BASE_URL.get_or_init(|| format!("http://{}", server.server_addr()));
350 for req in server.incoming_requests() {
351 let path = req.url().trim_start_matches('/');
352 if req.method() == &Method::Get {
353 let data = if let Some(range_header) = req
354 .headers()
355 .iter()
356 .find(|h| h.field.as_str().to_ascii_lowercase() == "range")
357 {
358 let data = test_index
359 .atomic_read(&PathBuf::from_str(path).unwrap())
360 .unwrap();
361
362 let range = {
363 let range_str = range_header.value.to_string();
364 let range_str = range_str.split('=').last().unwrap();
365 let range = range_str.split('-').collect::<Vec<&str>>();
366 let start = range[0].parse::<usize>().unwrap();
367 let end = (1 + range[1].parse::<usize>().unwrap()).min(data.len());
368 start..end
369 };
370 data[range].to_vec()
371 } else {
372 test_index
373 .atomic_read(&PathBuf::from_str(path).unwrap())
374 .unwrap()
375 };
376 let response = Response::from_data(data);
377 req.respond(response).unwrap();
378 } else if req.method() == &Method::Head {
379 let len = test_index
380 .atomic_read(&PathBuf::from_str(path).unwrap())
381 .unwrap()
382 .len();
383 let mut response = Response::from_string("".to_string());
384 response.add_header(
385 Header::from_bytes(&b"Content-Length"[..], len.to_string()).unwrap(),
386 );
387 req.respond(response).unwrap();
388 }
389 }
390 });
391 }
392
393 #[ctor::ctor]
394 fn ctor_init() {
395 run_test_server();
396 }
397
398 #[test]
399 fn test_has_meta_json() {
400 let http_directory =
401 super::RemoteDirectory::<8192>::new(&TEST_SERVER_BASE_URL.get().unwrap());
402 assert!(
403 http_directory
404 .atomic_read(std::path::Path::new("meta.json"))
405 .unwrap()
406 .len()
407 > 0
408 );
409 }
410
411 #[test]
412 fn test_has_docs() {
413 let reader = test_index().reader().unwrap();
414 assert_eq!(reader.searcher().num_docs(), 2);
415 }
416
417 #[test]
418 fn search_docs() {
419 let index = test_index();
420 let reader = index.reader().unwrap();
421 let searcher = reader.searcher();
422 let query_parser = tantivy::query::QueryParser::for_index(&index, vec![test_schema_name()]);
423 let query = query_parser.parse_query("LICENSE_MIT").unwrap();
424 let top_docs = searcher
425 .search(&query, &tantivy::collector::TopDocs::with_limit(10))
426 .unwrap();
427 assert_eq!(top_docs.len(), 1);
428 }
429}