1use super::schema::{Group, Schema};
59use crate::results::analysis_results::compute_analysis_results;
60use crate::results::analysis_results::AnalysisResults;
61use crate::results::ComputeAnalysisResultsError;
62use crate::schema::{BitOrder, Condition, FieldDefinition};
63use crate::utils::analyze_utils::{
64 create_bit_reader, create_bit_writer, reverse_bits, size_estimate, BitReaderContainer,
65 BitWriterContainer,
66};
67use crate::utils::constants::CHILD_MARKER;
68use ahash::{AHashMap, HashMapExt};
69use bitstream_io::{BitRead, BitReader, BitWrite, Endianness};
70use rustc_hash::FxHashMap;
71use std::io::{Cursor, SeekFrom};
72use thiserror::Error;
73
74pub struct SchemaAnalyzer<'a> {
81 pub schema: &'a Schema,
83 pub entries: Vec<u8>,
85 pub field_states: AHashMap<String, AnalyzerFieldState>,
88 pub compression_options: CompressionOptions,
90}
91
92#[derive(Debug, Clone, Copy)]
95pub struct SizeEstimationParameters<'a> {
96 pub name: &'a str,
100 pub data_len: usize,
102 pub data: Option<&'a [u8]>,
105 pub num_lz_matches: usize,
107 pub entropy: f64,
109 pub lz_match_multiplier: f64,
111 pub entropy_multiplier: f64,
113}
114
115pub type SizeEstimatorFn = fn(SizeEstimationParameters) -> usize;
119
120#[derive(Debug, Clone, Copy)]
122pub struct CompressionOptions {
123 pub zstd_compression_level: i32,
127 pub size_estimator_fn: SizeEstimatorFn,
130 pub lz_match_multiplier: f64,
132 pub entropy_multiplier: f64,
134}
135
136impl Default for CompressionOptions {
137 fn default() -> Self {
138 Self {
139 zstd_compression_level: 16,
140 size_estimator_fn: size_estimate,
141 lz_match_multiplier: 0.0,
142 entropy_multiplier: 0.0,
143 }
144 }
145}
146
147impl CompressionOptions {
148 pub fn with_zstd_compression_level(mut self, level: i32) -> Self {
152 self.zstd_compression_level = level;
153 self
154 }
155
156 pub fn with_size_estimator_fn(mut self, estimator_fn: SizeEstimatorFn) -> Self {
160 self.size_estimator_fn = estimator_fn;
161 self
162 }
163}
164
165pub struct AnalyzerFieldState {
167 pub name: String,
169 pub full_path: String,
171 pub depth: usize,
173 pub count: u64,
175 pub lenbits: u32,
177 pub writer: BitWriterContainer,
180 pub bit_counts: Vec<BitStats>,
182 pub bit_order: BitOrder,
184 pub value_counts: FxHashMap<u64, u64>,
186}
187
188#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
189pub struct BitStats {
190 pub zeros: u64,
192 pub ones: u64,
194}
195
196#[derive(Debug, Error)]
198pub enum AnalysisError {
199 #[error("I/O error in add_entry reader during analysis. This is indicative of a bug in schema parsing or sanitization; and should normally not happen. Details: {0}")]
200 Io(#[from] std::io::Error),
201
202 #[error(
203 "Field '{0}' not found in Analyzer. This is indicative of a bug and should not happen."
204 )]
205 FieldNotFound(String),
206
207 #[error("Invalid entry length: expected {expected}, got {found}")]
208 InvalidEntryLength { expected: usize, found: usize },
209}
210
211impl<'a> SchemaAnalyzer<'a> {
212 pub fn new(schema: &'a Schema, options: CompressionOptions) -> Self {
222 Self {
223 schema,
224 entries: Vec::new(),
225 field_states: build_field_stats(&schema.root, "", 0, schema.bit_order),
226 compression_options: options,
227 }
228 }
229
230 pub fn add_entry(&mut self, entry: &[u8]) -> Result<(), AnalysisError> {
239 self.entries.extend_from_slice(entry);
240
241 if entry.len() * 8 < self.schema.root.bits as usize {
243 return Err(AnalysisError::InvalidEntryLength {
244 expected: self.schema.root.bits as usize,
245 found: self.entries.len() * 8,
246 });
247 }
248
249 let reader = create_bit_reader(entry, self.schema.bit_order);
250 match reader {
251 BitReaderContainer::Msb(mut bit_reader) => {
252 self.process_group(&self.schema.root, &mut bit_reader)
253 }
254 BitReaderContainer::Lsb(mut bit_reader) => {
255 self.process_group(&self.schema.root, &mut bit_reader)
256 }
257 }
258 }
259
260 fn process_group<TEndian: Endianness>(
261 &mut self,
262 group: &Group,
263 reader: &mut BitReader<Cursor<&[u8]>, TEndian>,
264 ) -> Result<(), AnalysisError> {
265 if should_skip(reader, &group.skip_if_not)? {
268 return Ok(());
269 }
270
271 for (name, field_def) in &group.fields {
272 match field_def {
273 FieldDefinition::Field(field) => {
274 if should_skip(reader, &field.skip_if_not)? {
276 continue;
277 }
278
279 let bits_left = field.bits;
280 let field_stats = self
281 .field_states
282 .get_mut(name)
283 .ok_or_else(|| AnalysisError::FieldNotFound(name.clone()))?;
284
285 process_field_or_group(
286 reader,
287 bits_left,
288 field_stats,
289 field.skip_frequency_analysis,
290 )?;
291 }
292 FieldDefinition::Group(child_group) => {
293 let bits_left = child_group.bits;
294 let field_stats = self
295 .field_states
296 .get_mut(name)
297 .ok_or_else(|| AnalysisError::FieldNotFound(name.clone()))?;
298
299 let current_offset = reader.position_in_bits()?;
301 process_field_or_group(
302 reader,
303 bits_left,
304 field_stats,
305 child_group.skip_frequency_analysis,
306 )?;
307 reader.seek_bits(SeekFrom::Start(current_offset))?;
308
309 self.process_group(child_group, reader)?;
311 }
312 }
313 }
314 Ok(())
315 }
316
317 pub fn generate_results(&mut self) -> Result<AnalysisResults, ComputeAnalysisResultsError> {
325 compute_analysis_results(self)
326 }
327}
328
329fn process_field_or_group<TEndian: Endianness>(
330 reader: &mut BitReader<Cursor<&[u8]>, TEndian>,
331 mut bit_count: u32,
332 field_stats: &mut AnalyzerFieldState,
333 skip_frequency_analysis: bool,
334) -> Result<(), AnalysisError> {
335 let writer = &mut field_stats.writer;
336 let can_bit_stats = bit_count <= 64;
338 let skip_count_values = bit_count > 16 || skip_frequency_analysis;
339
340 field_stats.count += 1;
341 while bit_count > 0 {
342 let max_bits = bit_count.min(64);
344 let bits = reader.read::<u64>(max_bits)?;
345
346 if !skip_count_values {
348 if field_stats.bit_order == BitOrder::Lsb {
349 let reversed_bits = reverse_bits(max_bits, bits);
350 *field_stats.value_counts.entry(reversed_bits).or_insert(0) += 1;
351 } else {
352 *field_stats.value_counts.entry(bits).or_insert(0) += 1;
353 }
354 }
355
356 match writer {
358 BitWriterContainer::Msb(w) => w.write(max_bits, bits)?,
359 BitWriterContainer::Lsb(w) => w.write(max_bits, bits)?,
360 }
361
362 if can_bit_stats {
364 for i in 0..max_bits {
365 let idx = i as usize;
366 let bit_value = (bits >> (max_bits - 1 - i)) & 1;
367 if bit_value == 0 {
368 field_stats.bit_counts[idx].zeros += 1;
369 } else {
370 field_stats.bit_counts[idx].ones += 1;
371 }
372 }
373 }
374
375 bit_count -= max_bits;
376 }
377
378 match writer {
380 BitWriterContainer::Msb(w) => w.flush()?,
381 BitWriterContainer::Lsb(w) => w.flush()?,
382 }
383
384 Ok(())
385}
386
387fn build_field_stats<'a>(
388 group: &'a Group,
389 parent_path: &'a str,
390 depth: usize,
391 file_bit_order: BitOrder,
392) -> AHashMap<String, AnalyzerFieldState> {
393 let mut stats = AHashMap::new();
394
395 for (name, field) in &group.fields {
396 let path = if parent_path.is_empty() {
397 name.clone()
398 } else {
399 format!("{}{CHILD_MARKER}{}", parent_path, name)
400 };
401
402 match field {
403 FieldDefinition::Field(field) => {
404 let writer = create_bit_writer(file_bit_order);
405 stats.insert(
406 name.clone(),
407 AnalyzerFieldState {
408 full_path: path,
409 depth,
410 lenbits: field.bits,
411 count: 0,
412 writer,
413 bit_counts: vec![BitStats::default(); clamp_bits(field.bits as usize)],
414 name: name.clone(),
415 bit_order: field.bit_order.get_with_default_resolve(),
416 value_counts: FxHashMap::new(),
417 },
418 );
419 }
420 FieldDefinition::Group(group) => {
421 let writer = create_bit_writer(file_bit_order);
422
423 stats.insert(
425 name.clone(),
426 AnalyzerFieldState {
427 full_path: path.clone(),
428 depth,
429 lenbits: group.bits,
430 count: 0,
431 writer,
432 bit_counts: vec![BitStats::default(); clamp_bits(group.bits as usize)],
433 name: name.clone(),
434 bit_order: group.bit_order.get_with_default_resolve(),
435 value_counts: FxHashMap::new(),
436 },
437 );
438
439 stats.extend(build_field_stats(group, &path, depth + 1, file_bit_order));
441 }
442 }
443 }
444
445 stats
446}
447
448#[inline]
450fn should_skip<TEndian: Endianness>(
451 reader: &mut BitReader<Cursor<&[u8]>, TEndian>,
452 conditions: &[Condition],
453) -> Result<bool, AnalysisError> {
454 if conditions.is_empty() {
456 return Ok(false);
457 }
458
459 let original_pos_bits = reader.position_in_bits()?;
460 for condition in conditions {
461 let offset = (condition.byte_offset * 8) + condition.bit_offset as u64;
462 let target_pos = original_pos_bits.wrapping_add(offset);
463
464 reader.seek_bits(SeekFrom::Start(target_pos))?;
465 let mut value = reader.read::<u64>(condition.bits as u32)?;
466
467 if condition.bit_order == BitOrder::Lsb {
468 value = reverse_bits(condition.bits as u32, value);
469 }
470
471 if value != condition.value {
472 reader.seek_bits(SeekFrom::Start(original_pos_bits))?;
473 return Ok(true);
474 }
475 }
476
477 reader.seek_bits(SeekFrom::Start(original_pos_bits))?;
478 Ok(false)
479}
480
481fn clamp_bits(bits: usize) -> usize {
482 if bits > 64 {
483 0
484 } else {
485 bits
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::schema::Schema;
493
494 fn create_test_schema() -> Schema {
495 let yaml = r###"
496version: '1.0'
497root:
498 type: group
499 fields:
500 id:
501 type: field
502 bits: 32
503 description: "ID field"
504 nested:
505 type: group
506 bit_order: lsb
507 fields:
508 value:
509 type: field
510 bits: 8
511 description: "Nested value"
512 "###;
513
514 Schema::from_yaml(yaml).expect("Failed to parse test schema")
515 }
516
517 #[test]
518 fn test_analyzer_initialization() {
519 let schema = create_test_schema();
520 let options = CompressionOptions::default();
521 let analyzer = SchemaAnalyzer::new(&schema, options);
522
523 assert_eq!(
525 analyzer.field_states.len(),
526 3,
527 "Should have stats for root group + 2 fields"
528 );
529 }
530
531 #[test]
532 fn test_big_endian_bitorder() -> Result<(), AnalysisError> {
533 let yaml = r###"
534version: '1.0'
535root:
536 type: group
537 fields:
538 flags:
539 type: field
540 bits: 2
541 bit_order: msb
542"###;
543 let schema = Schema::from_yaml(yaml).expect("Failed to parse test schema");
544 let options = CompressionOptions::default();
545 let mut analyzer = SchemaAnalyzer::new(&schema, options);
546
547 analyzer.add_entry(&[0b11000000])?; analyzer.add_entry(&[0b00000000])?; analyzer.add_entry(&[0b10000000])?; analyzer.add_entry(&[0b01000000])?; {
556 let flags_field = analyzer
557 .field_states
558 .get_mut("flags")
559 .ok_or(AnalysisError::FieldNotFound("flags".to_string()))?;
560 assert_eq!(flags_field.count, 4, "Should process 4 entries");
561 assert_eq!(
562 flags_field.bit_counts.len(),
563 2,
564 "Should track 2 bits per field"
565 );
566
567 let writer = match &mut flags_field.writer {
569 BitWriterContainer::Msb(value) => value,
570 _ => panic!("Expected MSB variant"),
571 };
572 writer.byte_align()?;
573 writer.flush()?;
574 let inner_writer = writer.writer().unwrap();
575 let data = inner_writer.get_ref();
576 assert_eq!(data[0], 0xC9_u8, "Combined bits should form 0xC9");
577
578 let expected_counts =
580 FxHashMap::from_iter([(0b11, 1), (0b00, 1), (0b10, 1), (0b01, 1)]);
581 assert_eq!(
582 flags_field.value_counts, expected_counts,
583 "Value counts should match"
584 );
585
586 for (x, stats) in flags_field.bit_counts.iter().enumerate() {
588 assert_eq!(
589 stats.zeros, 2,
590 "Bit {} should have 2 zeros (actual: {})",
591 x, stats.zeros
592 );
593 assert_eq!(
594 stats.ones, 2,
595 "Bit {} should have 2 ones (actual: {})",
596 x, stats.ones
597 );
598 }
599 }
600
601 analyzer.add_entry(&[0b01000000])?; let flags_field = analyzer
604 .field_states
605 .get_mut("flags")
606 .ok_or(AnalysisError::FieldNotFound("flags".to_string()))?;
607 let expected_counts = FxHashMap::from_iter([(0b11, 1), (0b00, 1), (0b10, 1), (0b01, 2)]);
608 assert_eq!(
609 flags_field.value_counts, expected_counts,
610 "Value counts should match"
611 );
612 Ok(())
613 }
614
615 #[test]
616 fn test_little_endian_bitorder() {
617 let yaml = r###"
618version: '1.0'
619root:
620 type: group
621 fields:
622 flags:
623 type: field
624 bits: 2
625 bit_order: lsb
626"###;
627 let schema = Schema::from_yaml(yaml).expect("Failed to parse test schema");
628 let options = CompressionOptions::default();
629 let mut analyzer = SchemaAnalyzer::new(&schema, options);
630
631 analyzer.add_entry(&[0b11000000]).unwrap(); analyzer.add_entry(&[0b00000000]).unwrap(); analyzer.add_entry(&[0b10000000]).unwrap(); analyzer.add_entry(&[0b01000000]).unwrap(); analyzer.add_entry(&[0b10000000]).unwrap(); let flags_field = analyzer.field_states.get_mut("flags").unwrap();
642 let expected_counts = FxHashMap::from_iter([(0b11, 1), (0b00, 1), (0b10, 1), (0b01, 2)]);
643 assert_eq!(
644 flags_field.value_counts, expected_counts,
645 "Value counts should match"
646 );
647 }
648
649 #[test]
650 fn test_field_stats_structure() {
651 let schema = create_test_schema();
652 let options = CompressionOptions::default();
653 let analyzer = SchemaAnalyzer::new(&schema, options);
654
655 let root_group = analyzer.field_states.get("id").unwrap();
657 assert_eq!(root_group.name, "id");
658 assert_eq!(root_group.full_path, "id");
659 assert_eq!(root_group.depth, 0);
660 assert_eq!(root_group.count, 0);
661 assert_eq!(root_group.lenbits, 32);
662 assert_eq!(root_group.bit_counts.len(), root_group.lenbits as usize);
663 assert_eq!(root_group.bit_order, BitOrder::Msb);
664
665 let id_field = analyzer.field_states.get("nested").unwrap();
666 assert_eq!(id_field.full_path, "nested");
667 assert_eq!(id_field.name, "nested");
668 assert_eq!(id_field.depth, 0);
669 assert_eq!(id_field.count, 0);
670 assert_eq!(id_field.lenbits, 8);
671 assert_eq!(id_field.bit_counts.len(), id_field.lenbits as usize);
672 assert_eq!(id_field.bit_order, BitOrder::Lsb);
673
674 let nested_value = analyzer.field_states.get("value").unwrap();
675 assert_eq!(nested_value.full_path, "nested.value");
676 assert_eq!(nested_value.name, "value");
677 assert_eq!(nested_value.depth, 1);
678 assert_eq!(nested_value.count, 0);
679 assert_eq!(nested_value.lenbits, 8);
680 assert_eq!(nested_value.bit_counts.len(), nested_value.lenbits as usize);
681 assert_eq!(nested_value.bit_order, BitOrder::Lsb); }
683
684 #[test]
685 fn skips_group_based_on_conditions() {
686 let yaml = r#"
687version: '1.0'
688root:
689 type: group
690 skip_if_not:
691 - byte_offset: 0
692 bit_offset: 0
693 bits: 8
694 value: 0x55
695 fields:
696 dummy: 8
697"#;
698 let schema = Schema::from_yaml(yaml).unwrap();
699 let options = CompressionOptions::default();
700 let mut analyzer = SchemaAnalyzer::new(&schema, options);
701
702 analyzer.add_entry(&[0x55]).unwrap();
704 assert_eq!(analyzer.field_states.get("dummy").unwrap().count, 1);
705
706 analyzer.add_entry(&[0xAA]).unwrap();
708 assert_eq!(analyzer.field_states.get("dummy").unwrap().count, 1);
709
710 analyzer.add_entry(&[0x55]).unwrap();
712 assert_eq!(analyzer.field_states.get("dummy").unwrap().count, 2);
713 }
714
715 #[test]
716 fn skips_field_based_on_conditions() {
717 let yaml = r#"
718version: '1.0'
719root:
720 type: group
721 fields:
722 header:
723 type: field
724 bits: 7
725 skip_if_not:
726 - byte_offset: 0
727 bit_offset: 0
728 bits: 1
729 value: 1
730"#;
731 let schema = Schema::from_yaml(yaml).unwrap();
732 let options = CompressionOptions::default();
733 let mut analyzer = SchemaAnalyzer::new(&schema, options);
734
735 analyzer.add_entry(&[0b10000000]).unwrap();
737 assert_eq!(analyzer.field_states.get("header").unwrap().count, 1);
738
739 analyzer.add_entry(&[0b00000000]).unwrap();
741 assert_eq!(analyzer.field_states.get("header").unwrap().count, 1);
742 }
743
744 #[test]
745 fn test_builder() {
746 let options = CompressionOptions::default().with_zstd_compression_level(7);
747 assert_eq!(options.zstd_compression_level, 7);
748
749 let options = CompressionOptions::default();
750 assert_eq!(options.zstd_compression_level, 16); }
752}