struct_compression_analyzer/comparison/compare_groups/generate_bytes/
write_struct.rs

1use super::{GenerateBytesError, GenerateBytesResult};
2use crate::{
3    analyzer::AnalyzerFieldState,
4    schema::{GroupComponent, GroupComponentStruct},
5    utils::analyze_utils::{bit_writer_to_reader, BitReaderContainer},
6};
7use ahash::AHashMap;
8use bitstream_io::{BitWrite, BitWriter, Endianness};
9use core::cell::UnsafeCell;
10use std::io::{self};
11
12/// Processes an [`GroupComponentStruct`], writing its output to a
13/// provided [`BitWriter`].
14///
15/// # Arguments
16/// * `field_states` - A mutable reference to a map of field stats.
17/// * `writer` - The bit writer to write the array to.
18/// * `array` - Contains info about the array to write.
19pub(crate) fn write_struct<TWrite: io::Write, TEndian: Endianness>(
20    field_states: &mut AHashMap<String, AnalyzerFieldState>,
21    writer: &mut BitWriter<TWrite, TEndian>,
22    strct_ref: &GroupComponentStruct,
23) -> GenerateBytesResult<()> {
24    // Clone the struct definition to avoid mutating the original
25    let mut strct = strct_ref.clone();
26    let field_states_unsafe = UnsafeCell::new(field_states);
27
28    // Map field names to their bitstream readers
29    let mut field_readers = AHashMap::<String, BitReaderContainer>::new();
30
31    // Initialize readers for all fields used in the struct
32    for field in &mut strct.fields {
33        let field_name = match field {
34            GroupComponent::Array(_) | GroupComponent::Struct(_) => {
35                return Err(GenerateBytesError::UnsupportedNestedComponent)
36            }
37            GroupComponent::Field(field) => Some(field.field.clone()),
38            GroupComponent::Skip(skip) => Some(skip.field.clone()),
39            GroupComponent::Padding(_) => None,
40        };
41
42        if let Some(field_name) = field_name {
43            let field_states = unsafe { (*field_states_unsafe.get()).get_mut(&field_name) }
44                .ok_or_else(|| GenerateBytesError::FieldNotFound(field_name.clone()))?;
45
46            // Convert field's writer to a reader for reading stored bits
47            field_readers.insert(
48                field_name.clone(),
49                bit_writer_to_reader(&mut field_states.writer),
50            );
51
52            // Set default bits if not specified in schema
53            if let GroupComponent::Field(field) = field {
54                field.set_bits(field_states.lenbits);
55            };
56        }
57    }
58
59    // Process struct components in a loop until no more data
60    loop {
61        let mut read_anything = false;
62
63        for field in &strct.fields {
64            match field {
65                GroupComponent::Array(_) | GroupComponent::Struct(_) => {
66                    return Err(GenerateBytesError::UnsupportedNestedComponent)
67                }
68                GroupComponent::Padding(padding) => {
69                    writer
70                        .write(padding.bits as u32, padding.value)
71                        .map_err(|e| GenerateBytesError::WriteError {
72                            source: e,
73                            context: "writing padding bits".into(),
74                        })?;
75                }
76                GroupComponent::Field(field) => {
77                    let reader = field_readers
78                        .get_mut(&field.field)
79                        .ok_or_else(|| GenerateBytesError::FieldNotFound(field.field.clone()))?;
80
81                    // Attempt read from source field
82                    let read_result = reader.read(field.bits);
83                    match read_result {
84                        Ok(value) => {
85                            // Only write if we successfully read the value
86                            writer.write(field.bits, value).map_err(|e| {
87                                GenerateBytesError::WriteError {
88                                    source: e,
89                                    context: format!(
90                                        "writing {}-bit field '{}'",
91                                        field.bits, field.field
92                                    ),
93                                }
94                            })?;
95                            read_anything = true;
96                        }
97                        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
98                            // Field is exhausted, continue processing other components
99                        }
100                        Err(e) => {
101                            return Err(GenerateBytesError::ReadError {
102                                source: e,
103                                context: format!(
104                                    "reading {}-bit field '{}'",
105                                    field.bits, field.field
106                                ),
107                            })
108                        }
109                    }
110                }
111                GroupComponent::Skip(skip) => {
112                    let reader = field_readers
113                        .get_mut(&skip.field)
114                        .ok_or_else(|| GenerateBytesError::FieldNotFound(skip.field.clone()))?;
115
116                    // Attempt seek operation
117                    let seek_result = reader.seek_bits(io::SeekFrom::Current(skip.bits as i64));
118                    match seek_result {
119                        Ok(_) => read_anything = true,
120                        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
121                            // Field is exhausted, continue processing other components
122                        }
123                        Err(e) => {
124                            return Err(GenerateBytesError::SeekError {
125                                source: e,
126                                operation: format!(
127                                    "skipping {} bits in field '{}'",
128                                    skip.bits, skip.field
129                                ),
130                            })
131                        }
132                    }
133                }
134            }
135        }
136
137        if !read_anything {
138            return Ok(());
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::comparison::compare_groups::test_helpers::create_mock_field_states;
147    use crate::comparison::compare_groups::test_helpers::TEST_FIELD_NAME;
148    use crate::schema::default_entropy_multiplier;
149    use crate::schema::default_lz_match_multiplier;
150    use crate::schema::BitOrder;
151    use crate::schema::GroupComponentField;
152    use crate::schema::GroupComponentPadding;
153    use crate::schema::GroupComponentSkip;
154    use bitstream_io::{BigEndian, BitWriter, LittleEndian};
155    use std::io::Cursor;
156
157    fn single_field_struct_group_component(bits: u32) -> GroupComponentStruct {
158        GroupComponentStruct {
159            fields: vec![GroupComponent::Field(GroupComponentField {
160                field: TEST_FIELD_NAME.to_string(),
161                bits,
162            })],
163            lz_match_multiplier: default_lz_match_multiplier(),
164            entropy_multiplier: default_entropy_multiplier(),
165        }
166    }
167
168    #[test]
169    fn field_can_round_trip_lsb() {
170        let input_data = [
171            0b0010_0001, // 1, 2
172            0b1000_0100, // 4, 8
173        ];
174        let mut field_states = create_mock_field_states(
175            TEST_FIELD_NAME,
176            &input_data,
177            4,
178            BitOrder::Lsb,
179            BitOrder::Lsb,
180        );
181        let mut output = Vec::new();
182
183        let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
184        write_struct(
185            &mut field_states,
186            &mut writer,
187            &single_field_struct_group_component(0), // inherit from field
188        )
189        .unwrap();
190
191        assert_eq!(input_data, output.as_slice());
192    }
193
194    #[test]
195    fn field_can_round_trip_msb() {
196        let input_data = [
197            0b0001_0010, // 1, 2
198            0b0100_1000, // 4, 8
199        ];
200        let mut field_states = create_mock_field_states(
201            TEST_FIELD_NAME,
202            &input_data,
203            4,
204            BitOrder::Msb,
205            BitOrder::Msb,
206        );
207        let mut output = Vec::new();
208
209        let mut writer = BitWriter::endian(Cursor::new(&mut output), BigEndian);
210        write_struct(
211            &mut field_states,
212            &mut writer,
213            &single_field_struct_group_component(0), // inherit from field
214        )
215        .unwrap();
216
217        assert_eq!(input_data, output.as_slice());
218    }
219
220    #[test]
221    fn field_can_read_slices_with_skip() {
222        let input_data = [
223            0b0010_1101, // 00, 11
224            0b0000_1100, // 00, 11
225        ];
226
227        let expected_output = [0b00_11_00_11];
228
229        let mut field_states = create_mock_field_states(
230            TEST_FIELD_NAME,
231            &input_data,
232            4,
233            BitOrder::Lsb,
234            BitOrder::Lsb,
235        );
236        let mut output = Vec::new();
237
238        let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
239        write_struct(
240            &mut field_states,
241            &mut writer,
242            &GroupComponentStruct {
243                fields: vec![
244                    GroupComponent::Skip(GroupComponentSkip {
245                        field: TEST_FIELD_NAME.to_string(),
246                        bits: 2, // skip 2 bits
247                    }),
248                    GroupComponent::Field(GroupComponentField {
249                        field: TEST_FIELD_NAME.to_string(),
250                        bits: 2, // read 2 bits
251                    }),
252                ],
253                lz_match_multiplier: default_lz_match_multiplier(),
254                entropy_multiplier: default_entropy_multiplier(),
255            },
256        )
257        .unwrap();
258
259        assert_eq!(expected_output, output.as_slice());
260    }
261
262    #[test]
263    fn padding_writes_correct_bits_lsb() {
264        let mut field_states =
265            create_mock_field_states(TEST_FIELD_NAME, &[], 0, BitOrder::Lsb, BitOrder::Lsb);
266        let mut output = Vec::new();
267        let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
268        write_struct(
269            &mut field_states,
270            &mut writer,
271            &GroupComponentStruct {
272                fields: vec![GroupComponent::Padding(GroupComponentPadding {
273                    bits: 4,
274                    value: 0b1010,
275                })],
276                lz_match_multiplier: default_lz_match_multiplier(),
277                entropy_multiplier: default_entropy_multiplier(),
278            },
279        )
280        .unwrap();
281        writer.byte_align().unwrap();
282        writer.flush().unwrap();
283        assert_eq!(output, [0b0000_1010]);
284    }
285
286    #[test]
287    fn padding_writes_correct_bits_msb() {
288        let mut field_states =
289            create_mock_field_states(TEST_FIELD_NAME, &[], 0, BitOrder::Msb, BitOrder::Msb);
290        let mut output = Vec::new();
291        let mut writer = BitWriter::endian(Cursor::new(&mut output), BigEndian);
292        write_struct(
293            &mut field_states,
294            &mut writer,
295            &GroupComponentStruct {
296                fields: vec![GroupComponent::Padding(GroupComponentPadding {
297                    bits: 4,
298                    value: 0b1010,
299                })],
300                lz_match_multiplier: default_lz_match_multiplier(),
301                entropy_multiplier: default_entropy_multiplier(),
302            },
303        )
304        .unwrap();
305        writer.byte_align().unwrap();
306        writer.flush().unwrap();
307        assert_eq!(output, [0b1010_0000]);
308    }
309}