rust_lcm_codegen/
lib.rs

1//! Code generation for LCM serialization and deserialization in Rust
2#![allow(unused_variables)]
3#![allow(dead_code)]
4#![deny(warnings)]
5
6pub mod fingerprint;
7pub mod parser;
8
9use crate::parser::{ArrayDimension, ArrayType, PrimitiveType, StructType, Type};
10use proc_macro2::{Ident, TokenStream};
11use quote::{format_ident, quote};
12use std::fs::File;
13use std::io::{Read, Write};
14use std::path::Path;
15use std::process::Command;
16
17/// Generate a single Rust file from a collection of LCM schema files.
18pub fn generate<P1: AsRef<Path>, SF: IntoIterator<Item = P1>, P2: AsRef<Path>>(
19    schema_files: SF,
20    out_file_path: P2,
21) {
22    let out_file_path: &Path = out_file_path.as_ref();
23    let mut out_file = File::create(out_file_path).expect("Create out file");
24
25    let mut all_schemas = vec![];
26    for schema_file in schema_files.into_iter() {
27        let mut schema = File::open(schema_file.as_ref()).expect("Open schema");
28        let mut schema_content = String::new();
29        schema
30            .read_to_string(&mut schema_content)
31            .expect("Read schema");
32
33        let (remaining, ast) = parser::schema(&schema_content).expect("Parse schema");
34        assert_eq!(remaining, "", "Unparsed text at end of schema");
35        all_schemas.push(ast);
36    }
37    // TODO - either merge schema contents in the same package
38    // or error out when more than one file declares the same package
39
40    let schemas_code = all_schemas.iter().map(|schema| {
41        let env = Environment {
42            local_schema: schema.clone(),
43            all_schemas: all_schemas.clone(),
44        };
45
46        emit_schema(&schema, &env)
47    });
48
49    let tokens = quote! {
50        #(#schemas_code)*
51    };
52
53    write!(out_file, "{}", tokens).expect("Write out file");
54    rustfmt(out_file_path);
55}
56
57fn rustfmt<P: AsRef<Path>>(path: P) {
58    let path = path.as_ref();
59
60    Command::new("rustfmt")
61        .arg("--edition")
62        .arg("2018")
63        .arg(path.as_os_str())
64        .output()
65        .expect("rustfmt");
66}
67
68fn emit_schema(schema: &parser::Schema, env: &Environment) -> TokenStream {
69    let structs_code = schema.structs.iter().map(|s| emit_struct(s, env));
70    match &schema.package {
71        Some(name) => {
72            let mod_ident = format_ident!("{}", name);
73            quote! {
74                #[allow(non_camel_case_types)]
75                pub mod #mod_ident {
76                    #(#structs_code)*
77                }
78            }
79        }
80        None => quote! {
81            #(#structs_code)*
82        },
83    }
84}
85
86#[derive(Debug, PartialEq, Eq)]
87enum StateName {
88    Ready,
89    HandlingField(String),
90    Done,
91}
92
93impl StateName {
94    fn name(&self) -> &str {
95        match self {
96            StateName::Ready => "ready",
97            StateName::HandlingField(s) => s.as_str(),
98            StateName::Done => "done",
99        }
100    }
101}
102
103#[derive(Debug)]
104struct CodecState {
105    state_name: StateName,
106    /// The name of the LCM struct this state is for
107    struct_name: String,
108    /// field that's written when transitioning out of this state,
109    /// and whether that value needs to be captured for use in
110    /// tracking the length of an array
111    field: Option<(parser::Field, bool)>,
112    /// The array-length values that this state needs to pass along to
113    /// future states for the purposes of correctly sizing arrays,
114    /// identified by the name of the field they serve
115    baggage_dimensions: Vec<BaggageDimension>,
116}
117
118impl CodecState {
119    fn writer_struct_state_decl_ident(struct_name: &str, state_name: &StateName) -> Ident {
120        format_ident!("{}_write_{}", struct_name, state_name.name())
121    }
122    fn reader_struct_state_decl_ident(struct_name: &str, state_name: &StateName) -> Ident {
123        format_ident!("{}_read_{}", struct_name, state_name.name())
124    }
125    fn writer_ident(&self) -> Ident {
126        CodecState::writer_struct_state_decl_ident(&self.struct_name, &self.state_name)
127    }
128    fn reader_ident(&self) -> Ident {
129        CodecState::reader_struct_state_decl_ident(&self.struct_name, &self.state_name)
130    }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Hash)]
134struct BaggageDimension {
135    array_field_name: String,
136    len_field_name: String,
137    dimension_depth: usize,
138}
139
140impl BaggageDimension {
141    fn field_declarations(baggage_dimensions: &[BaggageDimension]) -> Vec<TokenStream> {
142        baggage_dimensions
143            .iter()
144            .map(|d| {
145                let field_ident = format_ident!("baggage_{}", d.len_field_name);
146                quote!(#field_ident: usize,)
147            })
148            .collect()
149    }
150    fn field_initializations_from_self<'a>(
151        baggage_dimensions: impl IntoIterator<Item = &'a BaggageDimension>,
152    ) -> Vec<TokenStream> {
153        baggage_dimensions
154            .into_iter()
155            .map(|d| {
156                let baggage_field_ident = format_ident!("baggage_{}", d.len_field_name);
157                quote!(#baggage_field_ident: self.#baggage_field_ident)
158            })
159            .collect()
160    }
161}
162
163fn to_underscored_literal(v: u64) -> proc_macro2::Literal {
164    use std::str::FromStr;
165    let raw = format!("{}", v);
166    let original_len = raw.len();
167    let mut s = String::with_capacity(original_len);
168    for (index, digit) in raw.chars().rev().enumerate() {
169        if index % 3 == 0 && index != 0 && index != original_len {
170            s.insert(0, '_')
171        }
172        s.insert(0, digit)
173    }
174    s.push_str("u64");
175    if let proc_macro2::TokenTree::Literal(l) = proc_macro2::TokenStream::from_str(&s)
176        .expect("Invalid underscored literal creation, failed lexing")
177        .into_iter()
178        .next()
179        .expect("Should have made at least one token")
180    {
181        l
182    } else {
183        panic!("Created the wrong type of token when trying to make an underscored literal")
184    }
185}
186
187fn emit_struct(s: &parser::Struct, env: &Environment) -> TokenStream {
188    let schema_hash_ident = format_ident!("{}_SCHEMA_HASH", s.name.to_uppercase());
189    let schema_hash = fingerprint::struct_hash(&s, &env);
190    let schema_hash = to_underscored_literal(schema_hash);
191
192    let codec_states = gather_states(s);
193
194    let writer_states_decl_code = codec_states
195        .iter()
196        .map(|ws| emit_writer_state_decl(&ws, &env));
197
198    let reader_states_decl_code = codec_states
199        .iter()
200        .map(|rs| emit_reader_state_decl(&rs, &env));
201
202    let mut writer_states_transition_code = vec![];
203    let mut reader_states_transition_code = vec![];
204    for window in codec_states.windows(2) {
205        if let [start_state, end_state] = window {
206            writer_states_transition_code.push(emit_writer_state_transition(
207                &start_state,
208                &end_state,
209                &env,
210            ));
211            reader_states_transition_code.push(emit_reader_state_transition(
212                &start_state,
213                &end_state,
214                &env,
215            ));
216        } else {
217            panic!("Unexpected window size in state transitions")
218        }
219    }
220
221    let write_ready_type = codec_states[0].writer_ident();
222    let read_ready_type = codec_states[0].reader_ident();
223    let begin_write = format_ident!("begin_{}_write", s.name);
224    let begin_read = format_ident!("begin_{}_read", s.name);
225
226    quote! {
227        pub const #schema_hash_ident : u64 = #schema_hash;
228
229            #[inline]
230            pub fn #begin_write<W: rust_lcm_codec::StreamingWriter>(writer: &'_ mut W)
231                    -> Result<#write_ready_type<'_, W>, rust_lcm_codec::EncodeFingerprintError<W::Error>> {
232                writer.write_bytes(&#schema_hash.to_be_bytes())?;
233
234                Ok(#write_ready_type {
235                    writer
236                })
237            }
238            #[inline]
239            pub fn #begin_read<R: rust_lcm_codec::StreamingReader>(reader: &'_ mut R)
240                    -> Result<#read_ready_type<'_, R>, rust_lcm_codec::DecodeFingerprintError<R::Error>> {
241                let mut hash_buffer = 0u64.to_ne_bytes();
242                reader.read_bytes(&mut hash_buffer)?;
243                let found_hash = u64::from_be_bytes(hash_buffer);
244                if found_hash != #schema_hash_ident {
245                    return Err(rust_lcm_codec::DecodeFingerprintError::InvalidFingerprint(found_hash));
246                }
247
248                Ok(#read_ready_type {
249                    reader
250                })
251            }
252
253        #( #writer_states_decl_code )*
254
255        #( #writer_states_transition_code )*
256
257
258        #( #reader_states_decl_code )*
259
260        #( #reader_states_transition_code )*
261    }
262}
263
264fn gather_states(s: &parser::Struct) -> Vec<CodecState> {
265    let mut codec_states = Vec::new();
266
267    codec_states.insert(
268        0,
269        CodecState {
270            state_name: StateName::Done,
271            struct_name: s.name.clone(),
272            field: None,
273            baggage_dimensions: Vec::with_capacity(0),
274        },
275    );
276
277    // Iterate backwards to collect and manage required dimensional metadata
278    // Note that the current approach does not handle multidimensional arrays
279    let mut baggage_dimensions = Vec::new();
280    for (i, member) in s.members.iter().enumerate().rev() {
281        if let parser::StructMember::Field(f) = member {
282            let mut local_dynamic_dimensions = vec![];
283            if let Type::Array(at) = &f.ty {
284                for (depth, dim) in at.dimensions.iter().enumerate() {
285                    if let ArrayDimension::Dynamic { field_name } = dim {
286                        local_dynamic_dimensions.push(BaggageDimension {
287                            array_field_name: f.name.clone(),
288                            len_field_name: field_name.clone(),
289                            dimension_depth: depth,
290                        })
291                    }
292                }
293            }
294            baggage_dimensions.extend(local_dynamic_dimensions.clone());
295            if local_dynamic_dimensions.len() > 1 {
296                panic!("Arrays with more than one dimension are not yet supported");
297            }
298            let mut field_serves_as_dimension = false;
299            while let Some(bi) = baggage_dimensions
300                .iter()
301                .position(|dim| dim.len_field_name == f.name.as_str())
302            {
303                // This dimension will be discharged by the transition out of this state, no need to
304                // keep tracking it
305                let bd = baggage_dimensions.remove(bi);
306                field_serves_as_dimension = true;
307            }
308            codec_states.insert(
309                0,
310                CodecState {
311                    state_name: if i == 0 {
312                        StateName::Ready
313                    } else {
314                        StateName::HandlingField(f.name.to_owned())
315                    },
316                    struct_name: s.name.clone(),
317                    field: Some((f.clone(), field_serves_as_dimension)),
318                    baggage_dimensions: baggage_dimensions.clone(),
319                },
320            );
321        }
322    }
323    codec_states
324}
325
326fn emit_writer_state_decl(ws: &CodecState, env: &Environment) -> TokenStream {
327    let struct_ident = ws.writer_ident();
328    let allow_dead = if ws.state_name == StateName::Done {
329        Some(quote!(#[allow(dead_code)]))
330    } else {
331        None
332    };
333    let dimensions_fields = BaggageDimension::field_declarations(&ws.baggage_dimensions);
334    let (current_iter_count_field, array_item_writer_decl) = if let Some((
335        parser::Field {
336            ty: parser::Type::Array(at),
337            name,
338        },
339        _,
340    )) = &ws.field
341    {
342        let current_count_field_ident = at
343            .array_current_count_field_ident(name.as_str(), 0)
344            .expect("Arrays must have at least one dimension");
345        let item_writer_struct_ident = format_ident!("{}_item", struct_ident);
346        let array_item_writer_decl = quote! {
347            #[must_use]
348            pub struct #item_writer_struct_ident<'a, W: rust_lcm_codec::StreamingWriter> {
349                parent: &'a mut #struct_ident<'a, W>,
350            }
351        };
352        (
353            Some(quote!(#current_count_field_ident: usize, )),
354            Some(array_item_writer_decl),
355        )
356    } else {
357        (None, None)
358    };
359    let maybe_must_use = if ws.state_name != StateName::Done {
360        Some(quote!(#[must_use]))
361    } else {
362        None
363    };
364    quote! {
365        #maybe_must_use
366        pub struct #struct_ident<'a, W: rust_lcm_codec::StreamingWriter> {
367            #allow_dead
368            pub(super) writer: &'a mut W,
369            #current_iter_count_field
370            #( #dimensions_fields )*
371        }
372
373        #array_item_writer_decl
374    }
375}
376
377fn emit_reader_state_decl(rs: &CodecState, env: &Environment) -> TokenStream {
378    let struct_ident = rs.reader_ident();
379    let allow_dead = if rs.state_name == StateName::Done {
380        Some(quote!(#[allow(dead_code)]))
381    } else {
382        None
383    };
384    let dimensions_fields = BaggageDimension::field_declarations(&rs.baggage_dimensions);
385    let (current_iter_count_field, array_item_reader_decl) = if let Some((
386        parser::Field {
387            ty: parser::Type::Array(at),
388            name,
389        },
390        _,
391    )) = &rs.field
392    {
393        let current_count_field_ident = at
394            .array_current_count_field_ident(name.as_str(), 0)
395            .expect("Arrays must have at least one dimension");
396        let item_reader_struct_ident = format_ident!("{}_item", struct_ident);
397        let array_item_reader_decl = quote! {
398            #[must_use]
399            pub struct #item_reader_struct_ident<'a, R: rust_lcm_codec::StreamingReader> {
400                parent: &'a mut #struct_ident<'a, R>,
401            }
402        };
403        (
404            Some(quote!(#current_count_field_ident: usize, )),
405            Some(array_item_reader_decl),
406        )
407    } else {
408        (None, None)
409    };
410    let maybe_must_use = if rs.state_name != StateName::Done {
411        Some(quote!(#[must_use]))
412    } else {
413        None
414    };
415    quote! {
416        #maybe_must_use
417        pub struct #struct_ident<'a, W: rust_lcm_codec::StreamingReader> {
418            #allow_dead
419            pub(super) reader: &'a mut W,
420            #current_iter_count_field
421            #( #dimensions_fields )*
422        }
423
424        #array_item_reader_decl
425    }
426}
427
428fn primitive_type_to_rust(pt: &parser::PrimitiveType) -> &str {
429    match pt {
430        parser::PrimitiveType::Int8 => "i8",
431        parser::PrimitiveType::Int16 => "i16",
432        parser::PrimitiveType::Int32 => "i32",
433        parser::PrimitiveType::Int64 => "i64",
434        parser::PrimitiveType::Float => "f32",
435        parser::PrimitiveType::Double => "f64",
436        parser::PrimitiveType::String => "str",
437        parser::PrimitiveType::Boolean => "bool",
438        parser::PrimitiveType::Byte => "u8",
439    }
440}
441
442fn emit_writer_state_transition(
443    ws: &CodecState,
444    ws_next: &CodecState,
445    env: &Environment,
446) -> TokenStream {
447    match ws.field {
448        Some((ref f, serves_as_dimension)) => {
449            let start_type = ws.writer_ident();
450            let next_type = ws_next.writer_ident();
451            let write_method_ident = format_ident!("write_{}", f.name);
452            match &f.ty {
453                parser::Type::Primitive(pt) => emit_writer_field_state_transition_primitive(
454                    start_type,
455                    ws_next,
456                    f.name.as_str(),
457                    *pt,
458                    serves_as_dimension,
459                ),
460                parser::Type::Struct(st) => emit_writer_field_state_transition_struct(
461                    start_type,
462                    ws_next,
463                    f.name.as_str(),
464                    st,
465                ),
466                parser::Type::Array(at) => emit_writer_field_state_transition_array(
467                    start_type,
468                    ws_next,
469                    f.name.as_str(),
470                    at,
471                ),
472            }
473        }
474        None => quote! {},
475    }
476}
477
478#[derive(Copy, Clone, Debug)]
479enum WriterPath {
480    Bare,
481    ViaSelf,
482    ViaSelfParent,
483}
484
485impl WriterPath {
486    fn path(self) -> TokenStream {
487        match self {
488            WriterPath::Bare => quote!(writer),
489            WriterPath::ViaSelf => quote!(self.writer),
490            WriterPath::ViaSelfParent => quote!(self.parent.writer),
491        }
492    }
493}
494fn emit_write_primitive_invocation(pt: PrimitiveType, writer_path: WriterPath) -> TokenStream {
495    let path = writer_path.path();
496    match pt {
497        PrimitiveType::String => quote! {
498            rust_lcm_codec::write_str_value(val, #path)?;
499        },
500        _ => quote! {
501            rust_lcm_codec::SerializeValue::write_value(val, #path)?;
502        },
503    }
504}
505
506fn emit_next_field_current_iter_count_initialization(
507    next_state: &CodecState,
508) -> Option<TokenStream> {
509    if let Some((
510        parser::Field {
511            ty: parser::Type::Array(at),
512            name,
513        },
514        _,
515    )) = &next_state.field
516    {
517        let current_iter_count_field_ident = at
518            .array_current_count_field_ident(name.as_str(), 0)
519            .expect("Arrays must have at least one dimension");
520        Some(quote!(#current_iter_count_field_ident: 0, ))
521    } else {
522        None
523    }
524}
525
526fn emit_writer_field_state_transition_primitive(
527    start_type: Ident,
528    next_state: &CodecState,
529    field_name: &str,
530    pt: PrimitiveType,
531    field_serves_as_dimension: bool,
532) -> TokenStream {
533    let write_method_ident = format_ident!("write_{}", field_name);
534    let write_method = {
535        let maybe_ref = if pt == PrimitiveType::String {
536            Some(quote!(&))
537        } else {
538            None
539        };
540        let rust_field_type = format_ident!("{}", primitive_type_to_rust(&pt));
541        let write_invocation = emit_write_primitive_invocation(pt, WriterPath::ViaSelf);
542        let dimensional_capture = if field_serves_as_dimension {
543            let baggage_field_ident = format_ident!("baggage_{}", field_name);
544            Some(quote!(#baggage_field_ident: val as usize,))
545        } else {
546            None
547        };
548        let next_type = next_state.writer_ident();
549        let next_dimensions_fields = BaggageDimension::field_initializations_from_self(
550            next_state
551                .baggage_dimensions
552                .iter()
553                .filter(|d| !field_serves_as_dimension || d.len_field_name.as_str() != field_name),
554        );
555        let current_iter_count_initialization =
556            emit_next_field_current_iter_count_initialization(next_state);
557        quote! {
558            #[inline]
559            pub fn #write_method_ident(self, val: #maybe_ref #rust_field_type) -> Result<#next_type<'a, W>, rust_lcm_codec::EncodeValueError<W::Error>> {
560                #write_invocation
561                Ok(#next_type {
562                    writer: self.writer,
563                    #dimensional_capture
564                    #current_iter_count_initialization
565                    #( #next_dimensions_fields )*
566                })
567            }
568        }
569    };
570
571    quote! {
572        impl<'a, W: rust_lcm_codec::StreamingWriter> #start_type<'a, W> {
573            #[inline]
574            #write_method
575        }
576    }
577}
578
579fn emit_write_struct_method(
580    st: &StructType,
581    write_method_ident: Ident,
582    pre_field_write: Option<TokenStream>,
583    post_field_write: Option<TokenStream>,
584    after_field_type: TokenStream,
585    after_field_constructor: TokenStream,
586    writer_path: WriterPath,
587) -> TokenStream {
588    let field_struct_write_ready: Ident =
589        CodecState::writer_struct_state_decl_ident(&st.name, &StateName::Ready);
590    let field_struct_write_done: Ident =
591        CodecState::writer_struct_state_decl_ident(&st.name, &StateName::Done);
592    let struct_ns_prefix = if let Some(ns) = &st.namespace {
593        let namespace_ident = format_ident!("{}", ns);
594        Some(quote!(super::#namespace_ident::))
595    } else {
596        None
597    };
598    let writer_path_tokens = writer_path.path();
599    quote! {
600        #[inline]
601        pub fn #write_method_ident<F>(self, f: F) -> Result<#after_field_type, rust_lcm_codec::EncodeValueError<W::Error>>
602            where F: FnOnce(#struct_ns_prefix#field_struct_write_ready<'a, W>)
603                -> Result<#struct_ns_prefix#field_struct_write_done<'a, W>, rust_lcm_codec::EncodeValueError<W::Error>>
604        {
605            #pre_field_write
606            let ready = #struct_ns_prefix#field_struct_write_ready {
607                writer: #writer_path_tokens,
608            };
609            #[allow(unused_variables)]
610            let done = f(ready)?;
611            #post_field_write
612            Ok(#after_field_constructor)
613        }
614    }
615}
616
617fn emit_writer_field_state_transition_struct(
618    start_type: Ident,
619    next_state: &CodecState,
620    field_name: &str,
621    st: &StructType,
622) -> TokenStream {
623    let next_type = next_state.writer_ident();
624    let write_method_ident = format_ident!("write_{}", field_name);
625    let after_field_type = quote!(#next_type<'a, W>);
626
627    let current_iter_count_initialization =
628        emit_next_field_current_iter_count_initialization(next_state);
629    let next_dimensions_fields =
630        BaggageDimension::field_initializations_from_self(&next_state.baggage_dimensions);
631    let after_field_constructor = quote! {
632                #next_type {
633                    writer: done.writer,
634                    #current_iter_count_initialization
635                    #( #next_dimensions_fields )*
636                }
637    };
638    let write_method = emit_write_struct_method(
639        st,
640        write_method_ident,
641        None,
642        None,
643        after_field_type,
644        after_field_constructor,
645        WriterPath::ViaSelf,
646    );
647    quote! {
648        impl<'a, W: rust_lcm_codec::StreamingWriter> #start_type<'a, W> {
649            #[inline]
650            #write_method
651        }
652    }
653}
654
655impl ArrayType {
656    fn array_current_count_field_ident(
657        &self,
658        array_field_name: &str,
659        index: usize,
660    ) -> Option<Ident> {
661        match self.dimensions.get(index) {
662            Some(ArrayDimension::Static { size }) => {
663                // Use the field_name of the array
664                Some(format_ident!("current_{}_count", array_field_name))
665            }
666            Some(ArrayDimension::Dynamic { field_name }) => {
667                // Use the field_name of the field supplying the dynamic array length
668                Some(format_ident!("current_{}_count", field_name))
669            }
670            None => None,
671        }
672    }
673    fn array_current_count_gte_expected_check(
674        &self,
675        array_field_name: &str,
676        index: usize,
677        use_parent: bool,
678    ) -> Option<TokenStream> {
679        self.array_current_count_vs_expected(array_field_name, index, use_parent)
680            .map(
681                |CountComparisonParts {
682                     current_count,
683                     expected_count,
684                 }| quote!(#current_count >= #expected_count ),
685            )
686    }
687    fn array_current_count_under_expected_check(
688        &self,
689        array_field_name: &str,
690        index: usize,
691        use_parent: bool,
692    ) -> Option<TokenStream> {
693        self.array_current_count_vs_expected(array_field_name, index, use_parent)
694            .map(
695                |CountComparisonParts {
696                     current_count,
697                     expected_count,
698                 }| quote!(#current_count < #expected_count ),
699            )
700    }
701    fn array_current_count_remainder_value(
702        &self,
703        array_field_name: &str,
704        index: usize,
705        use_parent: bool,
706    ) -> Option<TokenStream> {
707        self.array_current_count_vs_expected(array_field_name, index, use_parent)
708            .map(
709                |CountComparisonParts {
710                     current_count,
711                     expected_count,
712                 }| quote!(#expected_count - #current_count),
713            )
714    }
715
716    fn array_current_count_vs_expected(
717        &self,
718        array_field_name: &str,
719        index: usize,
720        use_parent: bool,
721    ) -> Option<CountComparisonParts> {
722        let current_count_ident = self.array_current_count_field_ident(array_field_name, index)?;
723        let path_prefix = if use_parent {
724            quote!(self.parent)
725        } else {
726            quote!(self)
727        };
728        match self.dimensions.get(index) {
729            Some(ArrayDimension::Static { size }) => Some(CountComparisonParts {
730                current_count: quote!(#path_prefix.#current_count_ident),
731                expected_count: quote!(#size),
732            }),
733            Some(ArrayDimension::Dynamic { field_name }) => {
734                let expected_count_ident = format_ident!("baggage_{}", field_name);
735                Some(CountComparisonParts {
736                    current_count: quote!(#path_prefix.#current_count_ident),
737                    expected_count: quote!(#path_prefix.#expected_count_ident),
738                })
739            }
740            None => None,
741        }
742    }
743}
744
745struct CountComparisonParts {
746    current_count: TokenStream,
747    expected_count: TokenStream,
748}
749
750/// The goal here is to make this current state implement an Iterator
751/// which returns a number items equal to the previously-written size
752/// of this array. The items produced by the iterator are single-shot
753/// "ItemWriter" instances that exist to facilitate writing a single
754/// value.
755///
756/// After the Iterator has been exhausted, the user is expected to
757/// call `done` on this state instance to consume it and move on.
758///
759/// If the array is over bytes, provide alternatives to iterating
760/// which allow direct slice operations.
761fn emit_writer_field_state_transition_array(
762    start_type: Ident,
763    next_state: &CodecState,
764    field_name: &str,
765    at: &ArrayType,
766) -> TokenStream {
767    let current_count_ident = at
768        .array_current_count_field_ident(field_name, 0)
769        .expect("Arrays should have at least one dimension");
770    let next_type = next_state.writer_ident();
771    let next_dimensions_fields =
772        BaggageDimension::field_initializations_from_self(&next_state.baggage_dimensions);
773    let item_writer_struct_ident = format_ident!("{}_item", start_type);
774    let write_item_method_ident = format_ident!("write");
775
776    let item_writer_over_len_check = at
777        .array_current_count_gte_expected_check(field_name, 0, true)
778        .expect("Arrays should have at least one dimension");
779    let pre_field_write = Some(quote! {
780        if #item_writer_over_len_check {
781            return Err(rust_lcm_codec::EncodeValueError::ArrayLengthMismatch(
782                "array length mismatch discovered while iterating",
783            ));
784        }
785    });
786    let post_field_write = Some(quote! {
787        self.parent.#current_count_ident += 1;
788    });
789    let write_item_method = match &*at.item_type {
790        Type::Primitive(pt) => {
791            let maybe_ref = if *pt == PrimitiveType::String {
792                Some(quote!(&))
793            } else {
794                None
795            };
796            let rust_field_type = Some(format_ident!("{}", primitive_type_to_rust(&pt)));
797            let write_invocation = emit_write_primitive_invocation(*pt, WriterPath::ViaSelfParent);
798            quote! {
799                #[inline]
800                pub fn #write_item_method_ident(self, val: #maybe_ref #rust_field_type) -> Result<(), rust_lcm_codec::EncodeValueError<W::Error>> {
801                    #pre_field_write
802                    #write_invocation
803                    #post_field_write
804                    Ok(())
805                }
806            }
807        }
808        Type::Struct(st) => {
809            let after_field_type = quote!(()); // unit
810            let after_field_constructor = quote!(()); // unit instantiation looks like its typedef
811            emit_write_struct_method(
812                st,
813                write_item_method_ident,
814                pre_field_write,
815                post_field_write,
816                after_field_type,
817                after_field_constructor,
818                WriterPath::ViaSelfParent,
819            )
820        }
821        Type::Array(at) => panic!("Multidimensional arrays are not supported yet."),
822    };
823
824    let current_iter_count_initialization =
825        emit_next_field_current_iter_count_initialization(next_state);
826    let top_level_under_len_check = at
827        .array_current_count_under_expected_check(field_name, 0, false)
828        .expect("Arrays should have at least one dimension");
829
830    let (maybe_slice_writer_methods, maybe_slice_writer_outcome_definition) = match &*at.item_type {
831        Type::Primitive(PrimitiveType::Byte) => {
832            let remainder_value = at.array_current_count_remainder_value(field_name, 0, false);
833            let copy_field_from_slice_ident = format_ident!("{}_copy_from_slice", field_name);
834            let get_field_as_mut_slice_ident = format_ident!("{}_as_mut_slice", field_name);
835            let slice_writer_outcome_type_ident = format_ident!("{}AsMutSliceOutcome", field_name);
836            let slice_writer_outcome_type_definition = quote! {
837                type #slice_writer_outcome_type_ident<'a, W> = (&'a mut [core::mem::MaybeUninit<u8>], #next_type<'a, W>);
838            };
839            (
840                Some(quote! {
841                #[inline]
842                pub fn #copy_field_from_slice_ident(self, val: &[u8]) -> Result<#next_type<'a, W>, rust_lcm_codec::EncodeValueError<W::Error>> {
843                    if #remainder_value != val.len() {
844                        Err(rust_lcm_codec::EncodeValueError::ArrayLengthMismatch(
845                            "slice provided to copy_FIELD_from_slice had a length which did not match the remaining expected size of the array",
846                        ))
847                    } else {
848                        self.writer.write_bytes(val)?;
849                        Ok(#next_type {
850                            writer: self.writer,
851                            #current_iter_count_initialization
852                            #( #next_dimensions_fields )*
853                        })
854                    }
855                }
856                /// This method exposes the underlying writer's raw bytes for a region of size equal
857                /// to the previously-written array length field value (minus any values already written
858                /// via iteration).  This provides a mechanism
859                /// for doing direct operations into byte blob style fields without extraneous copies,
860                ///
861                /// Since we don't know anything about the underlying writer's bytes preceding content,
862                /// return the bytes with a type hint showing they may be uninitialized.
863                /// In implementations where the writer's backing storage mechanism is understood by the
864                /// user (e.g. backed by a previously initialized array buffer), it may be safe to
865                /// transmute the slice to a plain byte slice.
866                #[inline]
867                pub fn #get_field_as_mut_slice_ident(self) -> Result<#slice_writer_outcome_type_ident<'a, W>, rust_lcm_codec::EncodeValueError<W::Error>> {
868                        // Use transmute to help link the generated bytes reference to the underlying Writer's lifetime
869                        //
870                        // Here we depend on the documented invariant of share_bytes_mut wherein the Writer
871                        // promises not to allow itself to mutate the shared bytes at any point in the future.
872                        let shared_bytes = unsafe { core::mem::transmute(self.writer.share_bytes_mut(#remainder_value)?) };
873                        Ok((shared_bytes,
874                            #next_type {
875                                writer: self.writer,
876                                #current_iter_count_initialization
877                                #( #next_dimensions_fields )*
878                            }))
879                }
880                }),
881                Some(slice_writer_outcome_type_definition),
882            )
883        }
884        _ => (None, None),
885    };
886    // TODO - create location-specific error message for array length mismatch
887    quote! {
888
889        impl<'a, W: rust_lcm_codec::StreamingWriter> Iterator for #start_type<'a, W> {
890            type Item = #item_writer_struct_ident<'a, W>;
891            fn next(&mut self) -> Option<Self::Item> {
892                if #top_level_under_len_check {
893                    // We cheat here to allow normally-evil multiple parent-mutable
894                    // references because we know that the generated code in the
895                    // child acts on the parent in a convergent manner:
896                    // * Each child consumes itself when it exercises its only method,
897                    //   and is thus limited to a single shot at mutating the parent.
898                    // * The child mutation of the parent is gated on boundary checks in the parent
899                    //   (max child operations and the underlying writer bounds checks)
900                    unsafe {
901                        Some(#item_writer_struct_ident {
902                            parent: core::mem::transmute(self),
903                        })
904                    }
905                } else {
906                    None
907                }
908            }
909        }
910        impl<'a, W: rust_lcm_codec::StreamingWriter> #item_writer_struct_ident<'a, W> {
911            #[inline]
912            #write_item_method
913        }
914
915        #maybe_slice_writer_outcome_definition
916
917        impl<'a, W: rust_lcm_codec::StreamingWriter> #start_type<'a, W> {
918
919            #maybe_slice_writer_methods
920
921            #[inline]
922            pub fn done(self) -> Result<#next_type<'a, W>, rust_lcm_codec::EncodeValueError<W::Error>> {
923                if #top_level_under_len_check {
924                    Err(rust_lcm_codec::EncodeValueError::ArrayLengthMismatch(
925                        "array length mismatch discovered when `done` called",
926                    ))
927                } else {
928                    Ok(#next_type {
929                        writer: self.writer,
930                        #current_iter_count_initialization
931                        #( #next_dimensions_fields )*
932                    })
933                }
934            }
935        }
936    }
937}
938
939impl parser::StructType {
940    fn namespace_prefix(&self) -> Option<TokenStream> {
941        if let Some(ns) = &self.namespace {
942            let namespace_ident = format_ident!("{}", ns);
943            Some(quote!(super::#namespace_ident::))
944        } else {
945            None
946        }
947    }
948}
949
950fn emit_reader_state_transition(
951    rs: &CodecState,
952    next_state: &CodecState,
953    env: &Environment,
954) -> TokenStream {
955    match rs.field {
956        Some((ref f, field_serves_as_dimension)) => {
957            let start_type = rs.reader_ident();
958            let next_type = next_state.reader_ident();
959            let read_method_ident = format_ident!("read_{}", f.name);
960            let next_dimensions_fields = BaggageDimension::field_initializations_from_self(
961                next_state.baggage_dimensions.iter().filter(|d| {
962                    !field_serves_as_dimension || d.len_field_name.as_str() != f.name.as_str()
963                }),
964            );
965            let current_iter_count_initialization =
966                emit_next_field_current_iter_count_initialization(next_state);
967            match &f.ty {
968                Type::Primitive(pt) => {
969                    let rust_field_type = Some(format_ident!("{}", primitive_type_to_rust(&pt)));
970                    let dimensional_capture = if field_serves_as_dimension {
971                        let baggage_field_ident = format_ident!("baggage_{}", f.name);
972                        Some(quote!(#baggage_field_ident: v as usize,))
973                    } else {
974                        None
975                    };
976                    let next_state = quote! {
977                        #next_type {
978                            reader: self.reader,
979                            #dimensional_capture
980                            #current_iter_count_initialization
981                            #( #next_dimensions_fields )*
982                        }
983                    };
984                    let read_methods = match pt {
985                        PrimitiveType::String => quote! {
986                            pub fn #read_method_ident(self) -> Result<(&'a #rust_field_type, #next_type<'a, R>), rust_lcm_codec::DecodeValueError<R::Error>> {
987                                // Use transmute to link the generated string reference to the underlying Reader's lifetime
988                                let v = unsafe { core::mem::transmute(rust_lcm_codec::read_str_value(self.reader)?) };
989                                Ok((v, #next_state))
990                            }
991                        },
992                        _ => {
993                            let capture_binding = if dimensional_capture.is_some() {
994                                Some(quote!(let v = *val;))
995                            } else {
996                                None
997                            };
998                            quote! {
999                                pub fn #read_method_ident(self) -> Result<(#rust_field_type, #next_type<'a, R>), rust_lcm_codec::DecodeValueError<R::Error>> {
1000                                    let v = rust_lcm_codec::SerializeValue::read_value(self.reader)?;
1001                                    Ok((v, #next_state))
1002                                }
1003                            }
1004                        }
1005                    };
1006
1007                    quote! {
1008                        impl<'a, R: rust_lcm_codec::StreamingReader> #start_type<'a, R> {
1009
1010                            #[inline]
1011                            #read_methods
1012                        }
1013                    }
1014                }
1015                Type::Struct(st) => {
1016                    let field_struct_read_ready: Ident =
1017                        CodecState::reader_struct_state_decl_ident(&st.name, &StateName::Ready);
1018                    let field_struct_read_done: Ident =
1019                        CodecState::reader_struct_state_decl_ident(&st.name, &StateName::Done);
1020                    let struct_ns_prefix = st.namespace_prefix();
1021                    quote! {
1022                        impl<'a, R: rust_lcm_codec::StreamingReader> #start_type<'a, R> {
1023
1024                            #[inline]
1025                            pub fn #read_method_ident<F>(self, f: F) -> Result<#next_type<'a, R>, rust_lcm_codec::DecodeValueError<R::Error>>
1026                                where F: FnOnce(#struct_ns_prefix#field_struct_read_ready<'a, R>) -> Result<#struct_ns_prefix#field_struct_read_done<'a, R>, rust_lcm_codec::DecodeValueError<R::Error>>
1027                            {
1028                                let ready = #struct_ns_prefix#field_struct_read_ready {
1029                                    reader: self.reader,
1030                                };
1031                                let done = f(ready)?;
1032                                Ok(#next_type {
1033                                    reader: done.reader,
1034                                    #current_iter_count_initialization
1035                                    #( #next_dimensions_fields )*
1036                                })
1037                            }
1038                        }
1039                    }
1040                }
1041                Type::Array(at) => {
1042                    let read_method_ident = format_ident!("read");
1043                    let current_iter_count_field_ident = at
1044                        .array_current_count_field_ident(f.name.as_str(), 0)
1045                        .expect("Arrays should have at least one dimension");
1046                    let item_reader_over_len_check = at
1047                        .array_current_count_gte_expected_check(f.name.as_str(), 0, true)
1048                        .expect("Arrays should have at least one dimension");
1049                    let pre_field_read = quote! {
1050                        if #item_reader_over_len_check {
1051                            return Err(rust_lcm_codec::DecodeValueError::ArrayLengthMismatch(
1052                                "array length mismatch discovered while iterating to read",
1053                            ));
1054                        }
1055                    };
1056                    let post_field_read = quote!(self.parent.#current_iter_count_field_ident += 1;);
1057                    let read_item_method = match &*at.item_type {
1058                        Type::Primitive(pt) => {
1059                            let rust_field_type = format_ident!("{}", primitive_type_to_rust(pt));
1060                            match pt {
1061                                PrimitiveType::String => quote! {
1062                                    pub fn #read_method_ident(self) -> Result<&'a #rust_field_type, rust_lcm_codec::DecodeValueError<R::Error>> {
1063                                        #pre_field_read
1064                                        // Use transmute to link the generated string reference to the underlying Reader's lifetime
1065                                        let v = unsafe { core::mem::transmute(rust_lcm_codec::read_str_value(self.parent.reader)?) };
1066                                        #post_field_read
1067                                        Ok(v)
1068                                    }
1069                                },
1070                                _ => quote! {
1071                                    pub fn #read_method_ident(self) -> Result<#rust_field_type, rust_lcm_codec::DecodeValueError<R::Error>> {
1072                                        #pre_field_read
1073                                        let v = rust_lcm_codec::SerializeValue::read_value(self.parent.reader)?;
1074                                        #post_field_read
1075                                        Ok(v)
1076                                    }
1077                                },
1078                            }
1079                        }
1080
1081                        Type::Struct(st) => {
1082                            let struct_ns_prefix = st.namespace_prefix();
1083                            let field_struct_read_ready: Ident =
1084                                CodecState::reader_struct_state_decl_ident(
1085                                    &st.name,
1086                                    &StateName::Ready,
1087                                );
1088                            let field_struct_read_done: Ident =
1089                                CodecState::reader_struct_state_decl_ident(
1090                                    &st.name,
1091                                    &StateName::Done,
1092                                );
1093                            quote! {
1094                                pub fn #read_method_ident<F>(self, f: F) -> Result<(), rust_lcm_codec::DecodeValueError<R::Error>>
1095                                    where F: FnOnce(#struct_ns_prefix#field_struct_read_ready<'a, R>) -> Result<#struct_ns_prefix#field_struct_read_done<'a, R>, rust_lcm_codec::DecodeValueError<R::Error>>
1096                                {
1097                                    #pre_field_read
1098                                    let ready = #struct_ns_prefix#field_struct_read_ready {
1099                                        reader: self.parent.reader,
1100                                    };
1101                                    let _done = f(ready)?;
1102                                    #post_field_read
1103                                    Ok(())
1104                                }
1105                            }
1106                        }
1107                        Type::Array(at) => panic!("Multidimensional arrays are not supported yet."),
1108                    };
1109                    let item_reader_struct_ident = format_ident!("{}_item", start_type);
1110                    let read_item_method_ident = format_ident!("read");
1111                    let top_level_under_len_check = at
1112                        .array_current_count_under_expected_check(f.name.as_str(), 0, false)
1113                        .expect("Arrays should have at least one dimension");
1114                    let (maybe_slice_reader_method, maybe_slice_reader_outcome_type) = match &*at
1115                        .item_type
1116                    {
1117                        Type::Primitive(PrimitiveType::Byte) => {
1118                            let field_name = f.name.as_str();
1119                            let remainder_value =
1120                                at.array_current_count_remainder_value(field_name, 0, false);
1121                            let get_field_as_slice_ident = format_ident!("{}_as_slice", field_name);
1122                            let slice_reader_outcome_type_ident =
1123                                format_ident!("{}AsSliceOutcome", field_name);
1124                            let slice_reader_outcome_type_definition = quote! {
1125                                type #slice_reader_outcome_type_ident<'a, R> = (&'a [u8], #next_type<'a, R>);
1126                            };
1127                            (
1128                                Some(quote! {
1129                                /// This method exposes the underlying reader's raw bytes for a region of size equal
1130                                /// to the previously-written array length field value (minus any values
1131                                /// previously read through iteration).  This provides a mechanism
1132                                /// for doing direct operations from byte blob style fields without extraneous copies,
1133                                #[inline]
1134                                pub fn #get_field_as_slice_ident(self) -> Result<#slice_reader_outcome_type_ident<'a, R>, rust_lcm_codec::DecodeValueError<R::Error>> {
1135                                        // Use transmute to help link the generated bytes reference to the underlying Reader's lifetime
1136                                        //
1137                                        // Here we depend on the documented invariant of share_bytes wherein the Reader
1138                                        // promises not to allow itself to mutate the shared bytes at any point in the future.
1139                                        let shared_bytes = unsafe { core::mem::transmute(self.reader.share_bytes(#remainder_value)?) };
1140                                        Ok((shared_bytes,
1141                                            #next_type {
1142                                                reader: self.reader,
1143                                                #current_iter_count_initialization
1144                                                #( #next_dimensions_fields )*
1145                                            }))
1146                                }
1147                                }),
1148                                Some(slice_reader_outcome_type_definition),
1149                            )
1150                        }
1151                        _ => (None, None),
1152                    };
1153                    quote! {
1154                        impl<'a, R: rust_lcm_codec::StreamingReader> Iterator for #start_type<'a, R> {
1155                            type Item = #item_reader_struct_ident<'a, R>;
1156                            fn next(&mut self) -> Option<Self::Item> {
1157                                if #top_level_under_len_check {
1158                                    // We cheat here to allow normally-evil multiple parent-mutable
1159                                    // references because we know that the generated code in the
1160                                    // child acts on the parent in a convergent manner:
1161                                    // * Each child consumes itself when it exercises its only method,
1162                                    //   and is thus limited to a single shot at mutating the parent.
1163                                    // * The child mutation of the parent is gated on boundary checks in the parent
1164                                    //   (max child operations and the underlying reader bounds checks)
1165                                    unsafe {
1166                                        Some(#item_reader_struct_ident {
1167                                            parent: core::mem::transmute(self),
1168                                        })
1169                                    }
1170                                } else {
1171                                    None
1172                                }
1173                            }
1174                        }
1175                        impl<'a, R: rust_lcm_codec::StreamingReader> #item_reader_struct_ident<'a, R> {
1176                            #[inline]
1177                            #read_item_method
1178                        }
1179
1180                        #maybe_slice_reader_outcome_type
1181
1182                        impl<'a, R: rust_lcm_codec::StreamingReader> #start_type<'a, R> {
1183                            #[inline]
1184                            pub fn done(self) -> Result<#next_type<'a, R>, rust_lcm_codec::DecodeValueError<R::Error>> {
1185                                if #top_level_under_len_check {
1186                                    Err(rust_lcm_codec::DecodeValueError::ArrayLengthMismatch(
1187                                        "array length mismatch discovered when read `done` called",
1188                                    ))
1189                                } else {
1190                                    Ok(#next_type {
1191                                        reader: self.reader,
1192                                        #current_iter_count_initialization
1193                                        #( #next_dimensions_fields )*
1194                                    })
1195                                }
1196                            }
1197
1198                            #maybe_slice_reader_method
1199                        }
1200                    }
1201                }
1202            }
1203        }
1204        None => quote! {},
1205    }
1206}
1207
1208/// Collection of a schema and its peers.
1209pub struct Environment {
1210    local_schema: parser::Schema,
1211    all_schemas: Vec<parser::Schema>,
1212}
1213
1214impl Environment {
1215    /// Find a struct in the environment by it's StructType (name + ns)
1216    fn resolve_struct_type(&self, st: &parser::StructType) -> Option<&parser::Struct> {
1217        match &st.namespace {
1218            None => self
1219                .local_schema
1220                .structs
1221                .iter()
1222                .find(|curr_st| curr_st.name == st.name),
1223            Some(ns) => {
1224                for sch in self.all_schemas.iter() {
1225                    match &sch.package {
1226                        Some(this_ns) => {
1227                            if this_ns == ns {
1228                                for curr_st in sch.structs.iter() {
1229                                    if curr_st.name == st.name {
1230                                        return Some(curr_st);
1231                                    }
1232                                }
1233                            }
1234                        }
1235                        None => (),
1236                    }
1237                }
1238                None
1239            }
1240        }
1241    }
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246    use super::*;
1247    #[test]
1248    fn manual_underscore_integer_check() {
1249        assert_eq!("0u64", format!("{}", to_underscored_literal(0)));
1250        assert_eq!("1u64", format!("{}", to_underscored_literal(1)));
1251        assert_eq!("10u64", format!("{}", to_underscored_literal(10)));
1252        assert_eq!("100u64", format!("{}", to_underscored_literal(100)));
1253        assert_eq!("1_000u64", format!("{}", to_underscored_literal(1_000)));
1254        assert_eq!("10_000u64", format!("{}", to_underscored_literal(10_000)));
1255        assert_eq!("100_000u64", format!("{}", to_underscored_literal(100_000)));
1256        assert_eq!(
1257            "1_000_000u64",
1258            format!("{}", to_underscored_literal(1_000_000))
1259        );
1260    }
1261}