serf_core/delegate/
transform.rs

1use memberlist_core::{
2  bytes::Bytes,
3  transport::{Id, Node, Transformable},
4  CheapClone,
5};
6use serf_types::{
7  FilterTransformError, JoinMessage, LeaveMessage, Member, MessageType, NodeTransformError,
8  PushPullMessage, QueryMessage, QueryResponseMessage, SerfMessageTransformError,
9  TagsTransformError, UserEventMessage,
10};
11
12use crate::{
13  coordinate::{Coordinate, CoordinateTransformError},
14  types::{AsMessageRef, Filter, SerfMessage, Tags, UnknownMessageType},
15};
16
17/// A delegate for encoding and decoding.
18pub trait TransformDelegate: Send + Sync + 'static {
19  /// The error type for the transformation.
20  type Error: std::error::Error + From<UnknownMessageType> + Send + Sync + 'static;
21  /// The Id type.
22  type Id: Id;
23  /// The Address type.
24  type Address: CheapClone + Send + Sync + 'static;
25
26  /// Encodes the filter into bytes.
27  fn encode_filter(filter: &Filter<Self::Id>) -> Result<Bytes, Self::Error>;
28
29  /// Decodes the filter from the given bytes, returning the number of bytes consumed and the filter.
30  fn decode_filter(bytes: &[u8]) -> Result<(usize, Filter<Self::Id>), Self::Error>;
31
32  /// Returns the encoded length of the node.
33  fn node_encoded_len(node: &Node<Self::Id, Self::Address>) -> usize;
34
35  /// Encodes the node into the given buffer, returning the number of bytes written.
36  fn encode_node(
37    node: &Node<Self::Id, Self::Address>,
38    dst: &mut [u8],
39  ) -> Result<usize, Self::Error>;
40
41  /// Decodes [`Node`] from the given bytes, returning the number of bytes consumed and the node.
42  fn decode_node(
43    bytes: impl AsRef<[u8]>,
44  ) -> Result<(usize, Node<Self::Id, Self::Address>), Self::Error>;
45
46  /// Returns the encoded length of the id.
47  fn id_encoded_len(id: &Self::Id) -> usize;
48
49  /// Encodes the id into the given buffer, returning the number of bytes written.
50  fn encode_id(id: &Self::Id, dst: &mut [u8]) -> Result<usize, Self::Error>;
51
52  /// Decodes the id from the given bytes, returning the number of bytes consumed and the id.
53  fn decode_id(bytes: &[u8]) -> Result<(usize, Self::Id), Self::Error>;
54
55  /// Returns the encoded length of the address.
56  fn address_encoded_len(address: &Self::Address) -> usize;
57
58  /// Encodes the address into the given buffer, returning the number of bytes written.
59  fn encode_address(address: &Self::Address, dst: &mut [u8]) -> Result<usize, Self::Error>;
60
61  /// Decodes the address from the given bytes, returning the number of bytes consumed and the address.
62  fn decode_address(bytes: &[u8]) -> Result<(usize, Self::Address), Self::Error>;
63
64  /// Encoded length of the coordinate.
65  fn coordinate_encoded_len(coordinate: &Coordinate) -> usize;
66
67  /// Encodes the coordinate into the given buffer, returning the number of bytes written.
68  fn encode_coordinate(coordinate: &Coordinate, dst: &mut [u8]) -> Result<usize, Self::Error>;
69
70  /// Decodes the coordinate from the given bytes, returning the number of bytes consumed and the coordinate.
71  fn decode_coordinate(bytes: &[u8]) -> Result<(usize, Coordinate), Self::Error>;
72
73  /// Encoded length of the tags.
74  fn tags_encoded_len(tags: &Tags) -> usize;
75
76  /// Encodes the tags into the given buffer, returning the number of bytes written.
77  fn encode_tags(tags: &Tags, dst: &mut [u8]) -> Result<usize, Self::Error>;
78
79  /// Decodes the tags from the given bytes, returning the number of bytes consumed and the tags.
80  fn decode_tags(bytes: &[u8]) -> Result<(usize, Tags), Self::Error>;
81
82  /// Encoded length of the message.
83  fn message_encoded_len(msg: impl AsMessageRef<Self::Id, Self::Address>) -> usize;
84
85  /// Encodes the message into the given buffer, returning the number of bytes written.
86  ///
87  /// **NOTE**:
88  ///
89  /// 1. The buffer must be large enough to hold the encoded message.
90  ///    The length of the buffer can be obtained by calling [`TransformDelegate::message_encoded_len`].
91  /// 2. A message type byte will be automatically prepended to the buffer,
92  ///    so users do not need to encode the message type byte by themselves.
93  fn encode_message(
94    msg: impl AsMessageRef<Self::Id, Self::Address>,
95    dst: impl AsMut<[u8]>,
96  ) -> Result<usize, Self::Error>;
97
98  /// Decodes the message from the given bytes, returning the number of bytes consumed and the message.
99  fn decode_message(
100    ty: MessageType,
101    bytes: impl AsRef<[u8]>,
102  ) -> Result<(usize, SerfMessage<Self::Id, Self::Address>), Self::Error>;
103}
104
105/// The error type for the LPE transformation.
106#[derive(thiserror::Error)]
107pub enum LpeTransformError<I, A>
108where
109  I: Transformable + core::hash::Hash + Eq,
110  A: Transformable + core::hash::Hash + Eq,
111{
112  /// Id transformation error.
113  #[error(transparent)]
114  Id(<I as Transformable>::Error),
115  /// Address transformation error.
116  #[error(transparent)]
117  Address(<A as Transformable>::Error),
118  /// Coordinate transformation error.
119  #[error(transparent)]
120  Coordinate(#[from] CoordinateTransformError),
121  /// Node transformation error.
122  #[error(transparent)]
123  Node(#[from] NodeTransformError<I, A>),
124  /// Filter transformation error.
125  #[error(transparent)]
126  Filter(#[from] FilterTransformError<I>),
127  /// Tags transformation error.
128  #[error(transparent)]
129  Tags(#[from] TagsTransformError),
130  /// Serf message transformation error.
131  #[error(transparent)]
132  Message(#[from] SerfMessageTransformError<I, A>),
133  /// Unknown message type error.
134  #[error(transparent)]
135  UnknownMessage(#[from] UnknownMessageType),
136  /// Unexpected relay message.
137  #[error("unexpected relay message")]
138  UnexpectedRelayMessage,
139}
140
141impl<I, A> core::fmt::Debug for LpeTransformError<I, A>
142where
143  I: Transformable + core::hash::Hash + Eq,
144  A: Transformable + core::hash::Hash + Eq,
145{
146  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
147    write!(f, "{}", self)
148  }
149}
150
151/// A length-prefixed encoding [`TransformDelegate`] implementation
152pub struct LpeTransfromDelegate<I, A>(std::marker::PhantomData<(I, A)>);
153
154impl<I, A> Default for LpeTransfromDelegate<I, A> {
155  fn default() -> Self {
156    Self(Default::default())
157  }
158}
159
160impl<I, A> Clone for LpeTransfromDelegate<I, A> {
161  fn clone(&self) -> Self {
162    *self
163  }
164}
165
166impl<I, A> Copy for LpeTransfromDelegate<I, A> {}
167
168impl<I, A> TransformDelegate for LpeTransfromDelegate<I, A>
169where
170  I: Id,
171  A: Transformable + CheapClone + core::hash::Hash + Eq + Send + Sync + 'static,
172{
173  type Error = LpeTransformError<Self::Id, Self::Address>;
174  type Id = I;
175  type Address = A;
176
177  fn encode_filter(filter: &Filter<Self::Id>) -> Result<Bytes, Self::Error> {
178    filter
179      .encode_to_vec()
180      .map(Bytes::from)
181      .map_err(Self::Error::Filter)
182  }
183
184  fn decode_filter(bytes: &[u8]) -> Result<(usize, Filter<Self::Id>), Self::Error> {
185    Filter::decode(bytes).map_err(Self::Error::Filter)
186  }
187
188  fn node_encoded_len(node: &Node<Self::Id, Self::Address>) -> usize {
189    Transformable::encoded_len(node)
190  }
191
192  fn encode_node(
193    node: &Node<Self::Id, Self::Address>,
194    dst: &mut [u8],
195  ) -> Result<usize, Self::Error> {
196    Transformable::encode(node, dst).map_err(Self::Error::Node)
197  }
198
199  fn decode_node(
200    bytes: impl AsRef<[u8]>,
201  ) -> Result<(usize, Node<Self::Id, Self::Address>), Self::Error> {
202    Transformable::decode(bytes.as_ref()).map_err(Self::Error::Node)
203  }
204
205  fn id_encoded_len(id: &Self::Id) -> usize {
206    Transformable::encoded_len(id)
207  }
208
209  fn encode_id(id: &Self::Id, dst: &mut [u8]) -> Result<usize, Self::Error> {
210    Transformable::encode(id, dst).map_err(Self::Error::Id)
211  }
212
213  fn decode_id(bytes: &[u8]) -> Result<(usize, Self::Id), Self::Error> {
214    Transformable::decode(bytes).map_err(Self::Error::Id)
215  }
216
217  fn address_encoded_len(address: &Self::Address) -> usize {
218    Transformable::encoded_len(address)
219  }
220
221  fn encode_address(address: &Self::Address, dst: &mut [u8]) -> Result<usize, Self::Error> {
222    Transformable::encode(address, dst).map_err(Self::Error::Address)
223  }
224
225  fn decode_address(bytes: &[u8]) -> Result<(usize, Self::Address), Self::Error> {
226    Transformable::decode(bytes).map_err(Self::Error::Address)
227  }
228
229  fn coordinate_encoded_len(coordinate: &Coordinate) -> usize {
230    Transformable::encoded_len(coordinate)
231  }
232
233  fn encode_coordinate(coordinate: &Coordinate, dst: &mut [u8]) -> Result<usize, Self::Error> {
234    Transformable::encode(coordinate, dst).map_err(Self::Error::Coordinate)
235  }
236
237  fn decode_coordinate(bytes: &[u8]) -> Result<(usize, Coordinate), Self::Error> {
238    Transformable::decode(bytes).map_err(Self::Error::Coordinate)
239  }
240
241  fn tags_encoded_len(tags: &Tags) -> usize {
242    Transformable::encoded_len(tags)
243  }
244
245  fn encode_tags(tags: &Tags, dst: &mut [u8]) -> Result<usize, Self::Error> {
246    Transformable::encode(tags, dst).map_err(Self::Error::Tags)
247  }
248
249  fn decode_tags(bytes: &[u8]) -> Result<(usize, Tags), Self::Error> {
250    Transformable::decode(bytes).map_err(Self::Error::Tags)
251  }
252
253  fn message_encoded_len(msg: impl AsMessageRef<Self::Id, Self::Address>) -> usize {
254    let msg = msg.as_message_ref();
255    serf_types::Encodable::encoded_len(&msg)
256  }
257
258  fn encode_message(
259    msg: impl AsMessageRef<Self::Id, Self::Address>,
260    mut dst: impl AsMut<[u8]>,
261  ) -> Result<usize, Self::Error> {
262    let msg = msg.as_message_ref();
263    serf_types::Encodable::encode(&msg, dst.as_mut()).map_err(Into::into)
264  }
265
266  fn decode_message(
267    ty: MessageType,
268    bytes: impl AsRef<[u8]>,
269  ) -> Result<(usize, SerfMessage<Self::Id, Self::Address>), Self::Error> {
270    match ty {
271      MessageType::Leave => LeaveMessage::decode(bytes.as_ref())
272        .map(|(n, m)| (n, SerfMessage::Leave(m)))
273        .map_err(|e| Self::Error::Message(e.into())),
274      MessageType::Join => JoinMessage::decode(bytes.as_ref())
275        .map(|(n, m)| (n, SerfMessage::Join(m)))
276        .map_err(|e| Self::Error::Message(e.into())),
277      MessageType::PushPull => PushPullMessage::decode(bytes.as_ref())
278        .map(|(n, m)| (n, SerfMessage::PushPull(m)))
279        .map_err(|e| Self::Error::Message(e.into())),
280      MessageType::UserEvent => UserEventMessage::decode(bytes.as_ref())
281        .map(|(n, m)| (n, SerfMessage::UserEvent(m)))
282        .map_err(|e| Self::Error::Message(e.into())),
283      MessageType::Query => QueryMessage::decode(bytes.as_ref())
284        .map(|(n, m)| (n, SerfMessage::Query(m)))
285        .map_err(|e| Self::Error::Message(e.into())),
286      MessageType::QueryResponse => QueryResponseMessage::decode(bytes.as_ref())
287        .map(|(n, m)| (n, SerfMessage::QueryResponse(m)))
288        .map_err(|e| Self::Error::Message(e.into())),
289      MessageType::ConflictResponse => Member::decode(bytes.as_ref())
290        .map(|(n, m)| (n, SerfMessage::ConflictResponse(m)))
291        .map_err(|e| Self::Error::Message(e.into())),
292      MessageType::Relay => Err(Self::Error::UnexpectedRelayMessage),
293      #[cfg(feature = "encryption")]
294      MessageType::KeyRequest => serf_types::KeyRequestMessage::decode(bytes.as_ref())
295        .map(|(n, m)| (n, SerfMessage::KeyRequest(m)))
296        .map_err(|e| Self::Error::Message(e.into())),
297      #[cfg(feature = "encryption")]
298      MessageType::KeyResponse => serf_types::KeyResponseMessage::decode(bytes.as_ref())
299        .map(|(n, m)| (n, SerfMessage::KeyResponse(m)))
300        .map_err(|e| Self::Error::Message(e.into())),
301      _ => unreachable!(),
302    }
303  }
304}