Skip to main content

rust_eigenda_v2_common/core/
encoded_payload.rs

1use crate::{
2    core::{Payload, PayloadEncodingVersion, BYTES_PER_SYMBOL},
3    errors::ConversionError,
4};
5use ark_bn254::Fr;
6use rust_kzg_bn254_primitives::helpers::{to_byte_array, to_fr_array};
7
8/// [`EncodedPayload`] represents a payload that has had an encoding applied to it.
9///
10/// Encoding Format:
11///
12/// The encoded payload consists of two parts:
13///
14/// 1. Header (32 bytes):
15///    - Byte 0: Always 0x00 (reserved)
16///    - Byte 1: Encoding Version byte (e.g., 0x00 for PayloadEncodingVersion::Zero)
17///    - Bytes 2-5: Big-endian u32 representing the original payload length
18///    - Bytes 6-31: Reserved (filled with 0x00)
19///
20/// 2. Data (multiple of 32 bytes):
21///    Each 32-byte chunk contains:
22///    - Byte 0: 0x00 (padding byte to ensure the data is in valid field element range)
23///    - Bytes 1-31: 31 bytes of actual payload data (or padding for the last chunk)
24///
25/// The padding ensures that all data is compatible with the bn254 curve's field element
26/// limitations, as each 32-byte segment represents a field element.
27#[derive(Debug, PartialEq)]
28pub struct EncodedPayload {
29    /// the size of these bytes is guaranteed to be a multiple of 32
30    bytes: Vec<u8>,
31}
32
33impl EncodedPayload {
34    /// Creates a new [`EncodedPayload`] from a [`Payload`], performing the `PayloadEncodingVersion0` encoding.
35    pub fn new(payload: &Payload) -> Result<EncodedPayload, ConversionError> {
36        let mut header = [0u8; 32].to_vec();
37        header[1] = PayloadEncodingVersion::Zero as u8;
38
39        let payload_bytes: Vec<u8> = payload.serialize();
40
41        // add payload length to the header
42        let payload_length: u32 = payload_bytes.len() as u32;
43        header[2..6].copy_from_slice(&payload_length.to_be_bytes());
44
45        // encode payload modulo bn254, and align to 32 bytes
46        let encoded_data = pad_to_bn254(&payload_bytes);
47
48        let mut bytes = Vec::new();
49        bytes.extend_from_slice(&header);
50        bytes.extend_from_slice(&encoded_data);
51
52        Ok(EncodedPayload { bytes })
53    }
54
55    /// Decodes the [`EncodedPayload`] back into a [`Payload`].
56    pub fn decode(&self) -> Result<Payload, ConversionError> {
57        let expected_data_length = match self.bytes[2..6].try_into() {
58            Ok(arr) => u32::from_be_bytes(arr),
59            Err(_) => {
60                return Err(ConversionError::Payload(
61                    "Invalid header format: couldn't read data length".to_string(),
62                ))
63            }
64        };
65        // decode raw data modulo bn254
66        let unpadded_data = remove_internal_padding(&self.bytes[32..])?;
67        let unpadded_data_length = unpadded_data.len() as u32;
68
69        // data length is checked when constructing an encoded payload. If this error is encountered, that means there
70        // must be a flaw in the logic at construction time (or someone was bad and didn't use the proper construction methods)
71        if unpadded_data_length < expected_data_length {
72            return Err(ConversionError::Payload(
73                "Invalid header format: data length is less than expected".to_string(),
74            ));
75        }
76
77        if unpadded_data_length > expected_data_length + 31 {
78            return Err(ConversionError::Payload(
79                "Invalid header format: data length is greater than expected".to_string(),
80            ));
81        }
82
83        Ok(Payload::new(
84            unpadded_data[0..expected_data_length as usize].to_vec(),
85        ))
86    }
87
88    /// Converts the encoded payload to an array of field elements.
89    pub fn to_field_elements(&self) -> Vec<Fr> {
90        to_fr_array(&self.bytes)
91    }
92
93    /// Creates an `EncodedPayload` from an array of field elements.
94    /// `max_payload_length` is the maximum length in bytes that the contained [`Payload`] is permitted to be.
95    pub fn from_field_elements(
96        field_elements: &[Fr],
97        max_payload_length: usize,
98    ) -> Result<EncodedPayload, ConversionError> {
99        let serialized_felts = to_byte_array(field_elements, usize::MAX);
100        // read payload length from the payload header
101        let payload_length = match serialized_felts[2..6].try_into() {
102            Ok(arr) => u32::from_be_bytes(arr),
103            Err(_) => {
104                return Err(ConversionError::EncodedPayload(
105                    "invalid serialized field elements: couldn't read payload length".to_string(),
106                ))
107            }
108        };
109
110        if payload_length > max_payload_length as u32 {
111            return Err(ConversionError::EncodedPayload(
112                "invalid serialized field elements: payload length is greater than maximum allowed"
113                    .to_string(),
114            ));
115        }
116
117        let padded_length = get_padded_data_length(payload_length as usize);
118        // add 32 to take into account the payload header
119        let encoded_payload_length = padded_length + 32;
120
121        let serialized_felts_length = serialized_felts.len();
122        let length_to_copy = encoded_payload_length.min(serialized_felts_length);
123
124        if encoded_payload_length < serialized_felts_length {
125            // serialized_felts is longer than encoded_payload_length,
126            // so we need to check that the remaining bytes are all 0.
127            let remaining_serialized_felts = serialized_felts
128                .iter()
129                .enumerate()
130                .skip(encoded_payload_length);
131            for (index, &byte) in remaining_serialized_felts {
132                if byte != 0 {
133                    return Err(ConversionError::EncodedPayload(format!(
134                        "byte at index {} was expected to be 0x00, but instead was 0x{:02x}",
135                        index, byte
136                    )));
137                }
138            }
139        }
140
141        // Create a byte vector of size encoded_payload_length filled with zeros
142        let mut encoded_payload_bytes = vec![0u8; encoded_payload_length];
143
144        // Copy data from serialized_felts up to length_to_copy
145        encoded_payload_bytes[..length_to_copy]
146            .copy_from_slice(&serialized_felts[..length_to_copy]);
147
148        // Return a new EncodedPayload with the byte vector
149        Ok(EncodedPayload {
150            bytes: encoded_payload_bytes,
151        })
152    }
153}
154
155/// Accepts an array of padded data, and removes the internal padding.
156///
157/// This function assumes that the input aligns to 32 bytes. Since it is removing 1 byte for every 31 bytes kept, the
158/// output from this function is not guaranteed to align to 32 bytes.
159fn remove_internal_padding(padded_data: &[u8]) -> Result<Vec<u8>, ConversionError> {
160    if padded_data.len() % BYTES_PER_SYMBOL != 0 {
161        return Err(ConversionError::EncodedPayload(format!(
162            "padded data (length {}) must be multiple of BYTES_PER_SYMBOL ({})",
163            padded_data.len(),
164            BYTES_PER_SYMBOL
165        )));
166    }
167
168    let bytes_per_chunk = BYTES_PER_SYMBOL - 1;
169    let symbol_count = padded_data.len() / BYTES_PER_SYMBOL;
170    let output_length = symbol_count * bytes_per_chunk;
171
172    let mut output_data = vec![0u8; output_length];
173
174    for i in 0..symbol_count {
175        let dst_index = i * bytes_per_chunk;
176        let src_index = i * BYTES_PER_SYMBOL + 1;
177
178        output_data[dst_index..dst_index + bytes_per_chunk]
179            .copy_from_slice(&padded_data[src_index..src_index + bytes_per_chunk]);
180    }
181
182    Ok(output_data)
183}
184
185/// Accepts the length of a byte array, and returns the length that the array would be after
186/// adding internal byte padding.
187///
188/// The value returned from this function will always be a multiple of [`BYTES_PER_SYMBOL`]
189fn get_padded_data_length(data_length: usize) -> usize {
190    let bytes_per_chunk = BYTES_PER_SYMBOL - 1;
191    let mut chunk_count = data_length / bytes_per_chunk;
192
193    if data_length % bytes_per_chunk != 0 {
194        chunk_count += 1;
195    }
196
197    chunk_count * BYTES_PER_SYMBOL
198}
199
200/// Accepts an array of data, and returns the array after adding padding to be bn254 friendly.
201fn pad_to_bn254(data: &[u8]) -> Vec<u8> {
202    let bytes_per_chunk = BYTES_PER_SYMBOL - 1;
203    let output_length = get_padded_data_length(data.len());
204    let mut padded_output = vec![0u8; output_length];
205
206    // pre-pad the input, so that it aligns to 31 bytes. This means that the internally padded result will automatically
207    // align to 32 bytes. Doing this padding in advance simplifies the for loop.
208    let required_pad = (bytes_per_chunk - data.len() % bytes_per_chunk) % bytes_per_chunk;
209    let pre_padded_payload = [data, &vec![0u8; required_pad]].concat();
210
211    for elem in 0..output_length / 32 {
212        let zero_byte_index = elem * BYTES_PER_SYMBOL;
213        padded_output[zero_byte_index] = 0x00;
214
215        let destination_index = zero_byte_index + 1;
216        let source_index = elem * bytes_per_chunk;
217
218        let pre_padded_chunk = &pre_padded_payload[source_index..source_index + bytes_per_chunk];
219        padded_output[destination_index..destination_index + bytes_per_chunk]
220            .copy_from_slice(pre_padded_chunk);
221    }
222
223    padded_output
224}
225
226#[cfg(test)]
227mod tests {
228    use crate::core::{encoded_payload::BYTES_PER_SYMBOL, EncodedPayload, Payload};
229    use rand::{thread_rng, Rng};
230
231    /// Checks that encoding and decoding a payload works correctly.
232    #[test]
233    fn test_encoding_decoding() {
234        // TODO: add proptest
235        let payload = Payload::new(vec![
236            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
237            25, 26, 27, 28, 29, 30, 31, 32,
238        ]);
239        let encoded_payload = EncodedPayload::new(&payload);
240        assert!(encoded_payload.is_ok());
241
242        let decoded_payload = encoded_payload.unwrap().decode();
243        assert!(decoded_payload.is_ok());
244        assert_eq!(payload, decoded_payload.unwrap());
245    }
246
247    /// Checks that an encoded payload with a length less than claimed length fails at decode time
248    #[test]
249    fn test_decode_short_bytes() {
250        // TODO: add proptest
251        let mut rng = thread_rng();
252        let random_length = rng.gen_range(33..1057); // 33 + random value up to 1024
253        let original_data: Vec<u8> = (0..random_length).map(|_| rng.r#gen()).collect();
254
255        // Create payload and encode it
256        let payload = Payload::new(original_data);
257        let encoded_payload = EncodedPayload::new(&payload).unwrap();
258
259        // Create a truncated version by removing the last 32 bytes
260        let truncated_bytes = encoded_payload.bytes[..encoded_payload.bytes.len() - 32].to_vec();
261        let truncated_payload = EncodedPayload {
262            bytes: truncated_bytes,
263        };
264
265        // Try to decode the truncated payload - should fail
266        let decode_result = truncated_payload.decode();
267        assert!(decode_result.is_err());
268    }
269
270    /// Checks that an encoded payload with length greater than claimed fails at decode
271    #[test]
272    fn test_decode_long_bytes() {
273        // Generate random data
274        // TODO: add proptest
275        let mut rng = thread_rng();
276        let random_length = rng.gen_range(1..1025); // 1 + random value up to 1024
277        let original_data: Vec<u8> = (0..random_length).map(|_| rng.r#gen()).collect();
278
279        // Create payload and encode it
280        let payload = Payload::new(original_data);
281        let encoded_payload = EncodedPayload::new(&payload).unwrap();
282
283        // Create an extended version by appending 33 bytes (all zeros)
284        let mut extended_bytes = encoded_payload.bytes.clone();
285        extended_bytes.extend_from_slice(&[0u8; 33]);
286
287        let extended_payload = EncodedPayload {
288            bytes: extended_bytes,
289        };
290
291        // Try to decode the extended payload, it should fail since it has too many bytes
292        let decode_result = extended_payload.decode();
293        assert!(decode_result.is_err());
294    }
295
296    /// Checks that converting an `EncodedPayload` to an array of field elements and
297    /// then back to an `EncodedPayload` results in the same data.
298    #[test]
299    fn test_from_to_field_elements() {
300        // TODO: add proptest
301        let payload = Payload::new(
302            "0123456789012345678901234567890123"
303                .to_string()
304                .into_bytes(),
305        );
306        let encoded_payload = EncodedPayload::new(&payload).unwrap();
307
308        let field_elements = encoded_payload.to_field_elements();
309        let max_payload_length = usize::MAX;
310        let new_encoded_payload =
311            EncodedPayload::from_field_elements(&field_elements, max_payload_length).unwrap();
312
313        assert_eq!(encoded_payload, new_encoded_payload);
314    }
315
316    /// Checks that an encoded payload with trailing non-zero bytes fails at decode
317    #[test]
318    fn test_trailing_non_zeros() {
319        // TODO: add proptest
320        // Generate random data
321        let mut rng = thread_rng();
322        let random_length = rng.gen_range(1..1025); // 1 + random value up to 1024
323        let original_data: Vec<u8> = (0..random_length).map(|_| rng.r#gen()).collect();
324
325        // Create payload and encode it
326        let payload = Payload::new(original_data);
327        let encoded_payload = EncodedPayload::new(&payload).unwrap();
328
329        // Get the field elements
330        let original_elements = encoded_payload.to_field_elements();
331
332        // Create a copy with a zero element appended
333        let mut field_elements1 = original_elements.clone();
334        // Append zero element
335        field_elements1.push(ark_bn254::Fr::from(0));
336
337        // This should succeed - adding a zero is fine
338        let max_payload_length = field_elements1.len() * BYTES_PER_SYMBOL;
339        let result1 = EncodedPayload::from_field_elements(&field_elements1, max_payload_length);
340        assert!(result1.is_ok());
341
342        // Create another copy with a non-zero element appended
343        let mut field_elements2 = original_elements.clone();
344        // Append non-zero element
345        field_elements2.push(ark_bn254::Fr::from(1));
346
347        // This should fail - adding a trailing non-zero value is not fine
348        let max_payload_length = field_elements2.len() * BYTES_PER_SYMBOL;
349        let result2 = EncodedPayload::from_field_elements(&field_elements2, max_payload_length);
350        assert!(result2.is_err());
351    }
352}