struct_compression_analyzer/comparison/compare_groups/generate_bytes/
write_struct.rs1use super::{GenerateBytesError, GenerateBytesResult};
2use crate::{
3 analyzer::AnalyzerFieldState,
4 schema::{GroupComponent, GroupComponentStruct},
5 utils::analyze_utils::{bit_writer_to_reader, BitReaderContainer},
6};
7use ahash::AHashMap;
8use bitstream_io::{BitWrite, BitWriter, Endianness};
9use core::cell::UnsafeCell;
10use std::io::{self};
11
12pub(crate) fn write_struct<TWrite: io::Write, TEndian: Endianness>(
20 field_states: &mut AHashMap<String, AnalyzerFieldState>,
21 writer: &mut BitWriter<TWrite, TEndian>,
22 strct_ref: &GroupComponentStruct,
23) -> GenerateBytesResult<()> {
24 let mut strct = strct_ref.clone();
26 let field_states_unsafe = UnsafeCell::new(field_states);
27
28 let mut field_readers = AHashMap::<String, BitReaderContainer>::new();
30
31 for field in &mut strct.fields {
33 let field_name = match field {
34 GroupComponent::Array(_) | GroupComponent::Struct(_) => {
35 return Err(GenerateBytesError::UnsupportedNestedComponent)
36 }
37 GroupComponent::Field(field) => Some(field.field.clone()),
38 GroupComponent::Skip(skip) => Some(skip.field.clone()),
39 GroupComponent::Padding(_) => None,
40 };
41
42 if let Some(field_name) = field_name {
43 let field_states = unsafe { (*field_states_unsafe.get()).get_mut(&field_name) }
44 .ok_or_else(|| GenerateBytesError::FieldNotFound(field_name.clone()))?;
45
46 field_readers.insert(
48 field_name.clone(),
49 bit_writer_to_reader(&mut field_states.writer),
50 );
51
52 if let GroupComponent::Field(field) = field {
54 field.set_bits(field_states.lenbits);
55 };
56 }
57 }
58
59 loop {
61 let mut read_anything = false;
62
63 for field in &strct.fields {
64 match field {
65 GroupComponent::Array(_) | GroupComponent::Struct(_) => {
66 return Err(GenerateBytesError::UnsupportedNestedComponent)
67 }
68 GroupComponent::Padding(padding) => {
69 writer
70 .write(padding.bits as u32, padding.value)
71 .map_err(|e| GenerateBytesError::WriteError {
72 source: e,
73 context: "writing padding bits".into(),
74 })?;
75 }
76 GroupComponent::Field(field) => {
77 let reader = field_readers
78 .get_mut(&field.field)
79 .ok_or_else(|| GenerateBytesError::FieldNotFound(field.field.clone()))?;
80
81 let read_result = reader.read(field.bits);
83 match read_result {
84 Ok(value) => {
85 writer.write(field.bits, value).map_err(|e| {
87 GenerateBytesError::WriteError {
88 source: e,
89 context: format!(
90 "writing {}-bit field '{}'",
91 field.bits, field.field
92 ),
93 }
94 })?;
95 read_anything = true;
96 }
97 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
98 }
100 Err(e) => {
101 return Err(GenerateBytesError::ReadError {
102 source: e,
103 context: format!(
104 "reading {}-bit field '{}'",
105 field.bits, field.field
106 ),
107 })
108 }
109 }
110 }
111 GroupComponent::Skip(skip) => {
112 let reader = field_readers
113 .get_mut(&skip.field)
114 .ok_or_else(|| GenerateBytesError::FieldNotFound(skip.field.clone()))?;
115
116 let seek_result = reader.seek_bits(io::SeekFrom::Current(skip.bits as i64));
118 match seek_result {
119 Ok(_) => read_anything = true,
120 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
121 }
123 Err(e) => {
124 return Err(GenerateBytesError::SeekError {
125 source: e,
126 operation: format!(
127 "skipping {} bits in field '{}'",
128 skip.bits, skip.field
129 ),
130 })
131 }
132 }
133 }
134 }
135 }
136
137 if !read_anything {
138 return Ok(());
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::comparison::compare_groups::test_helpers::create_mock_field_states;
147 use crate::comparison::compare_groups::test_helpers::TEST_FIELD_NAME;
148 use crate::schema::default_entropy_multiplier;
149 use crate::schema::default_lz_match_multiplier;
150 use crate::schema::BitOrder;
151 use crate::schema::GroupComponentField;
152 use crate::schema::GroupComponentPadding;
153 use crate::schema::GroupComponentSkip;
154 use bitstream_io::{BigEndian, BitWriter, LittleEndian};
155 use std::io::Cursor;
156
157 fn single_field_struct_group_component(bits: u32) -> GroupComponentStruct {
158 GroupComponentStruct {
159 fields: vec![GroupComponent::Field(GroupComponentField {
160 field: TEST_FIELD_NAME.to_string(),
161 bits,
162 })],
163 lz_match_multiplier: default_lz_match_multiplier(),
164 entropy_multiplier: default_entropy_multiplier(),
165 }
166 }
167
168 #[test]
169 fn field_can_round_trip_lsb() {
170 let input_data = [
171 0b0010_0001, 0b1000_0100, ];
174 let mut field_states = create_mock_field_states(
175 TEST_FIELD_NAME,
176 &input_data,
177 4,
178 BitOrder::Lsb,
179 BitOrder::Lsb,
180 );
181 let mut output = Vec::new();
182
183 let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
184 write_struct(
185 &mut field_states,
186 &mut writer,
187 &single_field_struct_group_component(0), )
189 .unwrap();
190
191 assert_eq!(input_data, output.as_slice());
192 }
193
194 #[test]
195 fn field_can_round_trip_msb() {
196 let input_data = [
197 0b0001_0010, 0b0100_1000, ];
200 let mut field_states = create_mock_field_states(
201 TEST_FIELD_NAME,
202 &input_data,
203 4,
204 BitOrder::Msb,
205 BitOrder::Msb,
206 );
207 let mut output = Vec::new();
208
209 let mut writer = BitWriter::endian(Cursor::new(&mut output), BigEndian);
210 write_struct(
211 &mut field_states,
212 &mut writer,
213 &single_field_struct_group_component(0), )
215 .unwrap();
216
217 assert_eq!(input_data, output.as_slice());
218 }
219
220 #[test]
221 fn field_can_read_slices_with_skip() {
222 let input_data = [
223 0b0010_1101, 0b0000_1100, ];
226
227 let expected_output = [0b00_11_00_11];
228
229 let mut field_states = create_mock_field_states(
230 TEST_FIELD_NAME,
231 &input_data,
232 4,
233 BitOrder::Lsb,
234 BitOrder::Lsb,
235 );
236 let mut output = Vec::new();
237
238 let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
239 write_struct(
240 &mut field_states,
241 &mut writer,
242 &GroupComponentStruct {
243 fields: vec![
244 GroupComponent::Skip(GroupComponentSkip {
245 field: TEST_FIELD_NAME.to_string(),
246 bits: 2, }),
248 GroupComponent::Field(GroupComponentField {
249 field: TEST_FIELD_NAME.to_string(),
250 bits: 2, }),
252 ],
253 lz_match_multiplier: default_lz_match_multiplier(),
254 entropy_multiplier: default_entropy_multiplier(),
255 },
256 )
257 .unwrap();
258
259 assert_eq!(expected_output, output.as_slice());
260 }
261
262 #[test]
263 fn padding_writes_correct_bits_lsb() {
264 let mut field_states =
265 create_mock_field_states(TEST_FIELD_NAME, &[], 0, BitOrder::Lsb, BitOrder::Lsb);
266 let mut output = Vec::new();
267 let mut writer = BitWriter::endian(Cursor::new(&mut output), LittleEndian);
268 write_struct(
269 &mut field_states,
270 &mut writer,
271 &GroupComponentStruct {
272 fields: vec![GroupComponent::Padding(GroupComponentPadding {
273 bits: 4,
274 value: 0b1010,
275 })],
276 lz_match_multiplier: default_lz_match_multiplier(),
277 entropy_multiplier: default_entropy_multiplier(),
278 },
279 )
280 .unwrap();
281 writer.byte_align().unwrap();
282 writer.flush().unwrap();
283 assert_eq!(output, [0b0000_1010]);
284 }
285
286 #[test]
287 fn padding_writes_correct_bits_msb() {
288 let mut field_states =
289 create_mock_field_states(TEST_FIELD_NAME, &[], 0, BitOrder::Msb, BitOrder::Msb);
290 let mut output = Vec::new();
291 let mut writer = BitWriter::endian(Cursor::new(&mut output), BigEndian);
292 write_struct(
293 &mut field_states,
294 &mut writer,
295 &GroupComponentStruct {
296 fields: vec![GroupComponent::Padding(GroupComponentPadding {
297 bits: 4,
298 value: 0b1010,
299 })],
300 lz_match_multiplier: default_lz_match_multiplier(),
301 entropy_multiplier: default_entropy_multiplier(),
302 },
303 )
304 .unwrap();
305 writer.byte_align().unwrap();
306 writer.flush().unwrap();
307 assert_eq!(output, [0b1010_0000]);
308 }
309}