struct_compression_analyzer/comparison/compare_groups/generate_bytes/
write_array.rs1use super::{GenerateBytesError, GenerateBytesResult};
2use crate::utils::{
3 analyze_utils::{get_writer_buffer, BitWriterContainer},
4 bitstream_ext::BitReaderExt,
5};
6use crate::{analyzer::AnalyzerFieldState, schema::GroupComponentArray};
7use ahash::AHashMap;
8use bitstream_io::{BigEndian, BitRead, BitReader, BitWrite, BitWriter, Endianness, LittleEndian};
9use std::io::{self, Cursor, SeekFrom};
10
11pub(crate) fn write_array<TWrite: io::Write, TEndian: Endianness>(
21 field_stats: &mut AHashMap<String, AnalyzerFieldState>,
22 writer: &mut BitWriter<TWrite, TEndian>,
23 array: &GroupComponentArray,
24) -> GenerateBytesResult<()> {
25 let field = field_stats
26 .get_mut(&array.field)
27 .ok_or_else(|| GenerateBytesError::FieldNotFound(array.field.clone()))?;
28
29 let bits: u32 = array.get_bits(field);
30 let offset = array.offset;
31 let field_len = field.lenbits;
32 match &field.writer {
33 BitWriterContainer::Msb(_) => {
34 let bytes = get_writer_buffer(&mut field.writer);
35 let mut reader = BitReader::endian(Cursor::new(bytes), BigEndian);
36 write_array_inner(&mut reader, bits, offset, field_len, writer)
37 }
38 BitWriterContainer::Lsb(_) => {
39 let bytes = get_writer_buffer(&mut field.writer);
40 let mut reader = BitReader::endian(Cursor::new(bytes), LittleEndian);
41 write_array_inner(&mut reader, bits, offset, field_len, writer)
42 }
43 }
44}
45
46fn write_array_inner<
52 TWrite: io::Write,
53 TEndian: Endianness,
54 TReader: io::Read + io::Seek,
55 TReaderEndian: Endianness,
56>(
57 reader: &mut BitReader<TReader, TReaderEndian>,
58 bits: u32,
59 offset: u32,
60 field_len: u32,
61 writer: &mut BitWriter<TWrite, TEndian>,
62) -> GenerateBytesResult<()> {
63 loop {
65 let ending_pos = reader
67 .position_in_bits()
68 .map_err(|e| GenerateBytesError::SeekError {
69 source: e,
70 operation: "getting array position".into(),
71 })?
72 + field_len as u64;
73
74 let remaining = reader
76 .remaining_bits()
77 .map_err(|e| GenerateBytesError::SeekError {
78 source: e,
79 operation: "checking remaining bits".into(),
80 })?;
81
82 if remaining < field_len as u64 {
83 return Ok(());
84 }
85
86 reader
88 .seek_bits(SeekFrom::Current(offset as i64))
89 .map_err(|e| GenerateBytesError::SeekError {
90 source: e,
91 operation: format!("seeking to array offset {}", offset),
92 })?;
93
94 let value = reader
96 .read::<u64>(bits)
97 .map_err(|e| GenerateBytesError::ReadError {
98 source: e,
99 context: format!("reading {bits}-bit array element"),
100 })?;
101
102 writer
104 .write::<u64>(bits, value)
105 .map_err(|e| GenerateBytesError::WriteError {
106 source: e,
107 context: format!("writing {bits}-bit array element"),
108 })?;
109
110 reader.seek_bits(SeekFrom::Start(ending_pos)).map_err(|e| {
112 GenerateBytesError::SeekError {
113 source: e,
114 operation: format!("seeking to array end position {}", ending_pos),
115 }
116 })?;
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::comparison::compare_groups::test_helpers::create_mock_field_states;
124 use crate::comparison::compare_groups::test_helpers::TEST_FIELD_NAME;
125 use crate::schema::default_entropy_multiplier;
126 use crate::schema::default_lz_match_multiplier;
127 use crate::schema::BitOrder;
128 use bitstream_io::BitWriter;
129 use std::io::Cursor;
130
131 fn test_array_group_component(offset: u32, bits: u32) -> GroupComponentArray {
132 GroupComponentArray {
133 field: TEST_FIELD_NAME.to_string(),
134 offset,
135 bits,
136 lz_match_multiplier: default_lz_match_multiplier(),
137 entropy_multiplier: default_entropy_multiplier(),
138 }
139 }
140
141 #[test]
142 fn can_round_trip_lsb() {
143 let input_data = [
145 0b0010_0001, 0b1000_0100, ];
148 let mut field_stats = create_mock_field_states(
149 TEST_FIELD_NAME,
150 &input_data,
151 4,
152 BitOrder::Lsb,
153 BitOrder::Lsb,
154 );
155 let mut output = Vec::new();
156
157 let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
159 write_array(
160 &mut field_stats,
161 &mut writer,
162 &test_array_group_component(0, 4), )
164 .unwrap();
165
166 assert_eq!(input_data, output.as_slice());
168 }
169
170 #[test]
171 fn can_round_trip_msb() {
172 let input_data = [
174 0b0001_0010, 0b0100_1000, ];
177 let mut field_stats = create_mock_field_states(
178 TEST_FIELD_NAME,
179 &input_data,
180 4,
181 BitOrder::Msb,
182 BitOrder::Msb,
183 );
184 let mut output = Vec::new();
185
186 let mut writer = BitWriter::endian(Cursor::new(&mut output), BigEndian);
188 write_array(
189 &mut field_stats,
190 &mut writer,
191 &test_array_group_component(0, 0), )
193 .unwrap();
194
195 assert_eq!(input_data, output.as_slice());
197 }
198
199 #[test]
200 fn can_read_slices() {
201 let input_data = [
203 0b0010_1101, 0b0000_1100, ];
206 let expected_output = [
209 0b00_11_00_11, ]; let mut field_stats = create_mock_field_states(
213 TEST_FIELD_NAME,
214 &input_data,
215 4,
216 BitOrder::Lsb,
217 BitOrder::Lsb,
218 );
219 let mut output = Vec::new();
220
221 let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
223 write_array(
224 &mut field_stats,
225 &mut writer,
226 &test_array_group_component(2, 2), )
228 .unwrap();
229
230 assert_eq!(expected_output, output.as_slice());
232 }
233}