1use crate::{GlobalArgs, GranularityArgs, NumThreadsArg};
9use anyhow::Result;
10use clap::Parser;
11use dsi_bitstream::{dispatch::factory::CodesReaderFactoryHelper, prelude::*};
12use dsi_progress_logger::prelude::*;
13use std::path::PathBuf;
14use webgraph::prelude::*;
15
16#[derive(Parser, Debug)]
17#[command(name = "codes", about = "Reads a graph and suggests the best codes to use.", long_about = None)]
18pub struct CliArgs {
19 pub basename: PathBuf,
21
22 #[clap(flatten)]
23 pub num_threads: NumThreadsArg,
24
25 #[clap(flatten)]
26 pub granularity: GranularityArgs,
27
28 #[clap(short = 'k', long, default_value_t = 3)]
29 pub top_k: usize,
32}
33
34pub fn main(global_args: GlobalArgs, args: CliArgs) -> Result<()> {
35 match get_endianness(&args.basename)?.as_str() {
36 #[cfg(feature = "be_bins")]
37 BE::NAME => optimize_codes::<BE>(global_args, args),
38 #[cfg(feature = "le_bins")]
39 LE::NAME => optimize_codes::<LE>(global_args, args),
40 e => panic!("Unknown endianness: {}", e),
41 }
42}
43
44pub struct ChunksIter {
48 total: core::ops::Range<usize>,
49 chunk_size: usize,
50}
51
52impl ChunksIter {
53 pub fn new(total: core::ops::Range<usize>, chunk_size: usize) -> Self {
54 Self { total, chunk_size }
55 }
56}
57
58impl Iterator for ChunksIter {
59 type Item = core::ops::Range<usize>;
60
61 fn next(&mut self) -> Option<Self::Item> {
62 if self.total.start < self.total.end {
63 let end = (self.total.start + self.chunk_size).min(self.total.end);
64 let range = self.total.start..end;
65 self.total.start = end;
66 Some(range)
67 } else {
68 None
69 }
70 }
71}
72
73pub fn optimize_codes<E: Endianness>(global_args: GlobalArgs, args: CliArgs) -> Result<()>
74where
75 MmapHelper<u32>: CodesReaderFactoryHelper<E>,
76 for<'a> LoadModeCodesReader<'a, E, Mmap>: BitSeek,
77{
78 let mut stats = Default::default();
79 let has_ef = std::fs::metadata(args.basename.with_extension("ef")).is_ok_and(|x| x.is_file());
80
81 let (_, _, comp_flags) =
83 parse_properties::<E>(args.basename.with_extension(PROPERTIES_EXTENSION))?;
84
85 if has_ef {
86 log::info!(
87 "Analyzing codes in parallel using {} threads",
88 args.num_threads.num_threads
89 );
90 let graph = BvGraph::with_basename(&args.basename)
91 .endianness::<E>()
92 .load()?;
93
94 let mut pl = concurrent_progress_logger![item_name = "node"];
95 pl.display_memory(true)
96 .expected_updates(Some(graph.num_nodes()));
97 pl.start("Scanning...");
98
99 if let Some(duration) = global_args.log_interval {
100 pl.log_interval(duration);
101 }
102
103 let thread_pool = rayon::ThreadPoolBuilder::new()
104 .num_threads(args.num_threads.num_threads)
105 .build()?;
106
107 let node_granularity = args
108 .granularity
109 .into_granularity()
110 .node_granularity(graph.num_nodes(), Some(graph.num_arcs()));
111
112 stats = thread_pool.install(|| {
115 ChunksIter::new(0..graph.num_nodes(), node_granularity).par_map_fold_with(
116 pl.clone(),
117 |pl, range| {
118 let mut iter = graph
119 .offset_deg_iter_from(range.start)
120 .map_decoder(|d| StatsDecoder::new(d, Default::default()));
121
122 for _ in (&mut iter).take(range.len()) {
123 pl.light_update();
124 }
125
126 let mut stats = Default::default();
127 iter.map_decoder(|d| {
128 stats = d.stats;
129 d.codes_reader });
131 stats
132 },
133 |mut acc1, acc2| {
134 acc1 += &acc2;
135 acc1
136 },
137 )
138 });
139
140 pl.done();
141 } else {
142 if args.num_threads.num_threads != 1 {
143 log::info!(SEQ_PROC_WARN![], args.basename.display());
144 }
145
146 let graph = BvGraphSeq::with_basename(args.basename)
147 .endianness::<E>()
148 .load()?;
149
150 let mut pl = ProgressLogger::default();
151 pl.display_memory(true)
152 .item_name("node")
153 .expected_updates(Some(graph.num_nodes()));
154
155 pl.start("Scanning...");
156
157 let mut iter = graph
159 .offset_deg_iter()
160 .map_decoder(|d| StatsDecoder::new(d, Default::default()));
161 for _ in iter.by_ref() {
163 pl.light_update();
164 }
165 pl.done();
166 iter.map_decoder(|d| {
168 stats = d.stats;
169 d.codes_reader });
171 }
172
173 println!("Default codes");
174 compare_codes(&stats, CompFlags::default(), args.top_k);
175
176 print!("\n\n\n");
177
178 println!("Current codes");
179 compare_codes(&stats, comp_flags, args.top_k);
180
181 Ok(())
182}
183
184fn get_size_by_code(stats: &CodesStats, code: Codes) -> Option<u64> {
187 match code {
188 Codes::Unary => Some(stats.unary),
189 Codes::Gamma => Some(stats.gamma),
190 Codes::Delta => Some(stats.delta),
191 Codes::Omega => Some(stats.omega),
192 Codes::VByteBe | Codes::VByteLe => Some(stats.vbyte),
193 Codes::Zeta(k) => stats.zeta.get(k - 1).copied(),
194 Codes::Golomb(b) => stats.golomb.get(b as usize - 1).copied(),
195 Codes::ExpGolomb(k) => stats.exp_golomb.get(k).copied(),
196 Codes::Rice(k) => stats.rice.get(k).copied(),
197 Codes::Pi(0) => Some(stats.gamma), Codes::Pi(1) => Some(stats.zeta[1]), Codes::Pi(k) => stats.pi.get(k - 2).copied(),
200 _ => unreachable!("Code {:?} not supported", code),
201 }
202}
203
204pub fn compare_codes(stats: &DecoderStats, reference: CompFlags, top_k: usize) {
206 macro_rules! impl_best_code {
207 ($new_bits:expr, $old_bits:expr, $stats:expr, $($code:ident -> $old:expr),*) => {
208 println!("{:>17} {:>20} {:>12} {:>10} {:>10} {:>16}",
209 "Type", "Code", "Improvement", "Weight", "Bytes", "Bits",
210 );
211 $(
212 let (_, new) = $stats.$code.best_code();
213 $new_bits += new;
214 $old_bits += $old;
215 )*
216
217 $(
218 let codes = $stats.$code.get_codes();
219 let (best_code, best_size) = codes[0];
220
221 let improvement = 100.0 * ($old - best_size) as f64 / $old as f64;
222 let weight = 100.0 * ($old as f64 - best_size as f64) / ($old_bits as f64 - $new_bits as f64);
223
224 println!("{:>17} {:>20} {:>12.3}% {:>9.3}% {:>10} {:>16}",
225 stringify!($code),
226 format!("{:?}", best_code),
227 improvement,
228 weight,
229 normalize(best_size as f64 / 8.0),
230 best_size,
231 );
232 for i in 1..top_k.min(codes.len()).max(1) {
233 let (code, size) = codes[i];
234 let improvement = 100.0 * ($old as f64 - size as f64) / $old as f64;
235 println!("{:>17} {:>20} {:>12.3}% {:>10.3} {:>10} {:>16}",
236 stringify!($code),
237 format!("{:?}", code),
238 improvement,
239 "",
240 normalize(size as f64 / 8.0),
241 size,
242 );
243 }
244 print!("\n");
245 )*
246 };
247 }
248
249 println!("Code optimization results against:");
250 for (name, code) in [
251 ("outdegrees", reference.outdegrees),
252 ("reference offsets", reference.references),
253 ("block counts", reference.blocks),
254 ("blocks", reference.blocks),
255 ("interval counts", reference.intervals),
256 ("interval starts", reference.intervals),
257 ("interval lengths", reference.intervals),
258 ("first residuals", reference.residuals),
259 ("residuals", reference.residuals),
260 ] {
261 println!("\t{:>18} : {:?}", name, code);
262 }
263
264 let mut new_bits = 0;
265 let mut old_bits = 0;
266 impl_best_code!(
267 new_bits,
268 old_bits,
269 stats,
270 outdegrees -> get_size_by_code(&stats.outdegrees, reference.outdegrees).unwrap(),
271 reference_offsets -> get_size_by_code(&stats.reference_offsets, reference.references).unwrap(),
272 block_counts -> get_size_by_code(&stats.block_counts, reference.blocks).unwrap(),
273 blocks -> get_size_by_code(&stats.blocks, reference.blocks).unwrap(),
274 interval_counts -> get_size_by_code(&stats.interval_counts, reference.intervals).unwrap(),
275 interval_starts -> get_size_by_code(&stats.interval_starts, reference.intervals).unwrap(),
276 interval_lens -> get_size_by_code(&stats.interval_lens, reference.intervals).unwrap(),
277 first_residuals -> get_size_by_code(&stats.first_residuals, reference.residuals).unwrap(),
278 residuals -> get_size_by_code(&stats.residuals, reference.residuals).unwrap()
279 );
280
281 println!();
282 println!(" Old bit size: {:>16}", old_bits);
283 println!(" New bit size: {:>16}", new_bits);
284 println!(" Saved bits: {:>16}", old_bits - new_bits);
285
286 println!("Old byte size: {:>16}", normalize(old_bits as f64 / 8.0));
287 println!("New byte size: {:>16}", normalize(new_bits as f64 / 8.0));
288 println!(
289 " Saved bytes: {:>16}",
290 normalize((old_bits - new_bits) as f64 / 8.0)
291 );
292
293 println!(
294 " Improvement: {:>15.3}%",
295 100.0 * (old_bits - new_bits) as f64 / old_bits as f64
296 );
297}
298
299fn normalize(mut value: f64) -> String {
300 let mut uom = ' ';
301 if value > 1000.0 {
302 value /= 1000.0;
303 uom = 'K';
304 }
305 if value > 1000.0 {
306 value /= 1000.0;
307 uom = 'M';
308 }
309 if value > 1000.0 {
310 value /= 1000.0;
311 uom = 'G';
312 }
313 if value > 1000.0 {
314 value /= 1000.0;
315 uom = 'T';
316 }
317 format!("{:.3}{}", value, uom)
318}