1use {
3 crate::{
4 accounts_hash::CalculateHashIntermediate, cache_hash_data_stats::CacheHashDataStats,
5 pubkey_bins::PubkeyBinCalculator24,
6 },
7 log::*,
8 memmap2::MmapMut,
9 safecoin_measure::measure::Measure,
10 std::{
11 collections::HashSet,
12 fs::{self, remove_file, OpenOptions},
13 io::{Seek, SeekFrom, Write},
14 path::{Path, PathBuf},
15 sync::{Arc, Mutex},
16 },
17};
18
19pub type EntryType = CalculateHashIntermediate;
20pub type SavedType = Vec<Vec<EntryType>>;
21pub type SavedTypeSlice = [Vec<EntryType>];
22
23#[repr(C)]
24pub struct Header {
25 count: usize,
26}
27
28struct CacheHashDataFile {
29 cell_size: u64,
30 mmap: MmapMut,
31 capacity: u64,
32}
33
34impl CacheHashDataFile {
35 fn get_mut<T: Sized>(&mut self, ix: u64) -> &mut T {
36 let start = (ix * self.cell_size) as usize + std::mem::size_of::<Header>();
37 let end = start + std::mem::size_of::<T>();
38 assert!(
39 end <= self.capacity as usize,
40 "end: {}, capacity: {}, ix: {}, cell size: {}",
41 end,
42 self.capacity,
43 ix,
44 self.cell_size
45 );
46 let item_slice: &[u8] = &self.mmap[start..end];
47 unsafe {
48 let item = item_slice.as_ptr() as *mut T;
49 &mut *item
50 }
51 }
52
53 fn get_header_mut(&mut self) -> &mut Header {
54 let start = 0_usize;
55 let end = start + std::mem::size_of::<Header>();
56 let item_slice: &[u8] = &self.mmap[start..end];
57 unsafe {
58 let item = item_slice.as_ptr() as *mut Header;
59 &mut *item
60 }
61 }
62
63 fn new_map(file: &Path, capacity: u64) -> Result<MmapMut, std::io::Error> {
64 let mut data = OpenOptions::new()
65 .read(true)
66 .write(true)
67 .create(true)
68 .open(file)?;
69
70 data.seek(SeekFrom::Start(capacity - 1)).unwrap();
74 data.write_all(&[0]).unwrap();
75 data.seek(SeekFrom::Start(0)).unwrap();
76 data.flush().unwrap();
77 Ok(unsafe { MmapMut::map_mut(&data).unwrap() })
78 }
79
80 fn load_map(file: &Path) -> Result<MmapMut, std::io::Error> {
81 let data = OpenOptions::new()
82 .read(true)
83 .write(true)
84 .create(false)
85 .open(file)?;
86
87 Ok(unsafe { MmapMut::map_mut(&data).unwrap() })
88 }
89}
90
91pub type PreExistingCacheFiles = HashSet<String>;
92pub struct CacheHashData {
93 cache_folder: PathBuf,
94 pre_existing_cache_files: Arc<Mutex<PreExistingCacheFiles>>,
95 pub stats: Arc<Mutex<CacheHashDataStats>>,
96}
97
98impl Drop for CacheHashData {
99 fn drop(&mut self) {
100 self.delete_old_cache_files();
101 self.stats.lock().unwrap().report();
102 }
103}
104
105impl CacheHashData {
106 pub fn new<P: AsRef<Path> + std::fmt::Debug>(parent_folder: &P) -> CacheHashData {
107 let cache_folder = Self::get_cache_root_path(parent_folder);
108
109 std::fs::create_dir_all(cache_folder.clone())
110 .unwrap_or_else(|_| panic!("error creating cache dir: {:?}", cache_folder));
111
112 let result = CacheHashData {
113 cache_folder,
114 pre_existing_cache_files: Arc::new(Mutex::new(PreExistingCacheFiles::default())),
115 stats: Arc::new(Mutex::new(CacheHashDataStats::default())),
116 };
117
118 result.get_cache_files();
119 result
120 }
121 fn delete_old_cache_files(&self) {
122 let pre_existing_cache_files = self.pre_existing_cache_files.lock().unwrap();
123 if !pre_existing_cache_files.is_empty() {
124 self.stats.lock().unwrap().unused_cache_files += pre_existing_cache_files.len();
125 for file_name in pre_existing_cache_files.iter() {
126 let result = self.cache_folder.join(file_name);
127 let _ = fs::remove_file(result);
128 }
129 }
130 }
131 fn get_cache_files(&self) {
132 if self.cache_folder.is_dir() {
133 let dir = fs::read_dir(self.cache_folder.clone());
134 if let Ok(dir) = dir {
135 let mut pre_existing = self.pre_existing_cache_files.lock().unwrap();
136 for entry in dir.flatten() {
137 if let Some(name) = entry.path().file_name() {
138 pre_existing.insert(name.to_str().unwrap().to_string());
139 }
140 }
141 self.stats.lock().unwrap().cache_file_count += pre_existing.len();
142 }
143 }
144 }
145
146 fn get_cache_root_path<P: AsRef<Path>>(parent_folder: &P) -> PathBuf {
147 parent_folder.as_ref().join("calculate_accounts_hash_cache")
148 }
149
150 pub fn load<P: AsRef<Path> + std::fmt::Debug>(
152 &self,
153 file_name: &P,
154 accumulator: &mut SavedType,
155 start_bin_index: usize,
156 bin_calculator: &PubkeyBinCalculator24,
157 ) -> Result<(), std::io::Error> {
158 let mut stats = CacheHashDataStats::default();
159 let result = self.load_internal(
160 file_name,
161 accumulator,
162 start_bin_index,
163 bin_calculator,
164 &mut stats,
165 );
166 self.stats.lock().unwrap().merge(&stats);
167 result
168 }
169
170 fn load_internal<P: AsRef<Path> + std::fmt::Debug>(
171 &self,
172 file_name: &P,
173 accumulator: &mut SavedType,
174 start_bin_index: usize,
175 bin_calculator: &PubkeyBinCalculator24,
176 stats: &mut CacheHashDataStats,
177 ) -> Result<(), std::io::Error> {
178 let mut m = Measure::start("overall");
179 let path = self.cache_folder.join(file_name);
180 let file_len = std::fs::metadata(path.clone())?.len();
181 let mut m1 = Measure::start("read_file");
182 let mmap = CacheHashDataFile::load_map(&path)?;
183 m1.stop();
184 stats.read_us = m1.as_us();
185 let header_size = std::mem::size_of::<Header>() as u64;
186 if file_len < header_size {
187 return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
188 }
189
190 let cell_size = std::mem::size_of::<EntryType>() as u64;
191 let mut cache_file = CacheHashDataFile {
192 mmap,
193 cell_size,
194 capacity: 0,
195 };
196 let header = cache_file.get_header_mut();
197 let entries = header.count;
198
199 let capacity = cell_size * (entries as u64) + header_size;
200 if file_len < capacity {
201 return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
202 }
203 cache_file.capacity = capacity;
204 assert_eq!(
205 capacity, file_len,
206 "expected: {}, len on disk: {} {:?}, entries: {}, cell_size: {}",
207 capacity, file_len, path, entries, cell_size
208 );
209
210 stats.total_entries = entries;
211 stats.cache_file_size += capacity as usize;
212
213 let file_name_lookup = file_name.as_ref().to_str().unwrap().to_string();
214 let found = self
215 .pre_existing_cache_files
216 .lock()
217 .unwrap()
218 .remove(&file_name_lookup);
219 if !found {
220 info!(
221 "tried to mark {:?} as used, but it wasn't in the set, one example: {:?}",
222 file_name_lookup,
223 self.pre_existing_cache_files.lock().unwrap().iter().next()
224 );
225 }
226
227 stats.loaded_from_cache += 1;
228 stats.entries_loaded_from_cache += entries;
229 let mut m2 = Measure::start("decode");
230 for i in 0..entries {
231 let d = cache_file.get_mut::<EntryType>(i as u64);
232 let mut pubkey_to_bin_index = bin_calculator.bin_from_pubkey(&d.pubkey);
233 assert!(
234 pubkey_to_bin_index >= start_bin_index,
235 "{}, {}",
236 pubkey_to_bin_index,
237 start_bin_index
238 ); pubkey_to_bin_index -= start_bin_index;
240 accumulator[pubkey_to_bin_index].push(d.clone()); }
242
243 m2.stop();
244 stats.decode_us += m2.as_us();
245 m.stop();
246 stats.load_us += m.as_us();
247 Ok(())
248 }
249
250 pub fn save(&self, file_name: &Path, data: &SavedTypeSlice) -> Result<(), std::io::Error> {
252 let mut stats = CacheHashDataStats::default();
253 let result = self.save_internal(file_name, data, &mut stats);
254 self.stats.lock().unwrap().merge(&stats);
255 result
256 }
257
258 fn save_internal(
259 &self,
260 file_name: &Path,
261 data: &SavedTypeSlice,
262 stats: &mut CacheHashDataStats,
263 ) -> Result<(), std::io::Error> {
264 let mut m = Measure::start("save");
265 let cache_path = self.cache_folder.join(file_name);
266 let create = true;
267 if create {
268 let _ignored = remove_file(&cache_path);
269 }
270 let cell_size = std::mem::size_of::<EntryType>() as u64;
271 let mut m1 = Measure::start("create save");
272 let entries = data
273 .iter()
274 .map(|x: &Vec<EntryType>| x.len())
275 .collect::<Vec<_>>();
276 let entries = entries.iter().sum::<usize>();
277 let capacity = cell_size * (entries as u64) + std::mem::size_of::<Header>() as u64;
278
279 let mmap = CacheHashDataFile::new_map(&cache_path, capacity)?;
280 m1.stop();
281 stats.create_save_us += m1.as_us();
282 let mut cache_file = CacheHashDataFile {
283 mmap,
284 cell_size,
285 capacity,
286 };
287
288 let mut header = cache_file.get_header_mut();
289 header.count = entries;
290
291 stats.cache_file_size = capacity as usize;
292 stats.total_entries = entries;
293
294 let mut m2 = Measure::start("write_to_mmap");
295 let mut i = 0;
296 data.iter().for_each(|x| {
297 x.iter().for_each(|item| {
298 let d = cache_file.get_mut::<EntryType>(i as u64);
299 i += 1;
300 *d = item.clone();
301 })
302 });
303 assert_eq!(i, entries);
304 m2.stop();
305 stats.write_to_mmap_us += m2.as_us();
306 m.stop();
307 stats.save_us += m.as_us();
308 stats.saved_to_cache += 1;
309 Ok(())
310 }
311}
312
313#[cfg(test)]
314pub mod tests {
315 use {super::*, rand::Rng};
316
317 #[test]
318 fn test_read_write() {
319 use tempfile::TempDir;
324 let tmpdir = TempDir::new().unwrap();
325 std::fs::create_dir_all(&tmpdir).unwrap();
326
327 for bins in [1, 2, 4] {
328 let bin_calculator = PubkeyBinCalculator24::new(bins);
329 let num_points = 5;
330 let (data, _total_points) = generate_test_data(num_points, bins, &bin_calculator);
331 for passes in [1, 2] {
332 let bins_per_pass = bins / passes;
333 if bins_per_pass == 0 {
334 continue; }
336 for pass in 0..passes {
337 for flatten_data in [true, false] {
338 let mut data_this_pass = if flatten_data {
339 vec![vec![], vec![]]
340 } else {
341 vec![]
342 };
343 let start_bin_this_pass = pass * bins_per_pass;
344 for bin in 0..bins_per_pass {
345 let mut this_bin_data = data[bin + start_bin_this_pass].clone();
346 if flatten_data {
347 data_this_pass[0].append(&mut this_bin_data);
348 } else {
349 data_this_pass.push(this_bin_data);
350 }
351 }
352 let cache = CacheHashData::new(&tmpdir);
353 let file_name = "test";
354 let file = Path::new(file_name).to_path_buf();
355 cache.save(&file, &data_this_pass).unwrap();
356 cache.get_cache_files();
357 assert_eq!(
358 cache
359 .pre_existing_cache_files
360 .lock()
361 .unwrap()
362 .iter()
363 .collect::<Vec<_>>(),
364 vec![file_name]
365 );
366 let mut accum = (0..bins_per_pass).into_iter().map(|_| vec![]).collect();
367 cache
368 .load(&file, &mut accum, start_bin_this_pass, &bin_calculator)
369 .unwrap();
370 if flatten_data {
371 bin_data(
372 &mut data_this_pass,
373 &bin_calculator,
374 bins_per_pass,
375 start_bin_this_pass,
376 );
377 }
378 assert_eq!(
379 accum, data_this_pass,
380 "bins: {}, start_bin_this_pass: {}, pass: {}, flatten: {}, passes: {}",
381 bins, start_bin_this_pass, pass, flatten_data, passes
382 );
383 }
384 }
385 }
386 }
387 }
388
389 fn bin_data(
390 data: &mut SavedType,
391 bin_calculator: &PubkeyBinCalculator24,
392 bins: usize,
393 start_bin: usize,
394 ) {
395 let mut accum: SavedType = (0..bins).into_iter().map(|_| vec![]).collect();
396 data.drain(..).into_iter().for_each(|mut x| {
397 x.drain(..).into_iter().for_each(|item| {
398 let bin = bin_calculator.bin_from_pubkey(&item.pubkey);
399 accum[bin - start_bin].push(item);
400 })
401 });
402 *data = accum;
403 }
404
405 fn generate_test_data(
406 count: usize,
407 bins: usize,
408 binner: &PubkeyBinCalculator24,
409 ) -> (SavedType, usize) {
410 let mut rng = rand::thread_rng();
411 let mut ct = 0;
412 (
413 (0..bins)
414 .into_iter()
415 .map(|bin| {
416 let rnd = rng.gen::<u64>() % (bins as u64);
417 if rnd < count as u64 {
418 (0..std::cmp::max(1, count / bins))
419 .into_iter()
420 .map(|_| {
421 ct += 1;
422 let mut pk;
423 loop {
424 pk = solana_sdk::pubkey::new_rand();
426 if binner.bin_from_pubkey(&pk) == bin {
427 break;
428 }
429 }
430
431 CalculateHashIntermediate::new(
432 solana_sdk::hash::new_rand(&mut rng),
433 ct as u64,
434 pk,
435 )
436 })
437 .collect::<Vec<_>>()
438 } else {
439 vec![]
440 }
441 })
442 .collect::<Vec<_>>(),
443 ct,
444 )
445 }
446}