1use std::collections::HashSet;
2
3use vexil_lang::ast::{PrimitiveType, SemanticType};
4use vexil_lang::ir::{
5 Encoding, FieldEncoding, MessageDef, ResolvedType, TypeDef, TypeId, TypeRegistry,
6};
7
8use crate::annotations::{emit_field_annotations, emit_tombstones, emit_type_annotations};
9use crate::emit::CodeWriter;
10use crate::types::rust_type;
11
12pub fn is_byte_aligned(ty: &ResolvedType, registry: &TypeRegistry) -> bool {
20 match ty {
21 ResolvedType::Primitive(PrimitiveType::Bool) => false,
22 ResolvedType::SubByte(_) => false,
23 ResolvedType::Named(id) => {
24 if let Some(TypeDef::Enum(e)) = registry.get(*id) {
26 e.wire_bits >= 8
27 } else {
28 true
29 }
30 }
31 ResolvedType::Optional(inner) => is_byte_aligned(inner, registry),
32 _ => true,
33 }
34}
35
36fn primitive_bits(p: &PrimitiveType) -> u8 {
41 match p {
42 PrimitiveType::I8 | PrimitiveType::U8 => 8,
43 PrimitiveType::I16 | PrimitiveType::U16 => 16,
44 PrimitiveType::I32 | PrimitiveType::U32 | PrimitiveType::F32 => 32,
45 PrimitiveType::I64 | PrimitiveType::U64 | PrimitiveType::F64 => 64,
46 _ => 0,
47 }
48}
49
50pub fn emit_write(
59 w: &mut CodeWriter,
60 access: &str,
61 ty: &ResolvedType,
62 enc: &FieldEncoding,
63 registry: &TypeRegistry,
64 field_name: &str,
65) {
66 match &enc.encoding {
68 Encoding::Varint => {
69 if let Some(limit) = enc.limit {
70 w.line(&format!(
71 "if ({access} as u64) > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {access} as u64 }}); }}"
72 ));
73 }
74 w.line(&format!("w.write_leb128({access} as u64);"));
75 return;
76 }
77 Encoding::ZigZag => {
78 if let Some(limit) = enc.limit {
79 w.line(&format!(
80 "if ({access} as i64).unsigned_abs() > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: ({access} as i64).unsigned_abs() }}); }}"
81 ));
82 }
83 let type_bits = match ty {
84 ResolvedType::Primitive(p) => primitive_bits(p),
85 _ => 64,
86 };
87 w.line(&format!("w.write_zigzag({access} as i64, {type_bits}_u8);"));
88 return;
89 }
90 Encoding::Delta(inner) => {
91 let base_enc = FieldEncoding {
94 encoding: *inner.clone(),
95 limit: enc.limit,
96 };
97 emit_write(w, access, ty, &base_enc, registry, field_name);
98 return;
99 }
100 Encoding::Default => {} _ => {} }
103
104 if let Some(limit) = enc.limit {
106 match ty {
107 ResolvedType::Array(_)
108 | ResolvedType::Map(_, _)
109 | ResolvedType::Semantic(SemanticType::String)
110 | ResolvedType::Semantic(SemanticType::Bytes) => {
111 w.line(&format!(
112 "if ({access}).len() as u64 > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: ({access}).len() as u64 }}); }}"
113 ));
114 }
115 _ => {}
116 }
117 }
118
119 emit_write_type(w, access, ty, registry, field_name);
120}
121
122#[allow(clippy::only_used_in_recursion)]
123fn emit_write_type(
124 w: &mut CodeWriter,
125 access: &str,
126 ty: &ResolvedType,
127 registry: &TypeRegistry,
128 field_name: &str,
129) {
130 match ty {
131 ResolvedType::Primitive(p) => match p {
132 PrimitiveType::Bool => w.line(&format!("w.write_bool({access});")),
133 PrimitiveType::U8 => w.line(&format!("w.write_u8({access});")),
134 PrimitiveType::U16 => w.line(&format!("w.write_u16({access});")),
135 PrimitiveType::U32 => w.line(&format!("w.write_u32({access});")),
136 PrimitiveType::U64 => w.line(&format!("w.write_u64({access});")),
137 PrimitiveType::I8 => w.line(&format!("w.write_i8({access});")),
138 PrimitiveType::I16 => w.line(&format!("w.write_i16({access});")),
139 PrimitiveType::I32 => w.line(&format!("w.write_i32({access});")),
140 PrimitiveType::I64 => w.line(&format!("w.write_i64({access});")),
141 PrimitiveType::F32 => w.line(&format!("w.write_f32({access});")),
142 PrimitiveType::F64 => w.line(&format!("w.write_f64({access});")),
143 PrimitiveType::Void => {} },
145 ResolvedType::SubByte(s) => {
146 let bits = s.bits;
147 if s.signed {
148 w.line(&format!("w.write_bits({access} as u8 as u64, {bits}_u8);"));
149 } else {
150 w.line(&format!("w.write_bits({access} as u64, {bits}_u8);"));
151 }
152 }
153 ResolvedType::Semantic(s) => match s {
154 SemanticType::String => w.line(&format!("w.write_string(&{access});")),
155 SemanticType::Bytes => w.line(&format!("w.write_bytes(&{access});")),
156 SemanticType::Rgb => {
157 w.line(&format!("w.write_u8({access}.0);"));
158 w.line(&format!("w.write_u8({access}.1);"));
159 w.line(&format!("w.write_u8({access}.2);"));
160 }
161 SemanticType::Uuid => w.line(&format!("w.write_raw_bytes(&{access});")),
162 SemanticType::Timestamp => w.line(&format!("w.write_i64({access});")),
163 SemanticType::Hash => w.line(&format!("w.write_raw_bytes(&{access});")),
164 },
165 ResolvedType::Named(_) => {
166 w.line(&format!("{access}.pack(w)?;"));
167 }
168 ResolvedType::Optional(inner) => {
169 w.line(&format!("w.write_bool({access}.is_some());"));
171 if is_byte_aligned(inner, registry) {
173 w.line("w.flush_to_byte_boundary();");
174 }
179 w.open_block(&format!("if let Some(ref inner_val) = {access}"));
180 emit_write_type(w, "inner_val", inner, registry, field_name);
181 w.close_block();
182 }
183 ResolvedType::Array(inner) => {
184 w.line(&format!("w.write_leb128({access}.len() as u64);"));
185 w.open_block(&format!("for item in &{access}"));
186 emit_write_type(w, "item", inner, registry, field_name);
187 w.close_block();
188 }
189 ResolvedType::Map(k, v) => {
190 w.line(&format!("w.write_leb128({access}.len() as u64);"));
191 w.open_block(&format!("for (map_k, map_v) in &{access}"));
192 emit_write_type(w, "map_k", k, registry, field_name);
193 emit_write_type(w, "map_v", v, registry, field_name);
194 w.close_block();
195 }
196 ResolvedType::Result(ok, err) => {
197 w.open_block(&format!("match &{access}"));
198 w.open_block("Ok(ok_val) =>");
199 w.line("w.write_bool(true);");
200 emit_write_type(w, "ok_val", ok, registry, field_name);
201 w.close_block();
202 w.open_block("Err(err_val) =>");
203 w.line("w.write_bool(false);");
204 emit_write_type(w, "err_val", err, registry, field_name);
205 w.close_block();
206 w.close_block();
207 }
208 _ => {} }
210}
211
212pub fn emit_read(
220 w: &mut CodeWriter,
221 var_name: &str,
222 ty: &ResolvedType,
223 enc: &FieldEncoding,
224 registry: &TypeRegistry,
225 field_name: &str,
226) {
227 match &enc.encoding {
228 Encoding::Varint => {
229 w.line(&format!("let {var_name}_raw = r.read_leb128(10_u8)?;"));
231 if let Some(limit) = enc.limit {
232 w.line(&format!(
233 "if {var_name}_raw > {limit}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {var_name}_raw }}); }}"
234 ));
235 }
236 let rust_ty = read_cast_for_varint(ty);
238 w.line(&format!(
239 "let {var_name}: {rust_ty} = {var_name}_raw as {rust_ty};"
240 ));
241 return;
242 }
243 Encoding::ZigZag => {
244 let type_bits = match ty {
245 ResolvedType::Primitive(p) => primitive_bits(p),
246 _ => 64,
247 };
248 w.line(&format!(
250 "let {var_name}_raw = r.read_zigzag({type_bits}_u8, 10_u8)?;"
251 ));
252 if let Some(limit) = enc.limit {
253 w.line(&format!(
254 "if {var_name}_raw.unsigned_abs() > {limit}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {var_name}_raw.unsigned_abs() }}); }}"
255 ));
256 }
257 let rust_ty = read_cast_for_zigzag(ty);
258 w.line(&format!(
259 "let {var_name}: {rust_ty} = {var_name}_raw as {rust_ty};"
260 ));
261 return;
262 }
263 Encoding::Delta(inner) => {
264 let base_enc = FieldEncoding {
267 encoding: *inner.clone(),
268 limit: enc.limit,
269 };
270 emit_read(w, var_name, ty, &base_enc, registry, field_name);
271 return;
272 }
273 Encoding::Default => {}
274 _ => {} }
276
277 emit_read_type(w, var_name, ty, registry, field_name, enc.limit);
278}
279
280fn emit_read_type(
281 w: &mut CodeWriter,
282 var_name: &str,
283 ty: &ResolvedType,
284 registry: &TypeRegistry,
285 field_name: &str,
286 limit: Option<u64>,
287) {
288 match ty {
289 ResolvedType::Primitive(p) => match p {
290 PrimitiveType::Bool => w.line(&format!("let {var_name} = r.read_bool()?;")),
291 PrimitiveType::U8 => w.line(&format!("let {var_name} = r.read_u8()?;")),
292 PrimitiveType::U16 => w.line(&format!("let {var_name} = r.read_u16()?;")),
293 PrimitiveType::U32 => w.line(&format!("let {var_name} = r.read_u32()?;")),
294 PrimitiveType::U64 => w.line(&format!("let {var_name} = r.read_u64()?;")),
295 PrimitiveType::I8 => w.line(&format!("let {var_name} = r.read_i8()?;")),
296 PrimitiveType::I16 => w.line(&format!("let {var_name} = r.read_i16()?;")),
297 PrimitiveType::I32 => w.line(&format!("let {var_name} = r.read_i32()?;")),
298 PrimitiveType::I64 => w.line(&format!("let {var_name} = r.read_i64()?;")),
299 PrimitiveType::F32 => w.line(&format!("let {var_name} = r.read_f32()?;")),
300 PrimitiveType::F64 => w.line(&format!("let {var_name} = r.read_f64()?;")),
301 PrimitiveType::Void => w.line(&format!("let {var_name} = ();")),
302 },
303 ResolvedType::SubByte(s) => {
304 let bits = s.bits;
305 if s.signed {
306 w.line(&format!(
307 "let {var_name} = r.read_bits({bits}_u8)? as u8 as i8;"
308 ));
309 } else {
310 w.line(&format!("let {var_name} = r.read_bits({bits}_u8)? as u8;"));
311 }
312 }
313 ResolvedType::Semantic(s) => match s {
314 SemanticType::String => {
315 w.line(&format!("let {var_name} = r.read_string()?;"));
316 if let Some(lim) = limit {
317 w.line(&format!(
318 "if {var_name}.len() as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}.len() as u64 }}); }}"
319 ));
320 }
321 }
322 SemanticType::Bytes => {
323 w.line(&format!("let {var_name} = r.read_bytes()?;"));
324 if let Some(lim) = limit {
325 w.line(&format!(
326 "if {var_name}.len() as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}.len() as u64 }}); }}"
327 ));
328 }
329 }
330 SemanticType::Rgb => {
331 w.line(&format!("let {var_name}_0 = r.read_u8()?;"));
332 w.line(&format!("let {var_name}_1 = r.read_u8()?;"));
333 w.line(&format!("let {var_name}_2 = r.read_u8()?;"));
334 w.line(&format!(
335 "let {var_name} = ({var_name}_0, {var_name}_1, {var_name}_2);"
336 ));
337 }
338 SemanticType::Uuid => {
339 w.line(&format!(
340 "let {var_name}_bytes = r.read_raw_bytes(16_usize)?;"
341 ));
342 w.line(&format!(
343 "let {var_name}: [u8; 16] = {var_name}_bytes.try_into().map_err(|_| vexil_runtime::DecodeError::UnexpectedEof)?;"
344 ));
345 }
346 SemanticType::Timestamp => {
347 w.line(&format!("let {var_name} = r.read_i64()?;"));
348 }
349 SemanticType::Hash => {
350 w.line(&format!(
351 "let {var_name}_bytes = r.read_raw_bytes(32_usize)?;"
352 ));
353 w.line(&format!(
354 "let {var_name}: [u8; 32] = {var_name}_bytes.try_into().map_err(|_| vexil_runtime::DecodeError::UnexpectedEof)?;"
355 ));
356 }
357 },
358 ResolvedType::Named(_) => {
359 w.line("r.enter_recursive()?;");
360 w.line(&format!(
361 "let {var_name} = vexil_runtime::Unpack::unpack(r)?;"
362 ));
363 w.line("r.leave_recursive();");
364 }
365 ResolvedType::Optional(inner) => {
366 w.line(&format!("let {var_name}_present = r.read_bool()?;"));
367 if is_byte_aligned(inner, registry) {
368 w.line("r.flush_to_byte_boundary();");
369 }
370 w.open_block(&format!("let {var_name} = if {var_name}_present"));
371 emit_read_type(
372 w,
373 &format!("{var_name}_inner"),
374 inner,
375 registry,
376 field_name,
377 None,
378 );
379 w.line(&format!("Some({var_name}_inner)"));
380 w.close_block();
381 w.open_block("else");
382 w.line("None");
383 w.close_block();
384 w.append(";");
385 w.append("\n");
386 }
387 ResolvedType::Array(inner) => {
388 w.line(&format!(
389 "let {var_name}_len = r.read_leb128(10_u8)? as usize;"
390 ));
391 if let Some(lim) = limit {
392 w.line(&format!(
393 "if {var_name}_len as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}_len as u64 }}); }}"
394 ));
395 }
396 w.line(&format!(
397 "let mut {var_name} = Vec::with_capacity({var_name}_len);"
398 ));
399 w.open_block(&format!("for _ in 0..{var_name}_len"));
400 emit_read_type(
401 w,
402 &format!("{var_name}_item"),
403 inner,
404 registry,
405 field_name,
406 None,
407 );
408 w.line(&format!("{var_name}.push({var_name}_item);"));
409 w.close_block();
410 }
411 ResolvedType::Map(k, v) => {
412 w.line(&format!(
413 "let {var_name}_len = r.read_leb128(10_u8)? as usize;"
414 ));
415 if let Some(lim) = limit {
416 w.line(&format!(
417 "if {var_name}_len as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}_len as u64 }}); }}"
418 ));
419 }
420 w.line(&format!(
421 "let mut {var_name} = std::collections::BTreeMap::new();"
422 ));
423 w.open_block(&format!("for _ in 0..{var_name}_len"));
424 emit_read_type(w, &format!("{var_name}_k"), k, registry, field_name, None);
425 emit_read_type(w, &format!("{var_name}_v"), v, registry, field_name, None);
426 w.line(&format!("{var_name}.insert({var_name}_k, {var_name}_v);"));
427 w.close_block();
428 }
429 ResolvedType::Result(ok, err) => {
430 w.line(&format!("let {var_name}_is_ok = r.read_bool()?;"));
431 w.open_block(&format!("let {var_name} = if {var_name}_is_ok"));
432 emit_read_type(w, &format!("{var_name}_ok"), ok, registry, field_name, None);
433 w.line(&format!("Ok({var_name}_ok)"));
434 w.close_block();
435 w.open_block("else");
436 emit_read_type(
437 w,
438 &format!("{var_name}_err"),
439 err,
440 registry,
441 field_name,
442 None,
443 );
444 w.line(&format!("Err({var_name}_err)"));
445 w.close_block();
446 w.append(";");
447 w.append("\n");
448 }
449 _ => {} }
451}
452
453fn read_cast_for_varint(ty: &ResolvedType) -> &'static str {
454 match ty {
455 ResolvedType::Primitive(p) => match p {
456 PrimitiveType::U8 => "u8",
457 PrimitiveType::U16 => "u16",
458 PrimitiveType::U32 => "u32",
459 PrimitiveType::U64 => "u64",
460 _ => "u64",
461 },
462 _ => "u64",
463 }
464}
465
466fn read_cast_for_zigzag(ty: &ResolvedType) -> &'static str {
467 match ty {
468 ResolvedType::Primitive(p) => match p {
469 PrimitiveType::I8 => "i8",
470 PrimitiveType::I16 => "i16",
471 PrimitiveType::I32 => "i32",
472 PrimitiveType::I64 => "i64",
473 _ => "i64",
474 },
475 _ => "i64",
476 }
477}
478
479pub fn emit_message(
485 w: &mut CodeWriter,
486 msg: &MessageDef,
487 registry: &TypeRegistry,
488 needs_box: &HashSet<(TypeId, usize)>,
489 type_id: TypeId,
490) {
491 let name = msg.name.as_str();
492
493 emit_tombstones(w, name, &msg.tombstones);
495
496 emit_type_annotations(w, &msg.annotations);
498 w.line("#[derive(Debug, Clone, PartialEq)]");
499
500 w.open_block(&format!("pub struct {name}"));
502 for (fi, field) in msg.fields.iter().enumerate() {
503 emit_field_annotations(w, &field.annotations);
504 let field_rust_type = rust_type(
505 &field.resolved_type,
506 registry,
507 needs_box,
508 Some((type_id, fi)),
509 );
510 w.line(&format!("pub {}: {},", field.name, field_rust_type));
511 }
512 w.close_block();
513 w.blank();
514
515 w.open_block(&format!("impl vexil_runtime::Pack for {name}"));
517 w.open_block("fn pack(&self, w: &mut vexil_runtime::BitWriter) -> Result<(), vexil_runtime::EncodeError>");
518 for field in &msg.fields {
519 let access = format!("self.{}", field.name);
520 emit_write(
521 w,
522 &access,
523 &field.resolved_type,
524 &field.encoding,
525 registry,
526 field.name.as_str(),
527 );
528 }
529 w.line("w.flush_to_byte_boundary();");
530 w.line("Ok(())");
531 w.close_block();
532 w.close_block();
533 w.blank();
534
535 w.open_block(&format!("impl vexil_runtime::Unpack for {name}"));
537 w.open_block("fn unpack(r: &mut vexil_runtime::BitReader<'_>) -> Result<Self, vexil_runtime::DecodeError>");
538 for field in &msg.fields {
539 let var_name = field.name.as_str();
540 emit_read(
541 w,
542 var_name,
543 &field.resolved_type,
544 &field.encoding,
545 registry,
546 var_name,
547 );
548 }
549 w.line("r.flush_to_byte_boundary();");
550 w.open_block("Ok(Self");
551 for field in &msg.fields {
552 w.line(&format!("{},", field.name));
553 }
554 w.dedent();
555 w.line("})");
556
557 w.close_block();
558 w.close_block();
559 w.blank();
560}