1use std::io::Write;
2
3#[cfg(feature = "simd")]
4use core::simd::Select;
5#[cfg(feature = "simd")]
6use core::simd::prelude::*;
7#[cfg(feature = "simd")]
8use core::simd::{Simd, u8x16};
9
10use crate::error::{BinaryError, Result};
11use crate::jid::{self, Jid, JidRef};
12use crate::node::{Node, NodeContent, NodeContentRef, NodeRef, NodeValue, ValueRef};
13use crate::token;
14
15pub trait ByteWriter {
16 fn write_u8(&mut self, value: u8) -> Result<()>;
17 fn write_bytes(&mut self, bytes: &[u8]) -> Result<()>;
18}
19
20pub(crate) struct IoByteWriter<W: Write> {
21 writer: W,
22}
23
24impl<W: Write> IoByteWriter<W> {
25 fn new(writer: W) -> Self {
26 Self { writer }
27 }
28}
29
30impl<W: Write> ByteWriter for IoByteWriter<W> {
31 #[inline]
32 fn write_u8(&mut self, value: u8) -> Result<()> {
33 self.writer.write_all(&[value])?;
34 Ok(())
35 }
36
37 #[inline]
38 fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
39 self.writer.write_all(bytes)?;
40 Ok(())
41 }
42}
43
44pub struct VecByteWriter<'a> {
45 buffer: &'a mut Vec<u8>,
46}
47
48impl<'a> VecByteWriter<'a> {
49 fn new(buffer: &'a mut Vec<u8>) -> Self {
50 Self { buffer }
51 }
52}
53
54impl ByteWriter for VecByteWriter<'_> {
55 #[inline]
56 fn write_u8(&mut self, value: u8) -> Result<()> {
57 self.buffer.push(value);
58 Ok(())
59 }
60
61 #[inline]
62 fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
63 self.buffer.extend_from_slice(bytes);
64 Ok(())
65 }
66}
67
68pub(crate) struct SliceByteWriter<'a> {
69 buffer: &'a mut [u8],
70 position: usize,
71}
72
73impl<'a> SliceByteWriter<'a> {
74 fn new(buffer: &'a mut [u8]) -> Self {
75 Self {
76 buffer,
77 position: 0,
78 }
79 }
80
81 #[inline]
82 fn bytes_written(&self) -> usize {
83 self.position
84 }
85}
86
87impl ByteWriter for SliceByteWriter<'_> {
88 #[inline]
89 fn write_u8(&mut self, value: u8) -> Result<()> {
90 if self.position >= self.buffer.len() {
91 return Err(BinaryError::UnexpectedEof);
92 }
93 self.buffer[self.position] = value;
94 self.position += 1;
95 Ok(())
96 }
97
98 #[inline]
99 fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
100 let end = self.position + bytes.len();
101 if end > self.buffer.len() {
102 return Err(BinaryError::UnexpectedEof);
103 }
104 self.buffer[self.position..end].copy_from_slice(bytes);
105 self.position = end;
106 Ok(())
107 }
108}
109
110pub trait EncodeNode {
114 fn tag(&self) -> &str;
115 fn attrs_len(&self) -> usize;
116 fn has_content(&self) -> bool;
117
118 fn encode_attrs<'a, W: ByteWriter>(&self, encoder: &mut Encoder<'a, W>) -> Result<()>;
120
121 fn encode_content<'a, W: ByteWriter>(&self, encoder: &mut Encoder<'a, W>) -> Result<()>;
123}
124
125impl EncodeNode for Node {
126 fn tag(&self) -> &str {
127 &self.tag
128 }
129
130 fn attrs_len(&self) -> usize {
131 self.attrs.len()
132 }
133
134 fn has_content(&self) -> bool {
135 self.content.is_some()
136 }
137
138 fn encode_attrs<'a, W: ByteWriter>(&self, encoder: &mut Encoder<'a, W>) -> Result<()> {
139 for (k, v) in &self.attrs {
140 encoder.write_string(k)?;
141 match v {
142 NodeValue::String(s) => encoder.write_string(s)?,
143 NodeValue::Jid(jid) => encoder.write_jid_owned(jid)?,
144 }
145 }
146 Ok(())
147 }
148
149 fn encode_content<'a, W: ByteWriter>(&self, encoder: &mut Encoder<'a, W>) -> Result<()> {
150 if let Some(content) = &self.content {
151 match content {
152 NodeContent::String(s) => encoder.write_string(s)?,
153 NodeContent::Bytes(b) => encoder.write_bytes_with_len(b)?,
154 NodeContent::Nodes(nodes) => {
155 encoder.write_list_start(nodes.len())?;
156 for node in nodes {
157 encoder.write_node(node)?;
158 }
159 }
160 }
161 }
162 Ok(())
163 }
164}
165
166impl EncodeNode for NodeRef<'_> {
167 fn tag(&self) -> &str {
168 &self.tag
169 }
170
171 fn attrs_len(&self) -> usize {
172 self.attrs.len()
173 }
174
175 fn has_content(&self) -> bool {
176 self.content.is_some()
177 }
178
179 fn encode_attrs<'a, W: ByteWriter>(&self, encoder: &mut Encoder<'a, W>) -> Result<()> {
180 for (k, v) in self.attrs.iter() {
181 encoder.write_string(k)?;
182 match v {
183 ValueRef::String(s) => encoder.write_string(s)?,
184 ValueRef::Jid(jid) => encoder.write_jid_ref(jid)?,
185 }
186 }
187 Ok(())
188 }
189
190 fn encode_content<'a, W: ByteWriter>(&self, encoder: &mut Encoder<'a, W>) -> Result<()> {
191 if let Some(content) = self.content.as_deref() {
192 match content {
193 NodeContentRef::String(s) => encoder.write_string(s)?,
194 NodeContentRef::Bytes(b) => encoder.write_bytes_with_len(b)?,
195 NodeContentRef::Nodes(nodes) => {
196 encoder.write_list_start(nodes.len())?;
197 for node in nodes.iter() {
198 encoder.write_node(node)?;
199 }
200 }
201 }
202 }
203 Ok(())
204 }
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208struct ParsedJidMeta {
209 user_end: usize,
210 server_start: usize,
211 domain_type: u8,
212 device: Option<u8>,
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
216struct StrKey {
217 ptr: usize,
218 len: usize,
219}
220
221impl StrKey {
222 #[inline]
223 fn from_str(s: &str) -> Self {
224 Self {
225 ptr: s.as_ptr() as usize,
226 len: s.len(),
227 }
228 }
229}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq)]
232enum StringHint {
233 Empty,
234 SingleToken(u8),
235 DoubleToken { dict: u8, token: u8 },
236 PackedNibble,
237 PackedHex,
238 Jid(ParsedJidMeta),
239 RawBytes,
240}
241
242#[derive(Debug)]
243pub(crate) struct StringHintCache {
244 hints: Vec<(StrKey, StringHint)>,
247}
248
249impl Default for StringHintCache {
250 fn default() -> Self {
251 Self {
252 hints: Vec::with_capacity(32),
253 }
254 }
255}
256
257impl StringHintCache {
258 const MAX_HINT_ENTRIES: usize = 96;
259
260 #[inline]
261 fn hint_for(&self, s: &str) -> Option<StringHint> {
262 let key = StrKey::from_str(s);
263 self.hints
264 .iter()
265 .find_map(|(cached_key, hint)| (*cached_key == key).then_some(*hint))
266 }
267
268 #[inline]
269 fn hint_or_insert(&mut self, s: &str) -> StringHint {
270 if s.len() > token::PACKED_MAX as usize {
271 return StringHint::RawBytes;
272 }
273 let key = StrKey::from_str(s);
274 if let Some(existing) = self
275 .hints
276 .iter()
277 .find_map(|(cached_key, hint)| (*cached_key == key).then_some(*hint))
278 {
279 existing
280 } else {
281 let hint = classify_string_hint(s);
282 if self.hints.len() < Self::MAX_HINT_ENTRIES {
283 self.hints.push((key, hint));
284 }
285 hint
286 }
287 }
288}
289
290#[derive(Debug)]
291pub(crate) struct MarshaledSizePlan {
292 pub(crate) size: usize,
293 pub(crate) hints: StringHintCache,
294}
295
296fn parse_jid_meta(input: &str) -> Option<ParsedJidMeta> {
297 let sep_idx = input.find('@')?;
298 let server_start = sep_idx + 1;
299 let server = &input[server_start..];
300 let user_combined = &input[..sep_idx];
301
302 let (user_agent, device) = if let Some(colon_idx) = user_combined.find(':') {
303 let device_part = &user_combined[colon_idx + 1..];
304 if let Ok(parsed_device) = device_part.parse::<u8>() {
305 (&user_combined[..colon_idx], Some(parsed_device))
306 } else {
307 (user_combined, None)
308 }
309 } else {
310 (user_combined, None)
311 };
312
313 let (user_end, agent_override) = if let Some(underscore_idx) = user_agent.find('_') {
314 let agent_part = &user_agent[underscore_idx + 1..];
315 if let Ok(parsed_agent) = agent_part.parse::<u8>() {
316 (underscore_idx, Some(parsed_agent))
317 } else {
318 (user_agent.len(), None)
319 }
320 } else {
321 (user_agent.len(), None)
322 };
323
324 let agent_byte = agent_override.unwrap_or(0);
325 let domain_type = if server == jid::HIDDEN_USER_SERVER {
326 1
327 } else if server == jid::HOSTED_SERVER {
328 128
329 } else if server == jid::HOSTED_LID_SERVER {
330 129
331 } else {
332 agent_byte
333 };
334
335 Some(ParsedJidMeta {
336 user_end,
337 server_start,
338 domain_type,
339 device,
340 })
341}
342
343#[inline]
344fn split_jid_from_meta(input: &str, meta: ParsedJidMeta) -> (&str, &str) {
345 (&input[..meta.user_end], &input[meta.server_start..])
346}
347
348#[inline]
365fn server_to_domain_type(server: jid::Server, agent: u8) -> u8 {
366 match server {
367 jid::Server::Lid => 1,
368 jid::Server::Hosted => 128,
369 jid::Server::HostedLid => 129,
370 _ => agent,
371 }
372}
373
374#[inline]
375fn classify_string_hint(s: &str) -> StringHint {
376 if s.is_empty() {
377 return StringHint::Empty;
378 }
379
380 let is_likely_jid = s.len() <= 48;
381
382 if let Some(kind) = token::index_of_token(s) {
383 return match kind {
384 token::TokenKind::Single(token) => StringHint::SingleToken(token),
385 token::TokenKind::Double(dict, token) => StringHint::DoubleToken { dict, token },
386 };
387 }
388
389 if validate_nibble(s) {
390 StringHint::PackedNibble
391 } else if validate_hex(s) {
392 StringHint::PackedHex
393 } else if is_likely_jid {
394 parse_jid_meta(s).map_or(StringHint::RawBytes, StringHint::Jid)
395 } else {
396 StringHint::RawBytes
397 }
398}
399
400pub(crate) fn build_marshaled_node_plan(node: &Node) -> MarshaledSizePlan {
401 let mut hints = StringHintCache::default();
402 let size = 1 + node_encoded_size_with_cache(node, &mut hints);
403 MarshaledSizePlan { size, hints }
404}
405
406pub(crate) fn build_marshaled_node_ref_plan(node: &NodeRef<'_>) -> MarshaledSizePlan {
407 let mut hints = StringHintCache::default();
408 let size = 1 + node_ref_encoded_size_with_cache(node, &mut hints);
409 MarshaledSizePlan { size, hints }
410}
411
412#[inline]
413fn list_start_encoded_size(len: usize) -> usize {
414 if len == 0 {
415 1
416 } else if len < 256 {
417 2
418 } else {
419 3
420 }
421}
422
423#[inline]
424fn binary_len_prefix_size(len: usize) -> usize {
425 if len < 256 {
426 2
427 } else if len < (1 << 20) {
428 4
429 } else {
430 5
431 }
432}
433
434#[inline]
435fn bytes_with_len_encoded_size(len: usize) -> usize {
436 binary_len_prefix_size(len) + len
437}
438
439#[inline]
440fn packed_encoded_size(value_len: usize) -> usize {
441 2 + value_len.div_ceil(2)
442}
443
444fn node_encoded_size_with_cache(node: &Node, hints: &mut StringHintCache) -> usize {
445 let content_len = usize::from(node.content.is_some());
446 let list_len = 1 + (node.attrs.len() * 2) + content_len;
447
448 let attrs_size: usize = node
449 .attrs
450 .iter()
451 .map(|(k, v)| {
452 let value_size = match v {
453 NodeValue::String(s) => string_encoded_size_with_cache(s, hints),
454 NodeValue::Jid(jid) => owned_jid_encoded_size_with_cache(jid, hints),
455 };
456 string_encoded_size_with_cache(k, hints) + value_size
457 })
458 .sum();
459
460 let content_size = match &node.content {
461 Some(NodeContent::String(s)) => string_encoded_size_with_cache(s, hints),
462 Some(NodeContent::Bytes(b)) => bytes_with_len_encoded_size(b.len()),
463 Some(NodeContent::Nodes(nodes)) => {
464 list_start_encoded_size(nodes.len())
465 + nodes
466 .iter()
467 .map(|child| node_encoded_size_with_cache(child, hints))
468 .sum::<usize>()
469 }
470 None => 0,
471 };
472
473 list_start_encoded_size(list_len)
474 + string_encoded_size_with_cache(&node.tag, hints)
475 + attrs_size
476 + content_size
477}
478
479fn node_ref_encoded_size_with_cache(node: &NodeRef<'_>, hints: &mut StringHintCache) -> usize {
480 let content_len = usize::from(node.content.is_some());
481 let list_len = 1 + (node.attrs.len() * 2) + content_len;
482
483 let attrs_size: usize = node
484 .attrs
485 .iter()
486 .map(|(k, v)| {
487 let value_size = match v {
488 ValueRef::String(s) => string_encoded_size_with_cache(s, hints),
489 ValueRef::Jid(jid) => jid_ref_encoded_size_with_cache(jid, hints),
490 };
491 string_encoded_size_with_cache(k, hints) + value_size
492 })
493 .sum();
494
495 let content_size = match node.content.as_deref() {
496 Some(NodeContentRef::String(s)) => string_encoded_size_with_cache(s, hints),
497 Some(NodeContentRef::Bytes(b)) => bytes_with_len_encoded_size(b.len()),
498 Some(NodeContentRef::Nodes(nodes)) => {
499 list_start_encoded_size(nodes.len())
500 + nodes
501 .iter()
502 .map(|child| node_ref_encoded_size_with_cache(child, hints))
503 .sum::<usize>()
504 }
505 None => 0,
506 };
507
508 list_start_encoded_size(list_len)
509 + string_encoded_size_with_cache(node.tag.as_ref(), hints)
510 + attrs_size
511 + content_size
512}
513
514#[inline]
515fn string_encoded_size_with_cache(s: &str, hints: &mut StringHintCache) -> usize {
516 let hint = hints.hint_or_insert(s);
517 string_encoded_size_from_hint_with_cache(s, hint, hints)
518}
519
520#[inline]
521fn string_encoded_size_from_hint_with_cache(
522 s: &str,
523 hint: StringHint,
524 hints: &mut StringHintCache,
525) -> usize {
526 match hint {
527 StringHint::Empty => 2,
528 StringHint::SingleToken(_) => 1,
529 StringHint::DoubleToken { .. } => 2,
530 StringHint::PackedNibble | StringHint::PackedHex => packed_encoded_size(s.len()),
531 StringHint::RawBytes => bytes_with_len_encoded_size(s.len()),
532 StringHint::Jid(meta) => parsed_jid_encoded_size_with_cache(s, meta, hints),
533 }
534}
535
536#[inline]
537fn parsed_jid_encoded_size_with_cache(
538 jid: &str,
539 meta: ParsedJidMeta,
540 hints: &mut StringHintCache,
541) -> usize {
542 let (user, server) = split_jid_from_meta(jid, meta);
543 if meta.device.is_some() {
544 3 + string_encoded_size_with_cache(user, hints)
545 } else {
546 let user_size = if user.is_empty() {
547 1
548 } else {
549 string_encoded_size_with_cache(user, hints)
550 };
551 1 + user_size + string_encoded_size_with_cache(server, hints)
552 }
553}
554
555#[inline]
556fn owned_jid_encoded_size_with_cache(jid: &Jid, hints: &mut StringHintCache) -> usize {
557 if jid.device > 0 {
558 3 + string_encoded_size_with_cache(&jid.user, hints)
559 } else {
560 let user_size = if jid.user.is_empty() {
561 1
562 } else {
563 string_encoded_size_with_cache(&jid.user, hints)
564 };
565 1 + user_size + string_encoded_size_with_cache(jid.server.as_str(), hints)
566 }
567}
568
569#[inline]
570fn jid_ref_encoded_size_with_cache(jid: &JidRef<'_>, hints: &mut StringHintCache) -> usize {
571 if jid.device > 0 {
572 3 + string_encoded_size_with_cache(&jid.user, hints)
573 } else {
574 let user_size = if jid.user.is_empty() {
575 1
576 } else {
577 string_encoded_size_with_cache(&jid.user, hints)
578 };
579 1 + user_size + string_encoded_size_with_cache(jid.server.as_str(), hints)
580 }
581}
582
583#[inline]
584fn validate_nibble(value: &str) -> bool {
585 if value.len() > token::PACKED_MAX as usize {
586 return false;
587 }
588 value
589 .as_bytes()
590 .iter()
591 .all(|&b| b.is_ascii_digit() || b == b'-' || b == b'.')
592}
593
594#[inline]
595fn validate_hex(value: &str) -> bool {
596 if value.len() > token::PACKED_MAX as usize {
597 return false;
598 }
599 value
600 .as_bytes()
601 .iter()
602 .all(|&b| b.is_ascii_digit() || (b'A'..=b'F').contains(&b))
603}
604
605pub struct Encoder<'a, W: ByteWriter> {
606 writer: W,
607 string_hints: Option<&'a StringHintCache>,
608}
609
610impl<W: Write> Encoder<'static, IoByteWriter<W>> {
611 pub fn new(writer: W) -> Result<Self> {
612 let mut enc = Self {
613 writer: IoByteWriter::new(writer),
614 string_hints: None,
615 };
616 enc.write_u8(0)?;
617 Ok(enc)
618 }
619}
620
621impl<'v> Encoder<'static, VecByteWriter<'v>> {
622 pub fn new_vec(buffer: &'v mut Vec<u8>) -> Result<Self> {
623 buffer.clear();
624 let mut enc = Self {
625 writer: VecByteWriter::new(buffer),
626 string_hints: None,
627 };
628 enc.write_u8(0)?;
629 Ok(enc)
630 }
631}
632
633impl<'a> Encoder<'a, SliceByteWriter<'a>> {
634 pub(crate) fn new_slice(
635 buffer: &'a mut [u8],
636 string_hints: Option<&'a StringHintCache>,
637 ) -> Result<Self> {
638 let mut enc = Self {
639 writer: SliceByteWriter::new(buffer),
640 string_hints,
641 };
642 enc.write_u8(0)?;
643 Ok(enc)
644 }
645
646 #[inline]
647 pub(crate) fn bytes_written(&self) -> usize {
648 self.writer.bytes_written()
649 }
650}
651
652impl<'a, W: ByteWriter> Encoder<'a, W> {
653 #[inline(always)]
654 fn write_u8(&mut self, val: u8) -> Result<()> {
655 self.writer.write_u8(val)
656 }
657
658 #[inline(always)]
659 fn write_u16_be(&mut self, val: u16) -> Result<()> {
660 self.writer.write_bytes(&val.to_be_bytes())
661 }
662
663 #[inline(always)]
664 fn write_u32_be(&mut self, val: u32) -> Result<()> {
665 self.writer.write_bytes(&val.to_be_bytes())
666 }
667
668 #[inline(always)]
669 fn write_u20_be(&mut self, value: u32) -> Result<()> {
670 let bytes = [
671 ((value >> 16) & 0x0F) as u8,
672 ((value >> 8) & 0xFF) as u8,
673 (value & 0xFF) as u8,
674 ];
675 self.writer.write_bytes(&bytes)
676 }
677
678 #[inline(always)]
679 fn write_raw_bytes(&mut self, bytes: &[u8]) -> Result<()> {
680 self.writer.write_bytes(bytes)
681 }
682
683 #[inline(always)]
684 pub fn write_bytes_with_len(&mut self, bytes: &[u8]) -> Result<()> {
685 let len = bytes.len();
686 if len < 256 {
687 self.write_u8(token::BINARY_8)?;
688 self.write_u8(len as u8)?;
689 } else if len < (1 << 20) {
690 self.write_u8(token::BINARY_20)?;
691 self.write_u20_be(len as u32)?;
692 } else {
693 self.write_u8(token::BINARY_32)?;
694 self.write_u32_be(len as u32)?;
695 }
696 self.write_raw_bytes(bytes)
697 }
698
699 #[inline(always)]
700 pub fn write_string(&mut self, s: &str) -> Result<()> {
701 if let Some(string_hints) = self.string_hints
702 && let Some(hint) = string_hints.hint_for(s)
703 {
704 return self.write_string_with_hint(s, hint);
705 }
706 self.write_string_uncached(s)
707 }
708
709 #[inline(always)]
710 fn write_string_uncached(&mut self, s: &str) -> Result<()> {
711 if s.len() > token::PACKED_MAX as usize {
714 return self.write_bytes_with_len(s.as_bytes());
715 }
716 self.write_string_with_hint(s, classify_string_hint(s))
717 }
718
719 #[inline(always)]
720 fn write_string_with_hint(&mut self, s: &str, hint: StringHint) -> Result<()> {
721 match hint {
722 StringHint::Empty => {
723 self.write_u8(token::BINARY_8)?;
724 self.write_u8(0)?;
725 }
726 StringHint::SingleToken(token) => self.write_u8(token)?,
727 StringHint::DoubleToken { dict, token } => {
728 self.write_u8(token::DICTIONARY_0 + dict)?;
729 self.write_u8(token)?;
730 }
731 StringHint::PackedNibble => self.write_packed_bytes(s, token::NIBBLE_8)?,
732 StringHint::PackedHex => self.write_packed_bytes(s, token::HEX_8)?,
733 StringHint::Jid(meta) => self.write_jid_from_meta(s, meta)?,
734 StringHint::RawBytes => self.write_bytes_with_len(s.as_bytes())?,
735 }
736 Ok(())
737 }
738
739 #[inline(always)]
740 fn write_jid_from_meta(&mut self, jid: &str, meta: ParsedJidMeta) -> Result<()> {
741 let (user, server) = split_jid_from_meta(jid, meta);
742 if let Some(device) = meta.device {
743 self.write_u8(token::AD_JID)?;
744 self.write_u8(meta.domain_type)?;
745 self.write_u8(device)?;
746 self.write_string(user)?;
747 } else {
748 self.write_u8(token::JID_PAIR)?;
749 if user.is_empty() {
750 self.write_u8(token::LIST_EMPTY)?;
751 } else {
752 self.write_string(user)?;
753 }
754 self.write_string(server)?;
755 }
756 Ok(())
757 }
758
759 pub fn write_jid_ref(&mut self, jid: &JidRef<'_>) -> Result<()> {
762 if jid.device > 0 {
763 let device = u8::try_from(jid.device).map_err(|_| {
765 BinaryError::AttrParse(format!("AD_JID device id out of range: {}", jid.device))
766 })?;
767 self.write_u8(token::AD_JID)?;
768 self.write_u8(server_to_domain_type(jid.server, jid.agent))?;
769 self.write_u8(device)?;
770 self.write_string(&jid.user)?;
771 } else {
772 self.write_u8(token::JID_PAIR)?;
774 if jid.user.is_empty() {
775 self.write_u8(token::LIST_EMPTY)?;
776 } else {
777 self.write_string(&jid.user)?;
778 }
779 self.write_string(jid.server.as_str())?;
780 }
781 Ok(())
782 }
783
784 pub fn write_jid_owned(&mut self, jid: &Jid) -> Result<()> {
787 if jid.device > 0 {
788 let device = u8::try_from(jid.device).map_err(|_| {
790 BinaryError::AttrParse(format!("AD_JID device id out of range: {}", jid.device))
791 })?;
792 self.write_u8(token::AD_JID)?;
793 self.write_u8(server_to_domain_type(jid.server, jid.agent))?;
794 self.write_u8(device)?;
795 self.write_string(&jid.user)?;
796 } else {
797 self.write_u8(token::JID_PAIR)?;
799 if jid.user.is_empty() {
800 self.write_u8(token::LIST_EMPTY)?;
801 } else {
802 self.write_string(&jid.user)?;
803 }
804 self.write_string(jid.server.as_str())?;
805 }
806 Ok(())
807 }
808
809 #[inline(always)]
810 fn pack_nibble(value: u8) -> u8 {
811 match value {
812 b'-' => 10,
813 b'.' => 11,
814 0 => 15,
815 c if c.is_ascii_digit() => c - b'0',
816 _ => panic!("Invalid char for nibble packing: {value}"),
817 }
818 }
819
820 #[inline(always)]
821 fn pack_hex(value: u8) -> u8 {
822 match value {
823 c if c.is_ascii_digit() => c - b'0',
824 c if (b'A'..=b'F').contains(&c) => 10 + (c - b'A'),
825 0 => 15,
826 _ => panic!("Invalid char for hex packing: {value}"),
827 }
828 }
829
830 #[inline(always)]
831 fn pack_byte_pair(packer: fn(u8) -> u8, part1: u8, part2: u8) -> u8 {
832 (packer(part1) << 4) | packer(part2)
833 }
834
835 fn write_packed_bytes(&mut self, value: &str, data_type: u8) -> Result<()> {
836 if value.len() > token::PACKED_MAX as usize {
837 panic!("String too long to be packed: {}", value.len());
838 }
839
840 self.write_u8(data_type)?;
841
842 let mut rounded_len = value.len().div_ceil(2) as u8;
843 if !value.len().is_multiple_of(2) {
844 rounded_len |= 0x80;
845 }
846 self.write_u8(rounded_len)?;
847
848 #[allow(unused_mut)]
849 let mut input_bytes = value.as_bytes();
850
851 if data_type == token::NIBBLE_8 {
852 #[cfg(feature = "simd")]
853 {
854 const NIBBLE_LOOKUP: [u8; 16] =
855 [10, 11, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255];
856 let lookup = Simd::from_array(NIBBLE_LOOKUP);
857 let nibble_base = Simd::splat(b'-');
858
859 while input_bytes.len() >= 16 {
860 let (chunk, rest) = input_bytes.split_at(16);
861 let input = u8x16::from_slice(chunk);
862 let indices = input.saturating_sub(nibble_base);
863 let nibbles = lookup.swizzle_dyn(indices);
864
865 let (evens, odds) = nibbles.deinterleave(nibbles.rotate_elements_left::<1>());
866 let packed: Simd<u8, 16> = (evens << Simd::splat(4)) | odds;
867 let packed_bytes = packed.to_array();
868 self.write_raw_bytes(&packed_bytes[..8])?;
869
870 input_bytes = rest;
871 }
872 }
873
874 let mut bytes_iter = input_bytes.iter().copied();
875 while let Some(part1) = bytes_iter.next() {
876 let part2 = bytes_iter.next().unwrap_or(0);
877 self.write_u8(Self::pack_byte_pair(Self::pack_nibble, part1, part2))?;
878 }
879 } else {
880 #[cfg(feature = "simd")]
881 {
882 let ascii_0 = Simd::splat(b'0');
883 let ascii_a = Simd::splat(b'A');
884 let ten = Simd::splat(10);
885
886 while input_bytes.len() >= 16 {
887 let (chunk, rest) = input_bytes.split_at(16);
888 let input = u8x16::from_slice(chunk);
889
890 let digit_vals = input - ascii_0;
891 let letter_vals = input - ascii_a + ten;
892 let is_letter = input.simd_ge(ascii_a);
893 let nibbles = is_letter.select(letter_vals, digit_vals);
894
895 let (evens, odds) = nibbles.deinterleave(nibbles.rotate_elements_left::<1>());
896 let packed: Simd<u8, 16> = (evens << Simd::splat(4)) | odds;
897 let packed_bytes = packed.to_array();
898 self.write_raw_bytes(&packed_bytes[..8])?;
899
900 input_bytes = rest;
901 }
902 }
903
904 let mut bytes_iter = input_bytes.iter().copied();
905 while let Some(part1) = bytes_iter.next() {
906 let part2 = bytes_iter.next().unwrap_or(0);
907 self.write_u8(Self::pack_byte_pair(Self::pack_hex, part1, part2))?;
908 }
909 }
910 Ok(())
911 }
912
913 pub fn write_list_start(&mut self, len: usize) -> Result<()> {
914 if len == 0 {
915 self.write_u8(token::LIST_EMPTY)?;
916 } else if len < 256 {
917 self.write_u8(248)?;
918 self.write_u8(len as u8)?;
919 } else if len <= u16::MAX as usize {
920 self.write_u8(249)?;
921 self.write_u16_be(len as u16)?;
922 } else {
923 return Err(BinaryError::InvalidNode);
924 }
925 Ok(())
926 }
927
928 pub fn write_node<N: EncodeNode>(&mut self, node: &N) -> Result<()> {
930 let content_len = if node.has_content() { 1 } else { 0 };
931 let list_len = 1 + (node.attrs_len() * 2) + content_len;
932
933 self.write_list_start(list_len)?;
934 self.write_string(node.tag())?;
935 node.encode_attrs(self)?;
936 node.encode_content(self)?;
937 Ok(())
938 }
939}
940
941#[cfg(test)]
942mod tests {
943 use super::*;
944 use crate::builder::NodeBuilder;
945 use crate::node::Attrs;
946 use std::io::Cursor;
947
948 type TestResult = crate::error::Result<()>;
949
950 #[test]
951 fn test_encode_node() -> TestResult {
952 let node = Node::new(
953 "message",
954 Attrs::new(),
955 Some(NodeContent::String("receipt".into())),
956 );
957
958 let mut buffer = Vec::new();
959 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
960 encoder.write_node(&node)?;
961
962 let expected = vec![0, 248, 2, 19, 7];
963 assert_eq!(buffer, expected);
964 assert_eq!(buffer.len(), 5);
965 Ok(())
966 }
967
968 #[test]
969 fn test_nibble_packing() -> TestResult {
970 let test_str = "-.0123456789";
972 let node = Node::new(
973 "test",
974 Attrs::new(),
975 Some(NodeContent::String(test_str.into())),
976 );
977
978 let mut buffer = Vec::new();
979 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
980 encoder.write_node(&node)?;
981
982 let expected = vec![
983 0, 248, 2, 252, 4, 116, 101, 115, 116, 255, 6, 171, 1, 35, 69, 103, 137,
984 ];
985 assert_eq!(buffer, expected);
986 assert_eq!(buffer.len(), 17);
987 Ok(())
988 }
989
990 #[test]
992 fn test_list_size_list8_boundary() -> TestResult {
993 let mut buffer = Vec::new();
994 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
995
996 encoder.write_list_start(255)?;
998
999 assert_eq!(buffer[1], token::LIST_8);
1001 assert_eq!(buffer[2], 255);
1002 Ok(())
1003 }
1004
1005 #[test]
1007 fn test_list_size_list16_boundary() -> TestResult {
1008 let mut buffer = Vec::new();
1009 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1010
1011 encoder.write_list_start(256)?;
1013
1014 assert_eq!(buffer[1], token::LIST_16);
1016 assert_eq!(buffer[2], 0x01); assert_eq!(buffer[3], 0x00); Ok(())
1019 }
1020
1021 #[test]
1023 fn test_list_size_empty() -> TestResult {
1024 let mut buffer = Vec::new();
1025 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1026
1027 encoder.write_list_start(0)?;
1028
1029 assert_eq!(buffer[1], token::LIST_EMPTY);
1031 Ok(())
1032 }
1033
1034 #[test]
1036 fn test_hex_validation() {
1037 assert!(validate_hex("0123456789ABCDEF"));
1039 assert!(validate_hex("DEADBEEF"));
1040 assert!(validate_hex("1234"));
1041
1042 assert!(!validate_hex("abcdef"));
1044 assert!(!validate_hex("DeadBeef"));
1045
1046 assert!(!validate_hex("-"));
1048 assert!(!validate_hex("."));
1049 assert!(!validate_hex(" "));
1050
1051 assert!(validate_hex(""));
1053 }
1054
1055 #[test]
1057 fn test_nibble_validation() {
1058 assert!(validate_nibble("0123456789"));
1060 assert!(validate_nibble("-"));
1061 assert!(validate_nibble("."));
1062 assert!(validate_nibble("123-456.789"));
1063
1064 assert!(!validate_nibble("abc"));
1066 assert!(!validate_nibble("123abc"));
1067
1068 assert!(!validate_nibble("ABC"));
1070
1071 assert!(!validate_nibble("123!456"));
1073 assert!(!validate_nibble("@"));
1074 }
1075
1076 #[test]
1078 fn test_binary_length_boundaries() -> TestResult {
1079 let short_data = vec![0x42; 255];
1081 let mut buffer = Vec::new();
1082 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1083 encoder.write_bytes_with_len(&short_data)?;
1084 assert_eq!(buffer[1], token::BINARY_8);
1085 assert_eq!(buffer[2], 255);
1086
1087 let medium_data = vec![0x42; 256];
1089 let mut buffer = Vec::new();
1090 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1091 encoder.write_bytes_with_len(&medium_data)?;
1092 assert_eq!(buffer[1], token::BINARY_20);
1093 assert_eq!(buffer[2], 0x00);
1095 assert_eq!(buffer[3], 0x01);
1096 assert_eq!(buffer[4], 0x00);
1097
1098 Ok(())
1099 }
1100
1101 #[test]
1103 fn test_node_with_255_children() -> TestResult {
1104 let children: Vec<Node> = (0..255)
1105 .map(|_| Node::new("child", Attrs::new(), None))
1106 .collect();
1107
1108 let parent = Node::new("parent", Attrs::new(), Some(NodeContent::Nodes(children)));
1109
1110 let mut buffer = Vec::new();
1111 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1112 encoder.write_node(&parent)?;
1113
1114 assert!(!buffer.is_empty());
1116 Ok(())
1117 }
1118
1119 #[test]
1121 fn test_node_with_256_children() -> TestResult {
1122 let children: Vec<Node> = (0..256)
1123 .map(|_| Node::new("x", Attrs::new(), None))
1124 .collect();
1125
1126 let parent = Node::new("parent", Attrs::new(), Some(NodeContent::Nodes(children)));
1127
1128 let mut buffer = Vec::new();
1129 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1130 encoder.write_node(&parent)?;
1131
1132 assert!(!buffer.is_empty());
1134 Ok(())
1135 }
1136
1137 #[test]
1139 fn test_packed_max_boundary() {
1140 let max_nibble = "0".repeat(token::PACKED_MAX as usize);
1142 assert!(validate_nibble(&max_nibble));
1143
1144 let over_max = "0".repeat(token::PACKED_MAX as usize + 1);
1146 assert!(!validate_nibble(&over_max));
1147 }
1148
1149 #[test]
1151 fn test_empty_string_encoding() -> TestResult {
1152 let mut buffer = Vec::new();
1153 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1154 encoder.write_string("")?;
1155
1156 println!("Empty string encoding: {:?}", &buffer[1..]);
1160 assert_eq!(
1161 buffer.len(),
1162 3,
1163 "Empty string should encode to 2 bytes (plus leading 0)"
1164 );
1165 assert_eq!(
1166 buffer[1],
1167 token::BINARY_8,
1168 "First byte should be BINARY_8 (252)"
1169 );
1170 assert_eq!(buffer[2], 0, "Second byte should be 0 (length)");
1171 Ok(())
1172 }
1173
1174 #[test]
1176 fn test_empty_string_roundtrip() -> TestResult {
1177 use crate::decoder::Decoder;
1178
1179 let mut attrs = Attrs::new();
1180 attrs.insert("key", ""); attrs.insert("", "value"); let node = Node::new("test", attrs, Some(NodeContent::String("".into())));
1184
1185 let mut buffer = Vec::new();
1186 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1187 encoder.write_node(&node)?;
1188
1189 let mut decoder = Decoder::new(&buffer[1..]);
1190 let decoded = decoder.read_node_ref()?.to_owned();
1191
1192 assert_eq!(decoded.tag, "test");
1193 assert_eq!(
1194 decoded.attrs.get("key"),
1195 Some(&NodeValue::String("".into()))
1196 );
1197 assert_eq!(
1198 decoded.attrs.get(""),
1199 Some(&NodeValue::String("value".into()))
1200 );
1201
1202 match &decoded.content {
1204 Some(NodeContent::Bytes(b)) => assert!(b.is_empty(), "Content should be empty bytes"),
1205 other => panic!("Expected empty bytes, got {:?}", other),
1206 }
1207 Ok(())
1208 }
1209
1210 #[test]
1213 fn test_jid_length_heuristic() -> TestResult {
1214 use crate::decoder::Decoder;
1215 use crate::token;
1216
1217 let short_jid = "user@s.whatsapp.net";
1219 let mut buffer = Vec::new();
1220 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1221 encoder.write_string(short_jid)?;
1222
1223 assert_eq!(
1225 buffer[1],
1226 token::JID_PAIR,
1227 "Short JID should be encoded as JID_PAIR token"
1228 );
1229
1230 let long_text = "x".repeat(300) + "@s.whatsapp.net";
1232 let mut buffer = Vec::new();
1233 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1234 encoder.write_string(&long_text)?;
1235
1236 assert_eq!(
1238 buffer[1],
1239 token::BINARY_20,
1240 "Long string should be encoded as BINARY_20, not as JID"
1241 );
1242
1243 let node = Node::new(
1245 "msg",
1246 Attrs::new(),
1247 Some(NodeContent::String(long_text.as_str().into())),
1248 );
1249 let mut buffer = Vec::new();
1250 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1251 encoder.write_node(&node)?;
1252
1253 let mut decoder = Decoder::new(&buffer[1..]);
1254 let decoded = decoder.read_node_ref()?.to_owned();
1255 match &decoded.content {
1256 Some(NodeContent::Bytes(b)) => {
1257 assert_eq!(
1258 String::from_utf8_lossy(b),
1259 long_text,
1260 "Long string should round-trip correctly"
1261 );
1262 }
1263 other => panic!("Expected bytes content, got {:?}", other),
1264 }
1265
1266 Ok(())
1267 }
1268
1269 #[test]
1270 fn test_jid_parser_preserves_non_numeric_device_suffix() -> TestResult {
1271 use crate::decoder::Decoder;
1272
1273 let value = "foo:bar@s.whatsapp.net";
1274 let node = Node::new("msg", Attrs::new(), Some(NodeContent::String(value.into())));
1275
1276 let mut buffer = Vec::new();
1277 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1278 encoder.write_node(&node)?;
1279
1280 let mut decoder = Decoder::new(&buffer[1..]);
1281 let decoded = decoder.read_node_ref()?.to_owned();
1282 match decoded.content {
1283 Some(NodeContent::String(s)) => assert_eq!(s, value),
1284 other => panic!("Expected string content, got {:?}", other),
1285 }
1286 Ok(())
1287 }
1288
1289 #[test]
1290 fn test_jid_parser_preserves_non_numeric_agent_suffix() -> TestResult {
1291 use crate::decoder::Decoder;
1292
1293 let value = "hello_world@s.whatsapp.net";
1294 let node = Node::new("msg", Attrs::new(), Some(NodeContent::String(value.into())));
1295
1296 let mut buffer = Vec::new();
1297 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1298 encoder.write_node(&node)?;
1299
1300 let mut decoder = Decoder::new(&buffer[1..]);
1301 let decoded = decoder.read_node_ref()?.to_owned();
1302 match decoded.content {
1303 Some(NodeContent::String(s)) => assert_eq!(s, value),
1304 other => panic!("Expected string content, got {:?}", other),
1305 }
1306 Ok(())
1307 }
1308
1309 #[test]
1320 fn test_ad_jid_domain_type_lid() -> TestResult {
1321 let lid_jid = Jid::lid_device("236395184570386", 39);
1323 let node = NodeBuilder::new("to").attr("jid", lid_jid).build();
1324
1325 let mut buffer = Vec::new();
1326 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1327 encoder.write_node(&node)?;
1328
1329 let ad_jid_pos = buffer
1331 .iter()
1332 .position(|&b| b == token::AD_JID)
1333 .expect("AD_JID token (0xF7) must be present for device JID");
1334
1335 let domain_type = buffer[ad_jid_pos + 1];
1337 assert_eq!(
1338 domain_type, 1,
1339 "LID JID must encode domain_type=1 (lid), got {domain_type} (0=whatsapp, 128=hosted)"
1340 );
1341
1342 let device = buffer[ad_jid_pos + 2];
1344 assert_eq!(device, 39, "Device byte must be 39");
1345
1346 Ok(())
1347 }
1348
1349 #[test]
1350 fn test_ad_jid_domain_type_whatsapp() -> TestResult {
1351 let pn_jid = Jid::pn_device("551199887766", 33);
1352 let node = NodeBuilder::new("to").attr("jid", pn_jid).build();
1353
1354 let mut buffer = Vec::new();
1355 let mut encoder = Encoder::new(Cursor::new(&mut buffer))?;
1356 encoder.write_node(&node)?;
1357
1358 let ad_jid_pos = buffer
1359 .iter()
1360 .position(|&b| b == token::AD_JID)
1361 .expect("AD_JID token must be present for device JID");
1362
1363 let domain_type = buffer[ad_jid_pos + 1];
1364 assert_eq!(
1365 domain_type, 0,
1366 "s.whatsapp.net JID must encode domain_type=0, got {domain_type}"
1367 );
1368
1369 Ok(())
1370 }
1371
1372 #[test]
1377 fn test_jid_string_vs_direct_encoding_matches() -> TestResult {
1378 use crate::decoder::Decoder;
1379
1380 let test_cases: Vec<Jid> = vec![
1381 Jid::lid_device("236395184570386", 39), Jid::pn_device("551199887766", 33), Jid::lid("236395184570386"), Jid::pn("551199887766"), "5511999887766:99@hosted".parse().unwrap(), "100000012345678:99@hosted.lid".parse().unwrap(), ];
1388
1389 for jid in test_cases {
1390 let node_str = NodeBuilder::new("to").attr("jid", jid.to_string()).build();
1392
1393 let node_jid = NodeBuilder::new("to").attr("jid", jid.clone()).build();
1395
1396 let mut buf_str = Vec::new();
1397 Encoder::new(Cursor::new(&mut buf_str))?.write_node(&node_str)?;
1398
1399 let mut buf_jid = Vec::new();
1400 Encoder::new(Cursor::new(&mut buf_jid))?.write_node(&node_jid)?;
1401
1402 assert_eq!(
1403 buf_str, buf_jid,
1404 "String vs direct Jid encoding must produce identical bytes for {jid}"
1405 );
1406
1407 let mut decoder = Decoder::new(&buf_jid[1..]);
1410 let decoded_node = decoder.read_node_ref()?.to_owned();
1411 let decoded_jid: Jid = decoded_node
1412 .attrs()
1413 .optional_jid("jid")
1414 .expect("jid attr must round-trip as JID");
1415
1416 assert_eq!(
1417 jid.user, decoded_jid.user,
1418 "Round-trip user mismatch for {jid}"
1419 );
1420 assert_eq!(
1421 jid.device, decoded_jid.device,
1422 "Round-trip device mismatch for {jid}"
1423 );
1424 assert_eq!(
1425 jid.server, decoded_jid.server,
1426 "Round-trip server mismatch for {jid}"
1427 );
1428 }
1429
1430 Ok(())
1431 }
1432
1433 #[test]
1436 fn test_direct_constructed_hosted_encodes_correct_domain_type() -> TestResult {
1437 let mut hosted = Jid::new("100000000000001", jid::Server::Hosted);
1438 hosted.device = 99;
1439 assert_eq!(
1440 hosted.agent, 0,
1441 "default agent for direct construction is 0"
1442 );
1443
1444 let mut hosted_lid = Jid::new("100000000000002", jid::Server::HostedLid);
1445 hosted_lid.device = 99;
1446 assert_eq!(hosted_lid.agent, 0);
1447
1448 for (jid, expected) in [(&hosted, 128u8), (&hosted_lid, 129u8)] {
1449 let node = NodeBuilder::new("to").attr("jid", jid.clone()).build();
1450 let mut buf = Vec::new();
1451 Encoder::new(Cursor::new(&mut buf))?.write_node(&node)?;
1452
1453 let pos = buf
1454 .iter()
1455 .position(|&b| b == token::AD_JID)
1456 .expect("AD_JID marker present");
1457 assert_eq!(
1458 buf[pos + 1],
1459 expected,
1460 "direct-constructed {jid} must emit domain_type {expected} \
1461 (pre-#391 would have emitted agent=0)"
1462 );
1463 }
1464 Ok(())
1465 }
1466
1467 #[test]
1471 fn test_long_string_skips_classification() -> TestResult {
1472 use crate::decoder::Decoder;
1473 use crate::marshal::marshal;
1474
1475 let at_boundary = "0".repeat(token::PACKED_MAX as usize); let over_boundary = "0".repeat(token::PACKED_MAX as usize + 1); let node_at = Node::new(
1480 "test",
1481 Attrs::new(),
1482 Some(NodeContent::String(at_boundary.as_str().into())),
1483 );
1484 let encoded_at = marshal(&node_at)?;
1485
1486 let node_over = Node::new(
1488 "test",
1489 Attrs::new(),
1490 Some(NodeContent::String(over_boundary.as_str().into())),
1491 );
1492 let encoded_over = marshal(&node_over)?;
1493
1494 assert!(
1496 encoded_at.len() < encoded_over.len(),
1497 "127-char nibble string should pack smaller than 128-char raw: {} vs {}",
1498 encoded_at.len(),
1499 encoded_over.len(),
1500 );
1501
1502 let has_raw_128 = encoded_over
1505 .windows(2)
1506 .any(|w| w[0] == token::BINARY_8 && w[1] == 128);
1507 assert!(
1508 has_raw_128,
1509 "128-char string must contain BINARY_8 + length=128 sequence"
1510 );
1511
1512 let decoded_at = Decoder::new(&encoded_at[1..]).read_node_ref()?.to_owned();
1514 let decoded_over = Decoder::new(&encoded_over[1..]).read_node_ref()?.to_owned();
1515
1516 match &decoded_at.content {
1517 Some(NodeContent::String(s)) => assert_eq!(s.as_str(), at_boundary),
1518 Some(NodeContent::Bytes(b)) => {
1519 assert_eq!(std::str::from_utf8(b).unwrap(), at_boundary)
1520 }
1521 other => panic!("Expected string/bytes content, got {:?}", other),
1522 }
1523 match &decoded_over.content {
1524 Some(NodeContent::Bytes(b)) => {
1525 assert_eq!(std::str::from_utf8(b).unwrap(), over_boundary)
1526 }
1527 other => panic!(
1528 "Expected bytes content for 128-char string, got {:?}",
1529 other
1530 ),
1531 }
1532
1533 Ok(())
1534 }
1535}