zeroproto/
builder.rs

1//! Message and vector builders for ZeroProto serialization
2
3#[cfg(not(feature = "std"))]
4extern crate alloc;
5#[cfg(not(feature = "std"))]
6use alloc::{collections::BTreeMap, vec::Vec};
7#[cfg(feature = "std")]
8use std::{collections::BTreeMap, vec::Vec};
9
10use crate::{
11    constants::{FIELD_ENTRY_SIZE, MAX_FIELDS},
12    errors::{Error, Result},
13    primitives::{Endian, PrimitiveType},
14    ZpWrite,
15};
16
17/// A message builder for serializing ZeroProto messages
18#[derive(Debug)]
19pub struct MessageBuilder {
20    buffer: Vec<u8>,
21    field_entries: BTreeMap<u16, FieldEntry>,
22    payload_offset: usize,
23}
24
25#[derive(Debug, Clone)]
26struct FieldEntry {
27    type_id: u8,
28    offset: u32,
29}
30
31impl MessageBuilder {
32    /// Create a new message builder
33    pub fn new() -> Self {
34        let mut builder = Self {
35            buffer: Vec::new(),
36            field_entries: BTreeMap::new(),
37            payload_offset: 0,
38        };
39
40        // Reserve space for field count (will be filled later)
41        builder.buffer.extend_from_slice(&[0, 0]);
42        builder.payload_offset = 2;
43        builder
44    }
45
46    /// Get the current number of fields
47    pub fn field_count(&self) -> u16 {
48        self.field_entries_count()
49    }
50
51    /// Add a scalar field
52    pub fn set_scalar<T: ZpWrite>(&mut self, field_index: u16, value: T) -> Result<()> {
53        self.ensure_field_index(field_index)?;
54
55        let type_id = self.get_type_id::<T>()?;
56        let field_offset = self.payload_offset as u32;
57
58        // Resize buffer to fit the value
59        let required_size = self.payload_offset + value.size();
60        if required_size > self.buffer.len() {
61            self.buffer.resize(required_size, 0);
62        }
63
64        // Write the value
65        value.write(&mut self.buffer, self.payload_offset)?;
66
67        // Add/update field entry
68        self.set_field_entry(field_index, type_id, field_offset);
69
70        // Update payload offset
71        self.payload_offset += value.size();
72
73        Ok(())
74    }
75
76    /// Add a string field
77    pub fn set_string(&mut self, field_index: u16, value: &str) -> Result<()> {
78        self.ensure_field_index(field_index)?;
79
80        let type_id = PrimitiveType::String as u8;
81        let field_offset = self.payload_offset as u32;
82
83        // Reserve space for length + string bytes
84        let len = value.len();
85        let required_size = self.payload_offset + 4 + len;
86        if required_size > self.buffer.len() {
87            self.buffer.resize(required_size, 0);
88        }
89
90        // Write length
91        Endian::Little.write_u32(len as u32, &mut self.buffer, self.payload_offset);
92
93        // Write string bytes
94        let string_offset = self.payload_offset + 4;
95        self.buffer[string_offset..string_offset + len].copy_from_slice(value.as_bytes());
96
97        // Add/update field entry
98        self.set_field_entry(field_index, type_id, field_offset);
99
100        // Update payload offset
101        self.payload_offset += 4 + len;
102
103        Ok(())
104    }
105
106    /// Add a bytes field
107    pub fn set_bytes(&mut self, field_index: u16, value: &[u8]) -> Result<()> {
108        self.ensure_field_index(field_index)?;
109
110        let type_id = PrimitiveType::Bytes as u8;
111        let field_offset = self.payload_offset as u32;
112
113        // Reserve space for length + bytes
114        let len = value.len();
115        let required_size = self.payload_offset + 4 + len;
116        if required_size > self.buffer.len() {
117            self.buffer.resize(required_size, 0);
118        }
119
120        // Write length
121        Endian::Little.write_u32(len as u32, &mut self.buffer, self.payload_offset);
122
123        // Write bytes
124        let bytes_offset = self.payload_offset + 4;
125        self.buffer[bytes_offset..bytes_offset + len].copy_from_slice(value);
126
127        // Add/update field entry
128        self.set_field_entry(field_index, type_id, field_offset);
129
130        // Update payload offset
131        self.payload_offset += 4 + len;
132
133        Ok(())
134    }
135
136    /// Add a nested message
137    pub fn set_message(&mut self, field_index: u16, message: &[u8]) -> Result<()> {
138        self.ensure_field_index(field_index)?;
139
140        let type_id = PrimitiveType::Message as u8;
141        let field_offset = self.payload_offset as u32;
142
143        // Reserve space for length + message bytes
144        let len = message.len();
145        let required_size = self.payload_offset + 4 + len;
146        if required_size > self.buffer.len() {
147            self.buffer.resize(required_size, 0);
148        }
149
150        // Write length
151        Endian::Little.write_u32(len as u32, &mut self.buffer, self.payload_offset);
152
153        // Write message bytes
154        let message_offset = self.payload_offset + 4;
155        self.buffer[message_offset..message_offset + len].copy_from_slice(message);
156
157        // Add/update field entry
158        self.set_field_entry(field_index, type_id, field_offset);
159
160        // Update payload offset
161        self.payload_offset += 4 + len;
162
163        Ok(())
164    }
165
166    /// Add a vector field
167    pub fn set_vector<T: ZpWrite>(&mut self, field_index: u16, values: &[T]) -> Result<()> {
168        self.ensure_field_index(field_index)?;
169
170        let type_id = PrimitiveType::Vector as u8;
171        let field_offset = self.payload_offset as u32;
172
173        // Calculate total size needed
174        let element_size = if values.is_empty() {
175            0
176        } else {
177            values[0].size()
178        };
179        let total_size = 4 + values.len() * element_size;
180        let required_size = self.payload_offset + total_size;
181        if required_size > self.buffer.len() {
182            self.buffer.resize(required_size, 0);
183        }
184
185        // Write count
186        Endian::Little.write_u32(values.len() as u32, &mut self.buffer, self.payload_offset);
187
188        // Write elements
189        let mut offset = self.payload_offset + 4;
190        for value in values {
191            value.write(&mut self.buffer, offset)?;
192            offset += value.size();
193        }
194
195        // Add/update field entry
196        self.set_field_entry(field_index, type_id, field_offset);
197
198        // Update payload offset
199        self.payload_offset += total_size;
200
201        Ok(())
202    }
203
204    fn field_entries_count(&self) -> u16 {
205        self.field_entries
206            .keys()
207            .next_back()
208            .copied()
209            .map(|index| index.saturating_add(1))
210            .unwrap_or(0)
211    }
212
213    /// Ensure the field index is valid
214    fn ensure_field_index(&self, field_index: u16) -> Result<()> {
215        if field_index == MAX_FIELDS {
216            return Err(Error::OutOfBounds);
217        }
218
219        Ok(())
220    }
221
222    /// Set a field entry
223    fn set_field_entry(&mut self, field_index: u16, type_id: u8, offset: u32) {
224        self.field_entries
225            .insert(field_index, FieldEntry { type_id, offset });
226    }
227
228    /// Clear a field entry (used for optional setters)
229    pub fn clear_field(&mut self, field_index: u16) -> Result<()> {
230        self.ensure_field_index(field_index)?;
231        self.field_entries.remove(&field_index);
232        Ok(())
233    }
234
235    /// Get the type ID for a type
236    fn get_type_id<T>(&self) -> Result<u8> {
237        let type_id = match core::any::type_name::<T>() {
238            "u8" => PrimitiveType::U8 as u8,
239            "u16" => PrimitiveType::U16 as u8,
240            "u32" => PrimitiveType::U32 as u8,
241            "u64" => PrimitiveType::U64 as u8,
242            "i8" => PrimitiveType::I8 as u8,
243            "i16" => PrimitiveType::I16 as u8,
244            "i32" => PrimitiveType::I32 as u8,
245            "i64" => PrimitiveType::I64 as u8,
246            "f32" => PrimitiveType::F32 as u8,
247            "f64" => PrimitiveType::F64 as u8,
248            "bool" => PrimitiveType::Bool as u8,
249            _ => return Err(Error::InvalidFieldType),
250        };
251        Ok(type_id)
252    }
253
254    /// Finish building and return the serialized message
255    pub fn finish(mut self) -> Vec<u8> {
256        let field_count = self.field_entries_count();
257
258        // Write field count
259        Endian::Little.write_u16(field_count, &mut self.buffer, 0);
260
261        // Reserve space for field table
262        let field_table_size = field_count as usize * FIELD_ENTRY_SIZE;
263        let current_payload_offset = self.payload_offset;
264
265        // Shift payload to make room for field table
266        self.buffer
267            .resize(current_payload_offset + field_table_size, 0);
268
269        // Move payload data
270        for i in (0..(current_payload_offset - 2)).rev() {
271            self.buffer[2 + field_table_size + i] = self.buffer[2 + i];
272        }
273
274        // Update field entry offsets to account for field table
275        for entry in self.field_entries.values_mut() {
276            entry.offset += field_table_size as u32;
277        }
278
279        // Write field table
280        let mut field_table_offset = 2;
281        for field_index in 0..field_count {
282            if let Some(entry) = self.field_entries.get(&field_index) {
283                self.buffer[field_table_offset] = entry.type_id;
284                Endian::Little.write_u32(entry.offset, &mut self.buffer, field_table_offset + 1);
285            } else {
286                self.buffer[field_table_offset] = PrimitiveType::Unset as u8;
287                Endian::Little.write_u32(0, &mut self.buffer, field_table_offset + 1);
288            }
289            field_table_offset += FIELD_ENTRY_SIZE;
290        }
291
292        // Trim buffer to actual size
293        self.buffer
294            .truncate(2 + field_table_size + (current_payload_offset - 2));
295
296        self.buffer
297    }
298}
299
300impl Default for MessageBuilder {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306/// A vector builder for serializing vectors
307#[derive(Debug)]
308pub struct VectorBuilder<T> {
309    elements: Vec<T>,
310}
311
312impl<T: ZpWrite> VectorBuilder<T> {
313    /// Create a new vector builder
314    pub fn new() -> Self {
315        Self {
316            elements: Vec::new(),
317        }
318    }
319
320    /// Add an element to the vector
321    pub fn push(&mut self, element: T) {
322        self.elements.push(element);
323    }
324
325    /// Get the number of elements
326    pub fn len(&self) -> usize {
327        self.elements.len()
328    }
329
330    /// Check if the vector is empty
331    pub fn is_empty(&self) -> bool {
332        self.elements.is_empty()
333    }
334
335    /// Finish building and return the elements
336    pub fn finish(self) -> Vec<T> {
337        self.elements
338    }
339}
340
341impl<T> Default for VectorBuilder<T> {
342    fn default() -> Self {
343        Self {
344            elements: Vec::new(),
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::primitives::Endian;
353
354    #[cfg(feature = "std")]
355    use std::println;
356    #[cfg(feature = "std")]
357    use std::vec;
358
359    #[test]
360    fn test_empty_message() {
361        let builder = MessageBuilder::new();
362        let buffer = builder.finish();
363        assert_eq!(buffer, vec![0, 0]);
364    }
365
366    #[test]
367    fn test_scalar_field() {
368        let mut builder = MessageBuilder::new();
369        builder.set_scalar(0, 42u16).unwrap();
370        let buffer = builder.finish();
371
372        // Parse and verify
373        let reader = crate::reader::MessageReader::new(&buffer).unwrap();
374        let value: u16 = reader.get_scalar(0).unwrap();
375        assert_eq!(value, 42);
376    }
377
378    #[test]
379    fn test_string_field() {
380        let mut builder = MessageBuilder::new();
381        builder.set_string(0, "hello").unwrap();
382        let buffer = builder.finish();
383
384        // Parse and verify
385        let reader = crate::reader::MessageReader::new(&buffer).unwrap();
386        let value = reader.get_string(0).unwrap();
387        assert_eq!(value, "hello");
388    }
389
390    #[test]
391    fn test_builder_basic() -> Result<()> {
392        let mut builder = MessageBuilder::new();
393        builder.set_scalar(0, 42u64)?;
394        let data = builder.finish();
395
396        assert_eq!(data.len(), 15); // 2 + 5 + 8
397        assert_eq!(data[0], 1); // field count
398        assert_eq!(data[2], 3); // u64 type id
399
400        Ok(())
401    }
402
403    #[test]
404    fn test_builder_multiple_fields() -> Result<()> {
405        let mut builder = MessageBuilder::new();
406        builder.set_scalar(0, 42u64)?;
407        builder.set_scalar(1, 100u32)?;
408        let data = builder.finish();
409
410        println!("Buffer length: {}", data.len());
411        println!("Buffer: {:?}", data);
412
413        // Expected: 2 (field count) + 10 (field table: 2 fields * 5 bytes) + 8 (u64) + 4 (u32) = 24
414        assert_eq!(data.len(), 24);
415
416        Ok(())
417    }
418
419    #[test]
420    fn test_builder_string() -> Result<()> {
421        let mut builder = MessageBuilder::new();
422        builder.set_string(0, "hello")?;
423        let data = builder.finish();
424
425        println!("String buffer length: {}", data.len());
426        println!("String buffer: {:?}", data);
427
428        // Expected: 2 (field count) + 5 (field table) + 4 (string length) + 5 (string bytes) = 16
429        assert_eq!(data.len(), 16);
430
431        Ok(())
432    }
433}