Skip to main content

volans_core/
multiaddr.rs

1use std::{
2    fmt,
3    net::{IpAddr, Ipv4Addr, Ipv6Addr},
4    str::FromStr,
5};
6
7use bytes::{BufMut, Bytes, BytesMut};
8use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
9
10use crate::PeerId;
11
12mod error;
13mod from_url;
14mod protocol;
15
16pub use error::Error;
17pub use from_url::{FromUrlErr, from_url, from_url_lossy};
18pub use protocol::Protocol;
19
20#[allow(clippy::rc_buffer)]
21#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
22pub struct Multiaddr {
23    bytes: Bytes,
24}
25
26impl Multiaddr {
27    pub fn empty() -> Self {
28        Self {
29            bytes: Bytes::new(),
30        }
31    }
32
33    pub fn with_capacity(n: usize) -> Self {
34        Self {
35            bytes: BytesMut::with_capacity(n).freeze(),
36        }
37    }
38
39    pub fn len(&self) -> usize {
40        self.bytes.len()
41    }
42
43    pub fn is_empty(&self) -> bool {
44        self.bytes.len() == 0
45    }
46
47    pub fn to_vec(&self) -> Vec<u8> {
48        Vec::from(&self.bytes[..])
49    }
50
51    pub fn push(&mut self, p: Protocol<'_>) {
52        let mut bytes = BytesMut::from(std::mem::take(&mut self.bytes));
53        p.write_bytes(&mut (&mut bytes).writer())
54            .expect("Writing to a `BytesMut` never fails.");
55        self.bytes = bytes.freeze();
56    }
57
58    pub fn pop<'a>(&mut self) -> Option<Protocol<'a>> {
59        let mut slice = &self.bytes[..]; // the remaining multiaddr slice
60        if slice.is_empty() {
61            return None;
62        }
63        let protocol = loop {
64            let (p, s) = Protocol::from_bytes(slice).expect("`slice` is a valid `Protocol`.");
65            if s.is_empty() {
66                break p.acquire();
67            }
68            slice = s
69        };
70        let remaining_len = self.len() - slice.len();
71        let mut bytes = BytesMut::from(std::mem::take(&mut self.bytes));
72        bytes.truncate(remaining_len);
73        self.bytes = bytes.freeze();
74        Some(protocol)
75    }
76
77    pub fn with(mut self, p: Protocol<'_>) -> Self {
78        let mut bytes = BytesMut::from(std::mem::take(&mut self.bytes));
79        p.write_bytes(&mut (&mut bytes).writer())
80            .expect("Writing to a `BytesMut` never fails.");
81        self.bytes = bytes.freeze();
82        self
83    }
84
85    pub fn with_peer(self, peer: PeerId) -> std::result::Result<Self, Self> {
86        match self.iter().last() {
87            Some(Protocol::Peer(p)) if p == peer => Ok(self),
88            Some(Protocol::Peer(_)) => Err(self),
89            _ => Ok(self.with(Protocol::Peer(peer))),
90        }
91    }
92
93    pub fn iter(&self) -> Iter<'_> {
94        Iter(&self.bytes)
95    }
96
97    pub fn replace<'a, F>(&self, at: usize, by: F) -> Option<Multiaddr>
98    where
99        F: FnOnce(&Protocol<'_>) -> Option<Protocol<'a>>,
100    {
101        let mut address = Multiaddr::with_capacity(self.len());
102        let mut fun = Some(by);
103        let mut replaced = false;
104
105        for (i, p) in self.iter().enumerate() {
106            if i == at {
107                let f = fun.take().expect("i == at only happens once");
108                if let Some(q) = f(&p) {
109                    address = address.with(q);
110                    replaced = true;
111                    continue;
112                }
113                return None;
114            }
115            address = address.with(p)
116        }
117
118        if replaced { Some(address) } else { None }
119    }
120
121    pub fn ends_with(&self, other: &Multiaddr) -> bool {
122        let n = self.bytes.len();
123        let m = other.bytes.len();
124        if n < m {
125            return false;
126        }
127        self.bytes[(n - m)..] == other.bytes[..]
128    }
129
130    pub fn starts_with(&self, other: &Multiaddr) -> bool {
131        let n = self.bytes.len();
132        let m = other.bytes.len();
133        if n < m {
134            return false;
135        }
136        self.bytes[..m] == other.bytes[..]
137    }
138
139    pub fn protocol_stack(&self) -> ProtoStackIter {
140        ProtoStackIter { parts: self.iter() }
141    }
142}
143
144impl fmt::Debug for Multiaddr {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        fmt::Display::fmt(self, f)
147    }
148}
149
150impl fmt::Display for Multiaddr {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        for s in self.iter() {
153            s.fmt(f)?;
154        }
155        Ok(())
156    }
157}
158
159impl AsRef<[u8]> for Multiaddr {
160    fn as_ref(&self) -> &[u8] {
161        self.bytes.as_ref()
162    }
163}
164
165impl<'a> IntoIterator for &'a Multiaddr {
166    type Item = Protocol<'a>;
167    type IntoIter = Iter<'a>;
168
169    fn into_iter(self) -> Iter<'a> {
170        Iter(&self.bytes)
171    }
172}
173
174impl<'a> FromIterator<Protocol<'a>> for Multiaddr {
175    fn from_iter<T>(iter: T) -> Self
176    where
177        T: IntoIterator<Item = Protocol<'a>>,
178    {
179        let mut bytes = BytesMut::new();
180        for cmp in iter {
181            cmp.write_bytes(&mut (&mut bytes).writer())
182                .expect("Writing to a `BytesMut` never fails.");
183        }
184        Multiaddr {
185            bytes: bytes.freeze(),
186        }
187    }
188}
189
190impl FromStr for Multiaddr {
191    type Err = Error;
192
193    fn from_str(input: &str) -> Result<Self, Error> {
194        let mut bytes = BytesMut::new();
195        let mut parts = input.split('/').peekable();
196
197        if Some("") != parts.next() {
198            // A multiaddr must start with `/`
199            return Err(Error::InvalidMultiaddr);
200        }
201
202        while parts.peek().is_some() {
203            let p = Protocol::from_str_parts(&mut parts)?;
204            p.write_bytes(&mut (&mut bytes).writer())
205                .expect("Writing to a `BytesMut` never fails.");
206        }
207
208        Ok(Multiaddr {
209            bytes: bytes.freeze(),
210        })
211    }
212}
213
214/// Iterator over `Multiaddr` [`Protocol`]s.
215pub struct Iter<'a>(&'a [u8]);
216
217impl<'a> Iterator for Iter<'a> {
218    type Item = Protocol<'a>;
219
220    fn next(&mut self) -> Option<Self::Item> {
221        if self.0.is_empty() {
222            return None;
223        }
224
225        let (p, next_data) =
226            Protocol::from_bytes(self.0).expect("`Multiaddr` is known to be valid.");
227
228        self.0 = next_data;
229        Some(p)
230    }
231}
232
233/// Iterator over the string identifiers of the protocols (not addrs) in a multiaddr
234pub struct ProtoStackIter<'a> {
235    parts: Iter<'a>,
236}
237
238impl Iterator for ProtoStackIter<'_> {
239    type Item = &'static str;
240    fn next(&mut self) -> Option<Self::Item> {
241        self.parts.next().as_ref().map(Protocol::tag)
242    }
243}
244
245impl<'a> From<Protocol<'a>> for Multiaddr {
246    fn from(p: Protocol<'a>) -> Multiaddr {
247        let mut bytes = BytesMut::new();
248        p.write_bytes(&mut (&mut bytes).writer())
249            .expect("Writing to a `BytesMut` never fails.");
250        Multiaddr {
251            bytes: bytes.freeze(),
252        }
253    }
254}
255
256impl From<IpAddr> for Multiaddr {
257    fn from(v: IpAddr) -> Multiaddr {
258        match v {
259            IpAddr::V4(a) => a.into(),
260            IpAddr::V6(a) => a.into(),
261        }
262    }
263}
264
265impl From<Ipv4Addr> for Multiaddr {
266    fn from(v: Ipv4Addr) -> Multiaddr {
267        Protocol::Ip4(v).into()
268    }
269}
270
271impl From<Ipv6Addr> for Multiaddr {
272    fn from(v: Ipv6Addr) -> Multiaddr {
273        Protocol::Ip6(v).into()
274    }
275}
276
277impl TryFrom<Vec<u8>> for Multiaddr {
278    type Error = Error;
279
280    fn try_from(v: Vec<u8>) -> Result<Self, Error> {
281        // Check if the argument is a valid `Multiaddr` by reading its protocols.
282        let mut slice = &v[..];
283        while !slice.is_empty() {
284            let (_, s) = Protocol::from_bytes(slice)?;
285            slice = s
286        }
287        Ok(Multiaddr {
288            bytes: Bytes::from(v),
289        })
290    }
291}
292
293impl TryFrom<String> for Multiaddr {
294    type Error = Error;
295
296    fn try_from(s: String) -> Result<Multiaddr, Error> {
297        s.parse()
298    }
299}
300
301impl<'a> TryFrom<&'a str> for Multiaddr {
302    type Error = Error;
303
304    fn try_from(s: &'a str) -> Result<Multiaddr, Error> {
305        s.parse()
306    }
307}
308
309impl Serialize for Multiaddr {
310    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
311    where
312        S: Serializer,
313    {
314        if serializer.is_human_readable() {
315            serializer.serialize_str(&self.to_string())
316        } else {
317            serializer.serialize_bytes(self.as_ref())
318        }
319    }
320}
321
322impl<'de> Deserialize<'de> for Multiaddr {
323    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
324    where
325        D: Deserializer<'de>,
326    {
327        struct Visitor {
328            is_human_readable: bool,
329        }
330
331        impl<'de> de::Visitor<'de> for Visitor {
332            type Value = Multiaddr;
333
334            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
335                formatter.write_str("multiaddress")
336            }
337            fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
338                let mut buf: Vec<u8> =
339                    Vec::with_capacity(std::cmp::min(seq.size_hint().unwrap_or(0), 4096));
340                while let Some(e) = seq.next_element()? {
341                    buf.push(e);
342                }
343                if self.is_human_readable {
344                    let s = String::from_utf8(buf).map_err(de::Error::custom)?;
345                    s.parse().map_err(de::Error::custom)
346                } else {
347                    Multiaddr::try_from(buf).map_err(de::Error::custom)
348                }
349            }
350            fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
351                v.parse().map_err(de::Error::custom)
352            }
353            fn visit_borrowed_str<E: de::Error>(self, v: &'de str) -> Result<Self::Value, E> {
354                self.visit_str(v)
355            }
356            fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
357                self.visit_str(&v)
358            }
359            fn visit_bytes<E: de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
360                self.visit_byte_buf(v.into())
361            }
362            fn visit_borrowed_bytes<E: de::Error>(self, v: &'de [u8]) -> Result<Self::Value, E> {
363                self.visit_byte_buf(v.into())
364            }
365            fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
366                Multiaddr::try_from(v).map_err(de::Error::custom)
367            }
368        }
369
370        if deserializer.is_human_readable() {
371            deserializer.deserialize_str(Visitor {
372                is_human_readable: true,
373            })
374        } else {
375            deserializer.deserialize_bytes(Visitor {
376                is_human_readable: false,
377            })
378        }
379    }
380}