struct_compression_analyzer/comparison/compare_groups/generate_bytes/
write_array.rs

1use super::{GenerateBytesError, GenerateBytesResult};
2use crate::utils::{
3    analyze_utils::{get_writer_buffer, BitWriterContainer},
4    bitstream_ext::BitReaderExt,
5};
6use crate::{analyzer::AnalyzerFieldState, schema::GroupComponentArray};
7use ahash::AHashMap;
8use bitstream_io::{BigEndian, BitRead, BitReader, BitWrite, BitWriter, Endianness, LittleEndian};
9use std::io::{self, Cursor, SeekFrom};
10
11/// Processes an [`GroupComponentArray`], writing its output to a
12/// provided [`BitWriter`].
13///
14/// # Arguments
15/// * `bit_order` - The bit order to use when reading the data.
16///   This is normally inherited from the schema root.
17/// * `field_stats` - A mutable reference to a map of field stats.
18/// * `writer` - The bit writer to write the array to.
19/// * `array` - Contains info about the array to write.
20pub(crate) fn write_array<TWrite: io::Write, TEndian: Endianness>(
21    field_stats: &mut AHashMap<String, AnalyzerFieldState>,
22    writer: &mut BitWriter<TWrite, TEndian>,
23    array: &GroupComponentArray,
24) -> GenerateBytesResult<()> {
25    let field = field_stats
26        .get_mut(&array.field)
27        .ok_or_else(|| GenerateBytesError::FieldNotFound(array.field.clone()))?;
28
29    let bits: u32 = array.get_bits(field);
30    let offset = array.offset;
31    let field_len = field.lenbits;
32    match &field.writer {
33        BitWriterContainer::Msb(_) => {
34            let bytes = get_writer_buffer(&mut field.writer);
35            let mut reader = BitReader::endian(Cursor::new(bytes), BigEndian);
36            write_array_inner(&mut reader, bits, offset, field_len, writer)
37        }
38        BitWriterContainer::Lsb(_) => {
39            let bytes = get_writer_buffer(&mut field.writer);
40            let mut reader = BitReader::endian(Cursor::new(bytes), LittleEndian);
41            write_array_inner(&mut reader, bits, offset, field_len, writer)
42        }
43    }
44}
45
46/// Processes an array component by reading bits from a field's stored data
47/// and writing them to the output writer according to array configuration.
48///
49/// Handles both MSB and LSB bit orders by creating appropriate readers
50/// from the field's stored bitstream data.
51fn write_array_inner<
52    TWrite: io::Write,
53    TEndian: Endianness,
54    TReader: io::Read + io::Seek,
55    TReaderEndian: Endianness,
56>(
57    reader: &mut BitReader<TReader, TReaderEndian>,
58    bits: u32,
59    offset: u32,
60    field_len: u32,
61    writer: &mut BitWriter<TWrite, TEndian>,
62) -> GenerateBytesResult<()> {
63    // Loop until we run out of bits in the source field data
64    loop {
65        // Calculate ending position before reading to maintain alignment
66        let ending_pos = reader
67            .position_in_bits()
68            .map_err(|e| GenerateBytesError::SeekError {
69                source: e,
70                operation: "getting array position".into(),
71            })?
72            + field_len as u64;
73
74        // Check remaining bits before attempting read
75        let remaining = reader
76            .remaining_bits()
77            .map_err(|e| GenerateBytesError::SeekError {
78                source: e,
79                operation: "checking remaining bits".into(),
80            })?;
81
82        if remaining < field_len as u64 {
83            return Ok(());
84        }
85
86        // Seek to the array element offset
87        reader
88            .seek_bits(SeekFrom::Current(offset as i64))
89            .map_err(|e| GenerateBytesError::SeekError {
90                source: e,
91                operation: format!("seeking to array offset {}", offset),
92            })?;
93
94        // Read the actual value from the source bitstream
95        let value = reader
96            .read::<u64>(bits)
97            .map_err(|e| GenerateBytesError::ReadError {
98                source: e,
99                context: format!("reading {bits}-bit array element"),
100            })?;
101
102        // Write the value to the output stream
103        writer
104            .write::<u64>(bits, value)
105            .map_err(|e| GenerateBytesError::WriteError {
106                source: e,
107                context: format!("writing {bits}-bit array element"),
108            })?;
109
110        // Return to calculated end position for next iteration
111        reader.seek_bits(SeekFrom::Start(ending_pos)).map_err(|e| {
112            GenerateBytesError::SeekError {
113                source: e,
114                operation: format!("seeking to array end position {}", ending_pos),
115            }
116        })?;
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::comparison::compare_groups::test_helpers::create_mock_field_states;
124    use crate::comparison::compare_groups::test_helpers::TEST_FIELD_NAME;
125    use crate::schema::default_entropy_multiplier;
126    use crate::schema::default_lz_match_multiplier;
127    use crate::schema::BitOrder;
128    use bitstream_io::BitWriter;
129    use std::io::Cursor;
130
131    fn test_array_group_component(offset: u32, bits: u32) -> GroupComponentArray {
132        GroupComponentArray {
133            field: TEST_FIELD_NAME.to_string(),
134            offset,
135            bits,
136            lz_match_multiplier: default_lz_match_multiplier(),
137            entropy_multiplier: default_entropy_multiplier(),
138        }
139    }
140
141    #[test]
142    fn can_round_trip_lsb() {
143        // Binary fields: LSB first (rightmost bit is first)
144        let input_data = [
145            0b0010_0001, // 1, 2
146            0b1000_0100, // 4, 8
147        ];
148        let mut field_stats = create_mock_field_states(
149            TEST_FIELD_NAME,
150            &input_data,
151            4,
152            BitOrder::Lsb,
153            BitOrder::Lsb,
154        );
155        let mut output = Vec::new();
156
157        // Write using LSB
158        let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
159        write_array(
160            &mut field_stats,
161            &mut writer,
162            &test_array_group_component(0, 4), // inherit bit count from field
163        )
164        .unwrap();
165
166        // Read back written data
167        assert_eq!(input_data, output.as_slice());
168    }
169
170    #[test]
171    fn can_round_trip_msb() {
172        // Binary fields: MSB first (rightmost bit is first)
173        let input_data = [
174            0b0001_0010, // 1, 2
175            0b0100_1000, // 4, 8
176        ];
177        let mut field_stats = create_mock_field_states(
178            TEST_FIELD_NAME,
179            &input_data,
180            4,
181            BitOrder::Msb,
182            BitOrder::Msb,
183        );
184        let mut output = Vec::new();
185
186        // Write using MSB
187        let mut writer = BitWriter::endian(Cursor::new(&mut output), BigEndian);
188        write_array(
189            &mut field_stats,
190            &mut writer,
191            &test_array_group_component(0, 0), // inherit bit count from field
192        )
193        .unwrap();
194
195        // Read back written data
196        assert_eq!(input_data, output.as_slice());
197    }
198
199    #[test]
200    fn can_read_slices() {
201        // Binary fields: LSB first (rightmost bit is first)
202        let input_data = [
203            0b0010_1101, // 00, 11
204            0b0000_1100, // 00, 11
205        ];
206        // Note: Regardless of the slice read however, after each read is done, the stream will be advanced to the
207        // next field.
208        let expected_output = [
209            0b00_11_00_11, // 00, 11, 00, 11
210        ]; // rest is dropped, because we offset by 2
211
212        let mut field_stats = create_mock_field_states(
213            TEST_FIELD_NAME,
214            &input_data,
215            4,
216            BitOrder::Lsb,
217            BitOrder::Lsb,
218        );
219        let mut output = Vec::new();
220
221        // Write using LSB
222        let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
223        write_array(
224            &mut field_stats,
225            &mut writer,
226            &test_array_group_component(2, 2), // only upper 2 bits.
227        )
228        .unwrap();
229
230        // Read back written data
231        assert_eq!(expected_output, output.as_slice());
232    }
233}