1use byteorder::{ByteOrder, NetworkEndian};
2use transformable::utils::encoded_u64_varint_len;
3
4use crate::LamportTimeTransformError;
5
6use super::{LamportTime, Transformable};
7
8#[viewit::viewit(setters(prefix = "with"))]
11#[derive(Debug, Clone, Eq, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct JoinMessage<I> {
14 #[viewit(
16 getter(const, attrs(doc = "Returns the lamport time for this message")),
17 setter(
18 const,
19 attrs(doc = "Sets the lamport time for this message (Builder pattern)")
20 )
21 )]
22 ltime: LamportTime,
23 #[viewit(
25 getter(const, style = "ref", attrs(doc = "Returns the node")),
26 setter(attrs(doc = "Sets the node (Builder pattern)"))
27 )]
28 id: I,
29}
30
31impl<I> JoinMessage<I> {
32 pub fn new(ltime: LamportTime, id: I) -> Self {
34 Self { ltime, id }
35 }
36
37 #[inline]
39 pub fn set_ltime(&mut self, ltime: LamportTime) -> &mut Self {
40 self.ltime = ltime;
41 self
42 }
43
44 #[inline]
46 pub fn set_id(&mut self, id: I) -> &mut Self {
47 self.id = id;
48 self
49 }
50}
51
52#[derive(thiserror::Error)]
54pub enum JoinMessageTransformError<I: Transformable> {
55 #[error("not enough bytes to decode JoinMessage")]
57 NotEnoughBytes,
58 #[error("encode buffer too small")]
60 EncodeBufferTooSmall,
61 #[error(transparent)]
63 Id(I::Error),
64
65 #[error(transparent)]
67 LamportTime(#[from] LamportTimeTransformError),
68}
69
70impl<I: Transformable> core::fmt::Debug for JoinMessageTransformError<I> {
71 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
72 write!(f, "{}", self)
73 }
74}
75
76impl<I> Transformable for JoinMessage<I>
77where
78 I: Transformable,
79{
80 type Error = JoinMessageTransformError<I>;
81
82 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
83 let encoded_len = self.encoded_len();
84 if dst.len() < encoded_len {
85 return Err(Self::Error::EncodeBufferTooSmall);
86 }
87
88 let mut offset = 0;
89 NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32);
90 offset += 4;
91
92 offset += self.ltime.encode(&mut dst[offset..])?;
93 offset += self
94 .id
95 .encode(&mut dst[offset..])
96 .map_err(Self::Error::Id)?;
97
98 debug_assert_eq!(
99 offset, encoded_len,
100 "expect write {} bytes, but actual write {} bytes",
101 encoded_len, offset
102 );
103 Ok(offset)
104 }
105
106 fn encoded_len(&self) -> usize {
107 4 + encoded_u64_varint_len(self.ltime.0) + self.id.encoded_len()
108 }
109
110 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
111 where
112 Self: Sized,
113 {
114 if src.len() < 4 {
115 return Err(Self::Error::NotEnoughBytes);
116 }
117
118 let encoded_len = NetworkEndian::read_u32(&src[..4]) as usize;
119 if src.len() < encoded_len {
120 return Err(Self::Error::NotEnoughBytes);
121 }
122
123 let mut offset = 4;
124 let (n, ltime) = LamportTime::decode(&src[offset..])?;
125 offset += n;
126
127 let (n, id) = I::decode(&src[offset..]).map_err(Self::Error::Id)?;
128 offset += n;
129
130 debug_assert_eq!(
131 offset, encoded_len,
132 "expect read {} bytes, but actual read {} bytes",
133 encoded_len, offset
134 );
135 Ok((encoded_len, Self { ltime, id }))
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use rand::{distributions::Alphanumeric, thread_rng, Rng};
142 use smol_str::SmolStr;
143
144 use super::*;
145
146 impl JoinMessage<SmolStr> {
147 fn random(size: usize) -> Self {
148 let id = thread_rng()
149 .sample_iter(Alphanumeric)
150 .take(size)
151 .collect::<Vec<u8>>();
152 let id = String::from_utf8(id).unwrap().into();
153
154 Self {
155 ltime: LamportTime::random(),
156 id,
157 }
158 }
159 }
160
161 #[test]
162 fn test_transfrom_encode_decode() {
163 futures::executor::block_on(async {
164 for i in 0..100 {
165 let filter = JoinMessage::random(i);
166 let mut buf = vec![0; filter.encoded_len()];
167 let encoded_len = filter.encode(&mut buf).unwrap();
168 assert_eq!(encoded_len, filter.encoded_len());
169
170 let (decoded_len, decoded) = JoinMessage::<SmolStr>::decode(&buf).unwrap();
171 assert_eq!(decoded_len, encoded_len);
172 assert_eq!(decoded, filter);
173
174 let (decoded_len, decoded) =
175 JoinMessage::<SmolStr>::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
176 assert_eq!(decoded_len, encoded_len);
177 assert_eq!(decoded, filter);
178
179 let (decoded_len, decoded) =
180 JoinMessage::<SmolStr>::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
181 .await
182 .unwrap();
183 assert_eq!(decoded_len, encoded_len);
184 assert_eq!(decoded, filter);
185 }
186 });
187 }
188}