rust_eigenda_v2_common/core/
encoded_payload.rs1use crate::{
2 core::{Payload, PayloadEncodingVersion, BYTES_PER_SYMBOL},
3 errors::ConversionError,
4};
5use ark_bn254::Fr;
6use rust_kzg_bn254_primitives::helpers::{to_byte_array, to_fr_array};
7
8#[derive(Debug, PartialEq)]
28pub struct EncodedPayload {
29 bytes: Vec<u8>,
31}
32
33impl EncodedPayload {
34 pub fn new(payload: &Payload) -> Result<EncodedPayload, ConversionError> {
36 let mut header = [0u8; 32].to_vec();
37 header[1] = PayloadEncodingVersion::Zero as u8;
38
39 let payload_bytes: Vec<u8> = payload.serialize();
40
41 let payload_length: u32 = payload_bytes.len() as u32;
43 header[2..6].copy_from_slice(&payload_length.to_be_bytes());
44
45 let encoded_data = pad_to_bn254(&payload_bytes);
47
48 let mut bytes = Vec::new();
49 bytes.extend_from_slice(&header);
50 bytes.extend_from_slice(&encoded_data);
51
52 Ok(EncodedPayload { bytes })
53 }
54
55 pub fn decode(&self) -> Result<Payload, ConversionError> {
57 let expected_data_length = match self.bytes[2..6].try_into() {
58 Ok(arr) => u32::from_be_bytes(arr),
59 Err(_) => {
60 return Err(ConversionError::Payload(
61 "Invalid header format: couldn't read data length".to_string(),
62 ))
63 }
64 };
65 let unpadded_data = remove_internal_padding(&self.bytes[32..])?;
67 let unpadded_data_length = unpadded_data.len() as u32;
68
69 if unpadded_data_length < expected_data_length {
72 return Err(ConversionError::Payload(
73 "Invalid header format: data length is less than expected".to_string(),
74 ));
75 }
76
77 if unpadded_data_length > expected_data_length + 31 {
78 return Err(ConversionError::Payload(
79 "Invalid header format: data length is greater than expected".to_string(),
80 ));
81 }
82
83 Ok(Payload::new(
84 unpadded_data[0..expected_data_length as usize].to_vec(),
85 ))
86 }
87
88 pub fn to_field_elements(&self) -> Vec<Fr> {
90 to_fr_array(&self.bytes)
91 }
92
93 pub fn from_field_elements(
96 field_elements: &[Fr],
97 max_payload_length: usize,
98 ) -> Result<EncodedPayload, ConversionError> {
99 let serialized_felts = to_byte_array(field_elements, usize::MAX);
100 let payload_length = match serialized_felts[2..6].try_into() {
102 Ok(arr) => u32::from_be_bytes(arr),
103 Err(_) => {
104 return Err(ConversionError::EncodedPayload(
105 "invalid serialized field elements: couldn't read payload length".to_string(),
106 ))
107 }
108 };
109
110 if payload_length > max_payload_length as u32 {
111 return Err(ConversionError::EncodedPayload(
112 "invalid serialized field elements: payload length is greater than maximum allowed"
113 .to_string(),
114 ));
115 }
116
117 let padded_length = get_padded_data_length(payload_length as usize);
118 let encoded_payload_length = padded_length + 32;
120
121 let serialized_felts_length = serialized_felts.len();
122 let length_to_copy = encoded_payload_length.min(serialized_felts_length);
123
124 if encoded_payload_length < serialized_felts_length {
125 let remaining_serialized_felts = serialized_felts
128 .iter()
129 .enumerate()
130 .skip(encoded_payload_length);
131 for (index, &byte) in remaining_serialized_felts {
132 if byte != 0 {
133 return Err(ConversionError::EncodedPayload(format!(
134 "byte at index {} was expected to be 0x00, but instead was 0x{:02x}",
135 index, byte
136 )));
137 }
138 }
139 }
140
141 let mut encoded_payload_bytes = vec![0u8; encoded_payload_length];
143
144 encoded_payload_bytes[..length_to_copy]
146 .copy_from_slice(&serialized_felts[..length_to_copy]);
147
148 Ok(EncodedPayload {
150 bytes: encoded_payload_bytes,
151 })
152 }
153}
154
155fn remove_internal_padding(padded_data: &[u8]) -> Result<Vec<u8>, ConversionError> {
160 if padded_data.len() % BYTES_PER_SYMBOL != 0 {
161 return Err(ConversionError::EncodedPayload(format!(
162 "padded data (length {}) must be multiple of BYTES_PER_SYMBOL ({})",
163 padded_data.len(),
164 BYTES_PER_SYMBOL
165 )));
166 }
167
168 let bytes_per_chunk = BYTES_PER_SYMBOL - 1;
169 let symbol_count = padded_data.len() / BYTES_PER_SYMBOL;
170 let output_length = symbol_count * bytes_per_chunk;
171
172 let mut output_data = vec![0u8; output_length];
173
174 for i in 0..symbol_count {
175 let dst_index = i * bytes_per_chunk;
176 let src_index = i * BYTES_PER_SYMBOL + 1;
177
178 output_data[dst_index..dst_index + bytes_per_chunk]
179 .copy_from_slice(&padded_data[src_index..src_index + bytes_per_chunk]);
180 }
181
182 Ok(output_data)
183}
184
185fn get_padded_data_length(data_length: usize) -> usize {
190 let bytes_per_chunk = BYTES_PER_SYMBOL - 1;
191 let mut chunk_count = data_length / bytes_per_chunk;
192
193 if data_length % bytes_per_chunk != 0 {
194 chunk_count += 1;
195 }
196
197 chunk_count * BYTES_PER_SYMBOL
198}
199
200fn pad_to_bn254(data: &[u8]) -> Vec<u8> {
202 let bytes_per_chunk = BYTES_PER_SYMBOL - 1;
203 let output_length = get_padded_data_length(data.len());
204 let mut padded_output = vec![0u8; output_length];
205
206 let required_pad = (bytes_per_chunk - data.len() % bytes_per_chunk) % bytes_per_chunk;
209 let pre_padded_payload = [data, &vec![0u8; required_pad]].concat();
210
211 for elem in 0..output_length / 32 {
212 let zero_byte_index = elem * BYTES_PER_SYMBOL;
213 padded_output[zero_byte_index] = 0x00;
214
215 let destination_index = zero_byte_index + 1;
216 let source_index = elem * bytes_per_chunk;
217
218 let pre_padded_chunk = &pre_padded_payload[source_index..source_index + bytes_per_chunk];
219 padded_output[destination_index..destination_index + bytes_per_chunk]
220 .copy_from_slice(pre_padded_chunk);
221 }
222
223 padded_output
224}
225
226#[cfg(test)]
227mod tests {
228 use crate::core::{encoded_payload::BYTES_PER_SYMBOL, EncodedPayload, Payload};
229 use rand::{thread_rng, Rng};
230
231 #[test]
233 fn test_encoding_decoding() {
234 let payload = Payload::new(vec![
236 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
237 25, 26, 27, 28, 29, 30, 31, 32,
238 ]);
239 let encoded_payload = EncodedPayload::new(&payload);
240 assert!(encoded_payload.is_ok());
241
242 let decoded_payload = encoded_payload.unwrap().decode();
243 assert!(decoded_payload.is_ok());
244 assert_eq!(payload, decoded_payload.unwrap());
245 }
246
247 #[test]
249 fn test_decode_short_bytes() {
250 let mut rng = thread_rng();
252 let random_length = rng.gen_range(33..1057); let original_data: Vec<u8> = (0..random_length).map(|_| rng.r#gen()).collect();
254
255 let payload = Payload::new(original_data);
257 let encoded_payload = EncodedPayload::new(&payload).unwrap();
258
259 let truncated_bytes = encoded_payload.bytes[..encoded_payload.bytes.len() - 32].to_vec();
261 let truncated_payload = EncodedPayload {
262 bytes: truncated_bytes,
263 };
264
265 let decode_result = truncated_payload.decode();
267 assert!(decode_result.is_err());
268 }
269
270 #[test]
272 fn test_decode_long_bytes() {
273 let mut rng = thread_rng();
276 let random_length = rng.gen_range(1..1025); let original_data: Vec<u8> = (0..random_length).map(|_| rng.r#gen()).collect();
278
279 let payload = Payload::new(original_data);
281 let encoded_payload = EncodedPayload::new(&payload).unwrap();
282
283 let mut extended_bytes = encoded_payload.bytes.clone();
285 extended_bytes.extend_from_slice(&[0u8; 33]);
286
287 let extended_payload = EncodedPayload {
288 bytes: extended_bytes,
289 };
290
291 let decode_result = extended_payload.decode();
293 assert!(decode_result.is_err());
294 }
295
296 #[test]
299 fn test_from_to_field_elements() {
300 let payload = Payload::new(
302 "0123456789012345678901234567890123"
303 .to_string()
304 .into_bytes(),
305 );
306 let encoded_payload = EncodedPayload::new(&payload).unwrap();
307
308 let field_elements = encoded_payload.to_field_elements();
309 let max_payload_length = usize::MAX;
310 let new_encoded_payload =
311 EncodedPayload::from_field_elements(&field_elements, max_payload_length).unwrap();
312
313 assert_eq!(encoded_payload, new_encoded_payload);
314 }
315
316 #[test]
318 fn test_trailing_non_zeros() {
319 let mut rng = thread_rng();
322 let random_length = rng.gen_range(1..1025); let original_data: Vec<u8> = (0..random_length).map(|_| rng.r#gen()).collect();
324
325 let payload = Payload::new(original_data);
327 let encoded_payload = EncodedPayload::new(&payload).unwrap();
328
329 let original_elements = encoded_payload.to_field_elements();
331
332 let mut field_elements1 = original_elements.clone();
334 field_elements1.push(ark_bn254::Fr::from(0));
336
337 let max_payload_length = field_elements1.len() * BYTES_PER_SYMBOL;
339 let result1 = EncodedPayload::from_field_elements(&field_elements1, max_payload_length);
340 assert!(result1.is_ok());
341
342 let mut field_elements2 = original_elements.clone();
344 field_elements2.push(ark_bn254::Fr::from(1));
346
347 let max_payload_length = field_elements2.len() * BYTES_PER_SYMBOL;
349 let result2 = EncodedPayload::from_field_elements(&field_elements2, max_payload_length);
350 assert!(result2.is_err());
351 }
352}