weld_codegen/
encode_rust.rs

1// CBOR Encode functions
2//
3// Because we have all the type information for declared types,
4// we can invoke the appropriate encode_* functions for each
5// simple type, structure, map, and array. (later: enums).
6// For Rust, it would have been a little easier to use serde code generation,
7// but we want to create a code base that is easy to port to other languages
8// that don't have serde.
9
10// The encoder is written as a plain function "encode_<S>" where S is the type name
11// (camel cased for the fn name), and scoped to the module where S is defined.
12use std::{fmt::Write as _, string::ToString};
13
14use atelier_core::{
15    model::{
16        shapes::{HasTraits, ShapeKind, Simple, StructureOrUnion},
17        HasIdentity, ShapeID,
18    },
19    prelude::{
20        prelude_namespace_id, SHAPE_BIGDECIMAL, SHAPE_BIGINTEGER, SHAPE_BLOB, SHAPE_BOOLEAN,
21        SHAPE_BYTE, SHAPE_DOCUMENT, SHAPE_DOUBLE, SHAPE_FLOAT, SHAPE_INTEGER, SHAPE_LONG,
22        SHAPE_PRIMITIVEBOOLEAN, SHAPE_PRIMITIVEBYTE, SHAPE_PRIMITIVEDOUBLE, SHAPE_PRIMITIVEFLOAT,
23        SHAPE_PRIMITIVEINTEGER, SHAPE_PRIMITIVELONG, SHAPE_PRIMITIVESHORT, SHAPE_SHORT,
24        SHAPE_STRING, SHAPE_TIMESTAMP,
25    },
26};
27
28use crate::{
29    codegen_rust::{is_optional_type, is_rust_primitive, RustCodeGen},
30    error::{Error, Result},
31    gen::CodeGen,
32    model::wasmcloud_model_namespace,
33    writer::Writer,
34};
35
36type IsEmptyStruct = bool;
37
38#[derive(Clone, Copy)]
39pub(crate) enum ValExpr<'s> {
40    Plain(&'s str),
41    Ref(&'s str),
42}
43impl<'s> ValExpr<'s> {
44    /// returns borrowed reference to value
45    pub(crate) fn as_ref(&self) -> String {
46        match self {
47            ValExpr::Plain(s) => format!("&{s}"),
48            ValExpr::Ref(s) => s.to_string(),
49        }
50    }
51
52    /// returns value as-is
53    pub(crate) fn as_str(&self) -> &str {
54        match self {
55            ValExpr::Plain(s) => s,
56            ValExpr::Ref(s) => s,
57        }
58    }
59
60    /// returns value for copyable types
61    pub(crate) fn as_copy(&self) -> String {
62        match self {
63            ValExpr::Plain(s) => s.to_string(),
64            ValExpr::Ref(s) => format!("*{s}"),
65        }
66    }
67}
68
69// encode_* methods encode a base/simple type
70
71fn encode_blob(val: ValExpr) -> String {
72    format!("e.bytes({})?;\n", &val.as_ref())
73}
74fn encode_boolean(val: ValExpr) -> String {
75    format!("e.bool({})?;\n", val.as_copy())
76}
77fn encode_str(val: ValExpr) -> String {
78    format!("e.str({})?;\n", val.as_ref())
79}
80fn encode_byte(val: ValExpr) -> String {
81    format!("e.i8({})?;\n", val.as_copy())
82}
83fn encode_unsigned_byte(val: ValExpr) -> String {
84    format!("e.u8({})?;\n", val.as_copy())
85}
86fn encode_short(val: ValExpr) -> String {
87    format!("e.i16({})?;\n", val.as_copy())
88}
89fn encode_unsigned_short(val: ValExpr) -> String {
90    format!("e.u16({})?;\n", val.as_copy())
91}
92fn encode_integer(val: ValExpr) -> String {
93    format!("e.i32({})?;\n", val.as_copy())
94}
95fn encode_unsigned_integer(val: ValExpr) -> String {
96    format!("e.u32({})?;\n", val.as_copy())
97}
98fn encode_long(val: ValExpr) -> String {
99    format!("e.i64({})?;\n", val.as_copy())
100}
101fn encode_unsigned_long(val: ValExpr) -> String {
102    format!("e.u64({})?;\n", val.as_copy())
103}
104fn encode_float(val: ValExpr) -> String {
105    format!("e.f32({})?;\n", val.as_copy())
106}
107fn encode_double(val: ValExpr) -> String {
108    format!("e.f64({})?;\n", val.as_copy())
109}
110fn encode_document(val: ValExpr) -> String {
111    format!(
112        "wasmbus_rpc::common::encode_document(&mut e, {})?;\n",
113        val.as_ref()
114    )
115}
116fn encode_unit() -> String {
117    "e.null()?;\n".to_string()
118}
119fn encode_timestamp(val: ValExpr) -> String {
120    format!(
121        "e.i64({}.sec)?;\ne.u32({}.nsec)?;\n",
122        val.as_str(),
123        val.as_str()
124    )
125}
126fn encode_big_integer(_val: ValExpr) -> String {
127    todo!(); // tag big int
128}
129fn encode_big_decimal(_val: ValExpr) -> String {
130    todo!() // tag big decimal
131}
132
133impl<'model> RustCodeGen<'model> {
134    /// Generates cbor encode statements "e.func()" for the id.
135    /// If id is a primitive type, writes the direct encode function, otherwise,
136    /// delegates to an encode_* function created in the same module where the symbol is defined
137    pub(crate) fn encode_shape_id(
138        &self,
139        id: &ShapeID,
140        val: ValExpr,
141        enc_owned: bool, // true if encoder is owned in current fn
142    ) -> Result<String> {
143        let name = id.shape_name().to_string();
144        let stmt = if id.namespace() == prelude_namespace_id() {
145            match name.as_ref() {
146                SHAPE_BLOB => encode_blob(val),
147                SHAPE_BOOLEAN | SHAPE_PRIMITIVEBOOLEAN => encode_boolean(val),
148                SHAPE_STRING => encode_str(val),
149                SHAPE_BYTE | SHAPE_PRIMITIVEBYTE => encode_byte(val),
150                SHAPE_SHORT | SHAPE_PRIMITIVESHORT => encode_short(val),
151                SHAPE_INTEGER | SHAPE_PRIMITIVEINTEGER => encode_integer(val),
152                SHAPE_LONG | SHAPE_PRIMITIVELONG => encode_long(val),
153                SHAPE_FLOAT | SHAPE_PRIMITIVEFLOAT => encode_float(val),
154                SHAPE_DOUBLE | SHAPE_PRIMITIVEDOUBLE => encode_double(val),
155                SHAPE_TIMESTAMP => encode_timestamp(val),
156                SHAPE_BIGINTEGER => encode_big_integer(val),
157                SHAPE_BIGDECIMAL => encode_big_decimal(val),
158                SHAPE_DOCUMENT => encode_document(val),
159                _ => return Err(Error::UnsupportedType(name)),
160            }
161        } else if id.namespace() == wasmcloud_model_namespace() {
162            match name.as_bytes() {
163                b"U64" => encode_unsigned_long(val),
164                b"U32" => encode_unsigned_integer(val),
165                b"U16" => encode_unsigned_short(val),
166                b"U8" => encode_unsigned_byte(val),
167                b"I64" => encode_long(val),
168                b"I32" => encode_integer(val),
169                b"I16" => encode_short(val),
170                b"I8" => encode_byte(val),
171                b"F64" => encode_double(val),
172                b"F32" => encode_float(val),
173                _ => format!(
174                    "{}encode_{}( e, {})?;\n",
175                    &self.get_model_crate(),
176                    crate::strings::to_snake_case(&id.shape_name().to_string()),
177                    val.as_ref()
178                ),
179            }
180        } else {
181            format!(
182                "{}encode_{}({} e, {})?;\n",
183                self.get_crate_path(id)?,
184                crate::strings::to_snake_case(&id.shape_name().to_string()),
185                if enc_owned { "&mut " } else { "" },
186                val.as_ref(),
187            )
188        };
189        Ok(stmt)
190    }
191
192    /// Generates statements to encode the shape.
193    /// Second Result field is true if structure has no fields, e.g., "MyStruct {}"
194    fn encode_shape_kind(
195        &self,
196        id: &ShapeID,
197        kind: &ShapeKind,
198        val: ValExpr,
199    ) -> Result<(String, IsEmptyStruct)> {
200        let mut empty_struct: IsEmptyStruct = false;
201        let s = match kind {
202            ShapeKind::Simple(simple) => match simple {
203                Simple::Blob => encode_blob(val),
204                Simple::Boolean => encode_boolean(val),
205                Simple::String => encode_str(val),
206                Simple::Byte => encode_byte(val),
207                Simple::Short => encode_short(val),
208                Simple::Integer => encode_integer(val),
209                Simple::Long => encode_long(val),
210                Simple::Float => encode_float(val),
211                Simple::Double => encode_double(val),
212                Simple::Timestamp => encode_timestamp(val),
213                Simple::BigInteger => encode_big_integer(val),
214                Simple::BigDecimal => encode_big_decimal(val),
215                Simple::Document => encode_document(val),
216            },
217            ShapeKind::Map(map) => {
218                let mut s = format!(
219                    r#"
220                    e.map({}.len() as u64)?;
221                    for (k,v) in {} {{
222                    "#,
223                    val.as_str(),
224                    val.as_str()
225                );
226                s.push_str(&self.encode_shape_id(map.key().target(), ValExpr::Ref("k"), false)?);
227                s.push_str(&self.encode_shape_id(
228                    map.value().target(),
229                    ValExpr::Ref("v"),
230                    false,
231                )?);
232                s.push_str(
233                    r#"
234                    }
235                    "#,
236                );
237                s
238            }
239            ShapeKind::List(list) => {
240                let mut s = format!(
241                    r#"
242                    e.array({}.len() as u64)?;
243                    for item in {}.iter() {{
244                    "#,
245                    val.as_str(),
246                    val.as_str()
247                );
248
249                s.push_str(&self.encode_shape_id(
250                    list.member().target(),
251                    ValExpr::Ref("item"),
252                    false,
253                )?);
254                s.push('}');
255                s
256            }
257            ShapeKind::Set(set) => {
258                let mut s = format!(
259                    r#"
260                    e.array({}.len() as u64)?;
261                    for v in {}.iter() {{
262                    "#,
263                    val.as_str(),
264                    val.as_str()
265                );
266                s.push_str(&self.encode_shape_id(
267                    set.member().target(),
268                    ValExpr::Ref("v"),
269                    false,
270                )?);
271                s.push_str(
272                    r#"
273                    }
274                    "#,
275                );
276                s
277            }
278            ShapeKind::Structure(struct_) => {
279                if id != crate::model::unit_shape() {
280                    let (s, is_empty_struct) = self.encode_struct(id, struct_, val)?;
281                    empty_struct = is_empty_struct;
282                    s
283                } else {
284                    encode_unit()
285                }
286            }
287            ShapeKind::Union(union_) => {
288                let (s, _) = self.encode_union(id, union_, val)?;
289                s
290            }
291            ShapeKind::Operation(_)
292            | ShapeKind::Resource(_)
293            | ShapeKind::Service(_)
294            | ShapeKind::Unresolved => String::new(),
295        };
296        Ok((s, empty_struct))
297    }
298
299    /// Generate string to encode union.
300    fn encode_union(
301        &self,
302        id: &ShapeID,
303        strukt: &StructureOrUnion,
304        val: ValExpr,
305    ) -> Result<(String, IsEmptyStruct)> {
306        let (fields, _) = crate::model::get_sorted_fields(id.shape_name(), strukt)?;
307        // FUTURE: if all variants are unit, this can be encoded as an int, not array
308        //   .. but decoder would have to peek to distinguish array from int
309        //let is_all_unit = fields
310        //    .iter()
311        //    .all(|f| f.target() == crate::model::unit_shape());
312        let is_all_unit = false; // for now, stick with array
313        let mut s = String::new();
314        writeln!(
315            s,
316            "// encoding union {}\n e.array(2)?;\n match {} {{",
317            id.shape_name(),
318            val.as_str()
319        )
320        .unwrap();
321        for field in fields.iter() {
322            let target = field.target();
323            let field_name = self.to_type_name_case(&field.id().to_string());
324            if target == crate::model::unit_shape() {
325                writeln!(
326                    s,
327                    "{}::{} => {{ e.u16({})?;",
328                    id.shape_name(),
329                    &field_name,
330                    &field.field_num().unwrap()
331                )
332                .unwrap();
333                if !is_all_unit {
334                    s.push_str(&encode_unit());
335                }
336            } else {
337                writeln!(
338                    s,
339                    "{}::{}(v) => {{ e.u16({})?;",
340                    id.shape_name(),
341                    &field_name,
342                    &field.field_num().unwrap()
343                )
344                .unwrap();
345                s.push_str(&self.encode_shape_id(target, ValExpr::Ref("v"), false)?);
346            }
347            s.push_str("},\n");
348        }
349        s.push_str("}\n");
350        Ok((s, fields.is_empty()))
351    }
352
353    /// Generate string to encode structure.
354    /// Second Result field is true if structure has no fields, e.g., "MyStruct {}"
355    fn encode_struct(
356        &self,
357        id: &ShapeID,
358        strukt: &StructureOrUnion,
359        val: ValExpr,
360    ) -> Result<(String, IsEmptyStruct)> {
361        let (fields, is_numbered) = crate::model::get_sorted_fields(id.shape_name(), strukt)?;
362        // use array encoding if fields are declared with numbers
363        let as_array = is_numbered;
364        let field_max_index = if as_array && !fields.is_empty() {
365            fields.iter().map(|f| f.field_num().unwrap()).max().unwrap()
366        } else {
367            fields.len() as u16
368        };
369        let mut s = String::new();
370        if as_array {
371            writeln!(s, "e.array({})?;", field_max_index + 1).unwrap();
372        } else {
373            writeln!(s, "e.map({})?;", fields.len()).unwrap();
374        }
375        let mut current_index = 0;
376        for field in fields.iter() {
377            if let Some(field_num) = field.field_num() {
378                if as_array {
379                    while current_index < *field_num {
380                        writeln!(s, "e.null()?;").unwrap();
381                        current_index += 1;
382                    }
383                }
384            }
385            let field_name = self.to_field_name(field.id(), field.traits())?;
386            let field_val = self.encode_shape_id(field.target(), ValExpr::Ref("val"), false)?;
387            if is_optional_type(field) {
388                writeln!(
389                    s,
390                    "if let Some(val) =  {}.{}.as_ref() {{",
391                    val.as_str(),
392                    &field_name
393                )
394                .unwrap();
395                if !as_array {
396                    // map key is declared name, not target language name
397                    writeln!(s, "e.str(\"{}\")?;", field.id()).unwrap();
398                }
399                writeln!(s, "{} }} else {{ e.null()?; }}", &field_val).unwrap();
400            } else {
401                if !as_array {
402                    // map key is declared name, not target language name
403                    writeln!(s, "e.str(\"{}\")?;", field.id()).unwrap();
404                }
405                let val = format!("{}.{}", val.as_str(), &field_name);
406                s.push_str(&self.encode_shape_id(field.target(), ValExpr::Plain(&val), false)?);
407            }
408            current_index += 1;
409        }
410        Ok((s, fields.is_empty()))
411    }
412
413    pub(crate) fn declare_shape_encoder(
414        &self,
415        w: &mut Writer,
416        id: &ShapeID,
417        kind: &ShapeKind,
418    ) -> Result<()> {
419        // The encoder is written as a plain function "encode_<S>" where S is the type name
420        // (camel cased for the fn name), and scoped to the module where S is defined. This could
421        // have been implemented as 'impl Encode for TYPE ...', but that would make the code more
422        // rust-specific. This code is structured to be easier to port to other target languages.
423        match kind {
424            ShapeKind::Simple(_)
425            | ShapeKind::Structure(_)
426            | ShapeKind::Union(_)
427            | ShapeKind::Map(_)
428            | ShapeKind::List(_)
429            | ShapeKind::Set(_) => {
430                let name = id.shape_name();
431                // use val-by-copy as param to encode if type is rust primitive "copy" type
432                // This is only relevant for aliases of primitive types in wasmbus-model namespace
433                let is_rust_copy = is_rust_primitive(id);
434                // The purpose of is_empty_struct is to determine when the parameter is unused
435                // in the function body, and prepend '_' to the name to avoid a compiler warning.
436                let (body, is_empty_struct) =
437                    self.encode_shape_kind(id, kind, ValExpr::Ref("val"))?;
438                let mut s = format!(
439                    r#" 
440                // Encode {} as CBOR and append to output stream
441                #[doc(hidden)] #[allow(unused_mut)] {}
442                pub fn encode_{}<W: {}::cbor::Write>(
443                    mut e: &mut {}::cbor::Encoder<W>, {}: &{}) -> RpcResult<()>
444                    where <W as {}::cbor::Write>::Error: std::fmt::Display
445                {{
446                "#,
447                    &name,
448                    if is_rust_copy { "#[inline]" } else { "" },
449                    crate::strings::to_snake_case(&name.to_string()),
450                    self.import_core,
451                    self.import_core,
452                    if is_empty_struct { "_val" } else { "val" },
453                    &id.shape_name(),
454                    self.import_core,
455                );
456                s.push_str(&body);
457                s.push_str("Ok(())\n}\n");
458                w.write(s.as_bytes());
459            }
460            ShapeKind::Operation(_)
461            | ShapeKind::Resource(_)
462            | ShapeKind::Service(_)
463            | ShapeKind::Unresolved => { /* write nothing */ }
464        }
465        Ok(())
466    }
467}