tycho_common/
hex_bytes.rs

1use std::{
2    borrow::Borrow,
3    clone::Clone,
4    fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5    ops::Deref,
6    str::FromStr,
7};
8
9use deepsize::{Context, DeepSizeOf};
10#[cfg(feature = "diesel")]
11use diesel::{
12    deserialize::{self, FromSql, FromSqlRow},
13    expression::AsExpression,
14    pg::Pg,
15    serialize::{self, ToSql},
16    sql_types::Binary,
17};
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use crate::serde_primitives::hex_bytes;
23
24/// Wrapper type around Bytes to deserialize/serialize from/to hex
25#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
26#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
27#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
28pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
29
30impl DeepSizeOf for Bytes {
31    fn deep_size_of_children(&self, _ctx: &mut Context) -> usize {
32        // Note: This may overcount memory if the underlying bytes are shared (e.g. via Arc).
33        // We cannot detect shared ownership here because Context’s internal tracking is private
34        // At the same time, this might also underreport memory if:
35        // - the bytes::Bytes are instantiated as Shared internally, which adds 24 bytes of overhead
36        // - the bytes::Bytes has capacity greater than its length, as we only count the length here
37        self.0.len()
38    }
39}
40
41fn bytes_to_hex(b: &Bytes) -> String {
42    hex::encode(b.0.as_ref())
43}
44
45impl Debug for Bytes {
46    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47        write!(f, "Bytes(0x{})", bytes_to_hex(self))
48    }
49}
50
51impl Display for Bytes {
52    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
53        write!(f, "0x{}", bytes_to_hex(self))
54    }
55}
56
57impl LowerHex for Bytes {
58    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
59        write!(f, "0x{}", bytes_to_hex(self))
60    }
61}
62
63impl Bytes {
64    pub fn new() -> Self {
65        Self(bytes::Bytes::new())
66    }
67    /// This function converts the internal byte array into a `Vec<u8>`
68    ///
69    /// # Returns
70    ///
71    /// A `Vec<u8>` containing the bytes from the `Bytes` struct.
72    ///
73    /// # Example
74    ///
75    /// ```
76    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
77    /// let vec = bytes.to_vec();
78    /// assert_eq!(vec, vec![0x01, 0x02, 0x03]);
79    /// ```
80    pub fn to_vec(&self) -> Vec<u8> {
81        self.as_ref().to_vec()
82    }
83
84    /// Left-pads the byte array to the specified length with the given padding byte.
85    ///
86    /// This function creates a new `Bytes` instance by prepending the specified padding byte
87    /// to the current byte array until its total length matches the desired length.
88    ///
89    /// If the current length of the byte array is greater than or equal to the specified length,
90    /// the original byte array is returned unchanged.
91    ///
92    /// # Arguments
93    ///
94    /// * `length` - The desired total length of the resulting byte array.
95    /// * `pad_byte` - The byte value to use for padding. Commonly `0x00`.
96    ///
97    /// # Returns
98    ///
99    /// A new `Bytes` instance with the byte array left-padded to the desired length.
100    ///
101    /// # Example
102    ///
103    /// ```
104    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
105    /// let padded = bytes.lpad(6, 0x00);
106    /// assert_eq!(padded.to_vec(), vec![0x00, 0x00, 0x00, 0x01, 0x02, 0x03]);
107    /// ```
108    pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
109        let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
110        padded_vec.extend_from_slice(self.as_ref());
111
112        Bytes(bytes::Bytes::from(padded_vec))
113    }
114
115    /// Right-pads the byte array to the specified length with the given padding byte.
116    ///
117    /// This function creates a new `Bytes` instance by appending the specified padding byte
118    /// to the current byte array until its total length matches the desired length.
119    ///
120    /// If the current length of the byte array is greater than or equal to the specified length,
121    /// the original byte array is returned unchanged.
122    ///
123    /// # Arguments
124    ///
125    /// * `length` - The desired total length of the resulting byte array.
126    /// * `pad_byte` - The byte value to use for padding. Commonly `0x00`.
127    ///
128    /// # Returns
129    ///
130    /// A new `Bytes` instance with the byte array right-padded to the desired length.
131    ///
132    /// # Example
133    ///
134    /// ```
135    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
136    /// let padded = bytes.rpad(6, 0x00);
137    /// assert_eq!(padded.to_vec(), vec![0x01, 0x02, 0x03, 0x00, 0x00, 0x00]);
138    /// ```
139    pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
140        let mut padded_vec = self.to_vec();
141        padded_vec.resize(length, pad_byte);
142
143        Bytes(bytes::Bytes::from(padded_vec))
144    }
145
146    /// Creates a `Bytes` object of the specified length, filled with zeros.
147    ///
148    /// # Arguments
149    ///
150    /// * `length` - The length of the `Bytes` object to be created.
151    ///
152    /// # Returns
153    ///
154    /// A `Bytes` object of the specified length, where each byte is set to zero.
155    ///
156    /// # Example
157    ///
158    /// ```
159    /// let b = Bytes::zero(5);
160    /// assert_eq!(b, Bytes::from(vec![0, 0, 0, 0, 0]));
161    /// ```
162    pub fn zero(length: usize) -> Bytes {
163        Bytes::from(vec![0u8; length])
164    }
165
166    /// Creates a `Bytes` object of the specified length, filled with random bytes.
167    ///
168    /// # Arguments
169    ///
170    /// * `length` - The length of the `Bytes` object to be created.
171    ///
172    /// # Returns
173    ///
174    /// A `Bytes` object of the specified length, filled with random bytes.
175    ///
176    /// # Example
177    ///
178    /// ```
179    /// let random_bytes = Bytes::random(5);
180    /// assert_eq!(random_bytes.len(), 5);
181    /// ```
182    pub fn random(length: usize) -> Bytes {
183        let mut data = vec![0u8; length];
184        rand::thread_rng().fill(&mut data[..]);
185        Bytes::from(data)
186    }
187
188    /// Checks if the byte array is full of zeros.
189    ///
190    /// # Returns
191    ///
192    /// A boolean value indicating whether all bytes in the byte array are zero.
193    ///
194    /// # Example
195    ///
196    /// ```
197    /// let b = Bytes::zero(5);
198    /// assert!(b.is_zero());
199    /// ```
200    pub fn is_zero(&self) -> bool {
201        self.as_ref().iter().all(|b| *b == 0)
202    }
203}
204
205impl Deref for Bytes {
206    type Target = [u8];
207
208    #[inline]
209    fn deref(&self) -> &[u8] {
210        self.as_ref()
211    }
212}
213
214impl AsRef<[u8]> for Bytes {
215    fn as_ref(&self) -> &[u8] {
216        self.0.as_ref()
217    }
218}
219
220impl Borrow<[u8]> for Bytes {
221    fn borrow(&self) -> &[u8] {
222        self.as_ref()
223    }
224}
225
226impl IntoIterator for Bytes {
227    type Item = u8;
228    type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
229
230    fn into_iter(self) -> Self::IntoIter {
231        self.0.into_iter()
232    }
233}
234
235impl<'a> IntoIterator for &'a Bytes {
236    type Item = &'a u8;
237    type IntoIter = core::slice::Iter<'a, u8>;
238
239    fn into_iter(self) -> Self::IntoIter {
240        self.as_ref().iter()
241    }
242}
243
244impl From<&[u8]> for Bytes {
245    fn from(src: &[u8]) -> Self {
246        Self(bytes::Bytes::copy_from_slice(src))
247    }
248}
249
250impl From<bytes::Bytes> for Bytes {
251    fn from(src: bytes::Bytes) -> Self {
252        Self(src)
253    }
254}
255
256impl From<Bytes> for bytes::Bytes {
257    fn from(src: Bytes) -> Self {
258        src.0
259    }
260}
261
262impl From<Vec<u8>> for Bytes {
263    fn from(src: Vec<u8>) -> Self {
264        Self(src.into())
265    }
266}
267
268impl From<Bytes> for Vec<u8> {
269    fn from(value: Bytes) -> Self {
270        value.to_vec()
271    }
272}
273
274impl<const N: usize> From<[u8; N]> for Bytes {
275    fn from(src: [u8; N]) -> Self {
276        src.to_vec().into()
277    }
278}
279
280impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
281    fn from(src: &'a [u8; N]) -> Self {
282        src.to_vec().into()
283    }
284}
285
286impl PartialEq<[u8]> for Bytes {
287    fn eq(&self, other: &[u8]) -> bool {
288        self.as_ref() == other
289    }
290}
291
292impl PartialEq<Bytes> for [u8] {
293    fn eq(&self, other: &Bytes) -> bool {
294        *other == *self
295    }
296}
297
298impl PartialEq<Vec<u8>> for Bytes {
299    fn eq(&self, other: &Vec<u8>) -> bool {
300        self.as_ref() == &other[..]
301    }
302}
303
304impl PartialEq<Bytes> for Vec<u8> {
305    fn eq(&self, other: &Bytes) -> bool {
306        *other == *self
307    }
308}
309
310impl PartialEq<bytes::Bytes> for Bytes {
311    fn eq(&self, other: &bytes::Bytes) -> bool {
312        other == self.as_ref()
313    }
314}
315
316#[derive(Debug, Clone, Error)]
317#[error("Failed to parse bytes: {0}")]
318pub struct ParseBytesError(String);
319
320impl FromStr for Bytes {
321    type Err = ParseBytesError;
322
323    fn from_str(value: &str) -> Result<Self, Self::Err> {
324        if let Some(value) = value.strip_prefix("0x") {
325            hex::decode(value)
326        } else {
327            hex::decode(value)
328        }
329        .map(Into::into)
330        .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
331    }
332}
333
334impl From<&str> for Bytes {
335    fn from(value: &str) -> Self {
336        value.parse().unwrap()
337    }
338}
339
340#[cfg(feature = "diesel")]
341impl ToSql<Binary, Pg> for Bytes {
342    fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
343        let bytes_slice: &[u8] = &self.0;
344        <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
345    }
346}
347
348#[cfg(feature = "diesel")]
349impl FromSql<Binary, Pg> for Bytes {
350    fn from_sql(
351        bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
352    ) -> deserialize::Result<Self> {
353        let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
354        Ok(Bytes(bytes::Bytes::from(byte_vec)))
355    }
356}
357
358macro_rules! impl_from_uint_for_bytes {
359    ($($t:ty),*) => {
360        $(
361            impl From<$t> for Bytes {
362                fn from(src: $t) -> Self {
363                    let size = std::mem::size_of::<$t>();
364                    let mut buf = vec![0u8; size];
365                    buf.copy_from_slice(&src.to_be_bytes());
366
367                    Self(bytes::Bytes::from(buf))
368                }
369            }
370        )*
371    };
372}
373
374impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
375
376macro_rules! impl_from_bytes_for_uint {
377    ($($t:ty),*) => {
378        $(
379            impl From<Bytes> for $t {
380                fn from(src: Bytes) -> Self {
381                    let bytes_slice = src.as_ref();
382
383                    // Create an array with zeros.
384                    let mut buf = [0u8; std::mem::size_of::<$t>()];
385
386                    // Copy bytes from `bytes_slice` to the end of `buf` to maintain big-endian order.
387                    buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
388
389                    // Convert to the integer type using big-endian.
390                    <$t>::from_be_bytes(buf)
391                }
392            }
393        )*
394    };
395}
396
397impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
398
399macro_rules! impl_from_bytes_for_signed_int {
400    ($($t:ty),*) => {
401        $(
402            impl From<Bytes> for $t {
403                fn from(src: Bytes) -> Self {
404                    let bytes_slice = src.as_ref();
405
406                    // Create an array with zeros or ones for negative numbers.
407                    let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
408                        [0xFFu8; std::mem::size_of::<$t>()] // Sign-extend with 0xFF for negative numbers.
409                    } else {
410                        [0x00u8; std::mem::size_of::<$t>()] // Fill with 0x00 for positive numbers.
411                    };
412
413                    // Copy bytes from `bytes_slice` to the end of `buf` to maintain big-endian order.
414                    buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
415
416                    // Convert to the signed integer type using big-endian.
417                    <$t>::from_be_bytes(buf)
418                }
419            }
420        )*
421    };
422}
423
424impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    #[test]
431    fn test_from_bytes() {
432        let b = bytes::Bytes::from("0123456789abcdef");
433        let wrapped_b = Bytes::from(b.clone());
434        let expected = Bytes(b);
435
436        assert_eq!(wrapped_b, expected);
437    }
438
439    #[test]
440    fn test_from_slice() {
441        let arr = [1, 35, 69, 103, 137, 171, 205, 239];
442        let b = Bytes::from(&arr);
443        let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
444
445        assert_eq!(b, expected);
446    }
447
448    #[test]
449    fn hex_formatting() {
450        let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
451        let expected = String::from("0x0123456789abcdef");
452        assert_eq!(format!("{b:x}"), expected);
453        assert_eq!(format!("{b}"), expected);
454    }
455
456    #[test]
457    fn test_from_str() {
458        let b = Bytes::from_str("0x1213");
459        assert!(b.is_ok());
460        let b = b.unwrap();
461        assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
462
463        let b = Bytes::from_str("1213");
464        let b = b.unwrap();
465        assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
466    }
467
468    #[test]
469    fn test_debug_formatting() {
470        let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
471        assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
472        assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
473    }
474
475    #[test]
476    fn test_to_vec() {
477        let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
478        let b = Bytes::from(vec.clone());
479
480        assert_eq!(b.to_vec(), vec);
481    }
482
483    #[test]
484    fn test_vec_partialeq() {
485        let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
486        let b = Bytes::from(vec.clone());
487        assert_eq!(b, vec);
488        assert_eq!(vec, b);
489
490        let wrong_vec = vec![1, 3, 52, 137];
491        assert_ne!(b, wrong_vec);
492        assert_ne!(wrong_vec, b);
493    }
494
495    #[test]
496    fn test_bytes_partialeq() {
497        let b = bytes::Bytes::from("0123456789abcdef");
498        let wrapped_b = Bytes::from(b.clone());
499        assert_eq!(wrapped_b, b);
500
501        let wrong_b = bytes::Bytes::from("0123absd");
502        assert_ne!(wrong_b, b);
503    }
504
505    #[test]
506    fn test_u128_from_bytes() {
507        let data = Bytes::from(vec![4, 3, 2, 1]);
508        let result: u128 = u128::from(data.clone());
509        assert_eq!(result, u128::from_str("67305985").unwrap());
510    }
511
512    #[test]
513    fn test_i128_from_bytes() {
514        let data = Bytes::from(vec![4, 3, 2, 1]);
515        let result: i128 = i128::from(data.clone());
516        assert_eq!(result, i128::from_str("67305985").unwrap());
517    }
518
519    #[test]
520    fn test_i32_from_bytes() {
521        let data = Bytes::from(vec![4, 3, 2, 1]);
522        let result: i32 = i32::from(data);
523        assert_eq!(result, i32::from_str("67305985").unwrap());
524    }
525}
526
527#[cfg(feature = "diesel")]
528#[cfg(test)]
529mod diesel_tests {
530    use diesel::{insert_into, table, Insertable, Queryable};
531    use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
532
533    use super::*;
534
535    async fn setup_db() -> AsyncPgConnection {
536        let db_url = std::env::var("DATABASE_URL").unwrap();
537        let mut conn = AsyncPgConnection::establish(&db_url)
538            .await
539            .unwrap();
540        conn.begin_test_transaction()
541            .await
542            .unwrap();
543        conn
544    }
545
546    #[tokio::test]
547    async fn test_bytes_db_round_trip() {
548        table! {
549            bytes_table (id) {
550                id -> Int4,
551                data -> Binary,
552            }
553        }
554
555        #[derive(Insertable)]
556        #[diesel(table_name = bytes_table)]
557        struct NewByteEntry {
558            data: Bytes,
559        }
560
561        #[derive(Queryable, PartialEq)]
562        struct ByteEntry {
563            id: i32,
564            data: Bytes,
565        }
566
567        let mut conn = setup_db().await;
568        let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
569
570        conn.batch_execute(
571            r"
572            CREATE TEMPORARY TABLE bytes_table (
573                id SERIAL PRIMARY KEY,
574                data BYTEA NOT NULL
575            );
576        ",
577        )
578        .await
579        .unwrap();
580
581        let new_entry = NewByteEntry { data: example_bytes.clone() };
582
583        let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
584            .values(&new_entry)
585            .get_results(&mut conn)
586            .await
587            .unwrap();
588
589        assert_eq!(inserted[0].data, example_bytes);
590    }
591}