tycho_common/
hex_bytes.rs

1use std::{
2    borrow::Borrow,
3    clone::Clone,
4    fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5    ops::Deref,
6    str::FromStr,
7};
8
9#[cfg(feature = "diesel")]
10use diesel::{
11    deserialize::{self, FromSql, FromSqlRow},
12    expression::AsExpression,
13    pg::Pg,
14    serialize::{self, ToSql},
15    sql_types::Binary,
16};
17use rand::Rng;
18use serde::{Deserialize, Serialize};
19use thiserror::Error;
20
21use crate::serde_primitives::hex_bytes;
22
23/// Wrapper type around Bytes to deserialize/serialize from/to hex
24#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
25#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
26#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
27pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
28
29fn bytes_to_hex(b: &Bytes) -> String {
30    hex::encode(b.0.as_ref())
31}
32
33impl Debug for Bytes {
34    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
35        write!(f, "Bytes(0x{})", bytes_to_hex(self))
36    }
37}
38
39impl Display for Bytes {
40    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
41        write!(f, "0x{}", bytes_to_hex(self))
42    }
43}
44
45impl LowerHex for Bytes {
46    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47        write!(f, "0x{}", bytes_to_hex(self))
48    }
49}
50
51impl Bytes {
52    pub fn new() -> Self {
53        Self(bytes::Bytes::new())
54    }
55    /// This function converts the internal byte array into a `Vec<u8>`
56    ///
57    /// # Returns
58    ///
59    /// A `Vec<u8>` containing the bytes from the `Bytes` struct.
60    ///
61    /// # Example
62    ///
63    /// ```
64    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
65    /// let vec = bytes.to_vec();
66    /// assert_eq!(vec, vec![0x01, 0x02, 0x03]);
67    /// ```
68    pub fn to_vec(&self) -> Vec<u8> {
69        self.as_ref().to_vec()
70    }
71
72    /// Left-pads the byte array to the specified length with the given padding byte.
73    ///
74    /// This function creates a new `Bytes` instance by prepending the specified padding byte
75    /// to the current byte array until its total length matches the desired length.
76    ///
77    /// If the current length of the byte array is greater than or equal to the specified length,
78    /// the original byte array is returned unchanged.
79    ///
80    /// # Arguments
81    ///
82    /// * `length` - The desired total length of the resulting byte array.
83    /// * `pad_byte` - The byte value to use for padding. Commonly `0x00`.
84    ///
85    /// # Returns
86    ///
87    /// A new `Bytes` instance with the byte array left-padded to the desired length.
88    ///
89    /// # Example
90    ///
91    /// ```
92    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
93    /// let padded = bytes.lpad(6, 0x00);
94    /// assert_eq!(padded.to_vec(), vec![0x00, 0x00, 0x00, 0x01, 0x02, 0x03]);
95    /// ```
96    pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
97        let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
98        padded_vec.extend_from_slice(self.as_ref());
99
100        Bytes(bytes::Bytes::from(padded_vec))
101    }
102
103    /// Right-pads the byte array to the specified length with the given padding byte.
104    ///
105    /// This function creates a new `Bytes` instance by appending the specified padding byte
106    /// to the current byte array until its total length matches the desired length.
107    ///
108    /// If the current length of the byte array is greater than or equal to the specified length,
109    /// the original byte array is returned unchanged.
110    ///
111    /// # Arguments
112    ///
113    /// * `length` - The desired total length of the resulting byte array.
114    /// * `pad_byte` - The byte value to use for padding. Commonly `0x00`.
115    ///
116    /// # Returns
117    ///
118    /// A new `Bytes` instance with the byte array right-padded to the desired length.
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
124    /// let padded = bytes.rpad(6, 0x00);
125    /// assert_eq!(padded.to_vec(), vec![0x01, 0x02, 0x03, 0x00, 0x00, 0x00]);
126    /// ```
127    pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
128        let mut padded_vec = self.to_vec();
129        padded_vec.resize(length, pad_byte);
130
131        Bytes(bytes::Bytes::from(padded_vec))
132    }
133
134    /// Creates a `Bytes` object of the specified length, filled with zeros.
135    ///
136    /// # Arguments
137    ///
138    /// * `length` - The length of the `Bytes` object to be created.
139    ///
140    /// # Returns
141    ///
142    /// A `Bytes` object of the specified length, where each byte is set to zero.
143    ///
144    /// # Example
145    ///
146    /// ```
147    /// let b = Bytes::zero(5);
148    /// assert_eq!(b, Bytes::from(vec![0, 0, 0, 0, 0]));
149    /// ```
150    pub fn zero(length: usize) -> Bytes {
151        Bytes::from(vec![0u8; length])
152    }
153
154    /// Creates a `Bytes` object of the specified length, filled with random bytes.
155    ///
156    /// # Arguments
157    ///
158    /// * `length` - The length of the `Bytes` object to be created.
159    ///
160    /// # Returns
161    ///
162    /// A `Bytes` object of the specified length, filled with random bytes.
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// let random_bytes = Bytes::random(5);
168    /// assert_eq!(random_bytes.len(), 5);
169    /// ```
170    pub fn random(length: usize) -> Bytes {
171        let mut data = vec![0u8; length];
172        rand::thread_rng().fill(&mut data[..]);
173        Bytes::from(data)
174    }
175}
176
177impl Deref for Bytes {
178    type Target = [u8];
179
180    #[inline]
181    fn deref(&self) -> &[u8] {
182        self.as_ref()
183    }
184}
185
186impl AsRef<[u8]> for Bytes {
187    fn as_ref(&self) -> &[u8] {
188        self.0.as_ref()
189    }
190}
191
192impl Borrow<[u8]> for Bytes {
193    fn borrow(&self) -> &[u8] {
194        self.as_ref()
195    }
196}
197
198impl IntoIterator for Bytes {
199    type Item = u8;
200    type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
201
202    fn into_iter(self) -> Self::IntoIter {
203        self.0.into_iter()
204    }
205}
206
207impl<'a> IntoIterator for &'a Bytes {
208    type Item = &'a u8;
209    type IntoIter = core::slice::Iter<'a, u8>;
210
211    fn into_iter(self) -> Self::IntoIter {
212        self.as_ref().iter()
213    }
214}
215
216impl From<&[u8]> for Bytes {
217    fn from(src: &[u8]) -> Self {
218        Self(bytes::Bytes::copy_from_slice(src))
219    }
220}
221
222impl From<bytes::Bytes> for Bytes {
223    fn from(src: bytes::Bytes) -> Self {
224        Self(src)
225    }
226}
227
228impl From<Bytes> for bytes::Bytes {
229    fn from(src: Bytes) -> Self {
230        src.0
231    }
232}
233
234impl From<Vec<u8>> for Bytes {
235    fn from(src: Vec<u8>) -> Self {
236        Self(src.into())
237    }
238}
239
240impl From<Bytes> for Vec<u8> {
241    fn from(value: Bytes) -> Self {
242        value.to_vec()
243    }
244}
245
246impl<const N: usize> From<[u8; N]> for Bytes {
247    fn from(src: [u8; N]) -> Self {
248        src.to_vec().into()
249    }
250}
251
252impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
253    fn from(src: &'a [u8; N]) -> Self {
254        src.to_vec().into()
255    }
256}
257
258impl PartialEq<[u8]> for Bytes {
259    fn eq(&self, other: &[u8]) -> bool {
260        self.as_ref() == other
261    }
262}
263
264impl PartialEq<Bytes> for [u8] {
265    fn eq(&self, other: &Bytes) -> bool {
266        *other == *self
267    }
268}
269
270impl PartialEq<Vec<u8>> for Bytes {
271    fn eq(&self, other: &Vec<u8>) -> bool {
272        self.as_ref() == &other[..]
273    }
274}
275
276impl PartialEq<Bytes> for Vec<u8> {
277    fn eq(&self, other: &Bytes) -> bool {
278        *other == *self
279    }
280}
281
282impl PartialEq<bytes::Bytes> for Bytes {
283    fn eq(&self, other: &bytes::Bytes) -> bool {
284        other == self.as_ref()
285    }
286}
287
288#[derive(Debug, Clone, Error)]
289#[error("Failed to parse bytes: {0}")]
290pub struct ParseBytesError(String);
291
292impl FromStr for Bytes {
293    type Err = ParseBytesError;
294
295    fn from_str(value: &str) -> Result<Self, Self::Err> {
296        if let Some(value) = value.strip_prefix("0x") {
297            hex::decode(value)
298        } else {
299            hex::decode(value)
300        }
301        .map(Into::into)
302        .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
303    }
304}
305
306impl From<&str> for Bytes {
307    fn from(value: &str) -> Self {
308        value.parse().unwrap()
309    }
310}
311
312#[cfg(feature = "diesel")]
313impl ToSql<Binary, Pg> for Bytes {
314    fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
315        let bytes_slice: &[u8] = &self.0;
316        <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
317    }
318}
319
320#[cfg(feature = "diesel")]
321impl FromSql<Binary, Pg> for Bytes {
322    fn from_sql(
323        bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
324    ) -> deserialize::Result<Self> {
325        let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
326        Ok(Bytes(bytes::Bytes::from(byte_vec)))
327    }
328}
329
330macro_rules! impl_from_uint_for_bytes {
331    ($($t:ty),*) => {
332        $(
333            impl From<$t> for Bytes {
334                fn from(src: $t) -> Self {
335                    let size = std::mem::size_of::<$t>();
336                    let mut buf = vec![0u8; size];
337                    buf.copy_from_slice(&src.to_be_bytes());
338
339                    Self(bytes::Bytes::from(buf))
340                }
341            }
342        )*
343    };
344}
345
346impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
347
348macro_rules! impl_from_bytes_for_uint {
349    ($($t:ty),*) => {
350        $(
351            impl From<Bytes> for $t {
352                fn from(src: Bytes) -> Self {
353                    let bytes_slice = src.as_ref();
354
355                    // Create an array with zeros.
356                    let mut buf = [0u8; std::mem::size_of::<$t>()];
357
358                    // Copy bytes from `bytes_slice` to the end of `buf` to maintain big-endian order.
359                    buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
360
361                    // Convert to the integer type using big-endian.
362                    <$t>::from_be_bytes(buf)
363                }
364            }
365        )*
366    };
367}
368
369impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
370
371macro_rules! impl_from_bytes_for_signed_int {
372    ($($t:ty),*) => {
373        $(
374            impl From<Bytes> for $t {
375                fn from(src: Bytes) -> Self {
376                    let bytes_slice = src.as_ref();
377
378                    // Create an array with zeros or ones for negative numbers.
379                    let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
380                        [0xFFu8; std::mem::size_of::<$t>()] // Sign-extend with 0xFF for negative numbers.
381                    } else {
382                        [0x00u8; std::mem::size_of::<$t>()] // Fill with 0x00 for positive numbers.
383                    };
384
385                    // Copy bytes from `bytes_slice` to the end of `buf` to maintain big-endian order.
386                    buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
387
388                    // Convert to the signed integer type using big-endian.
389                    <$t>::from_be_bytes(buf)
390                }
391            }
392        )*
393    };
394}
395
396impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_from_bytes() {
404        let b = bytes::Bytes::from("0123456789abcdef");
405        let wrapped_b = Bytes::from(b.clone());
406        let expected = Bytes(b);
407
408        assert_eq!(wrapped_b, expected);
409    }
410
411    #[test]
412    fn test_from_slice() {
413        let arr = [1, 35, 69, 103, 137, 171, 205, 239];
414        let b = Bytes::from(&arr);
415        let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
416
417        assert_eq!(b, expected);
418    }
419
420    #[test]
421    fn hex_formatting() {
422        let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
423        let expected = String::from("0x0123456789abcdef");
424        assert_eq!(format!("{b:x}"), expected);
425        assert_eq!(format!("{b}"), expected);
426    }
427
428    #[test]
429    fn test_from_str() {
430        let b = Bytes::from_str("0x1213");
431        assert!(b.is_ok());
432        let b = b.unwrap();
433        assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
434
435        let b = Bytes::from_str("1213");
436        let b = b.unwrap();
437        assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
438    }
439
440    #[test]
441    fn test_debug_formatting() {
442        let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
443        assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
444        assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
445    }
446
447    #[test]
448    fn test_to_vec() {
449        let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
450        let b = Bytes::from(vec.clone());
451
452        assert_eq!(b.to_vec(), vec);
453    }
454
455    #[test]
456    fn test_vec_partialeq() {
457        let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
458        let b = Bytes::from(vec.clone());
459        assert_eq!(b, vec);
460        assert_eq!(vec, b);
461
462        let wrong_vec = vec![1, 3, 52, 137];
463        assert_ne!(b, wrong_vec);
464        assert_ne!(wrong_vec, b);
465    }
466
467    #[test]
468    fn test_bytes_partialeq() {
469        let b = bytes::Bytes::from("0123456789abcdef");
470        let wrapped_b = Bytes::from(b.clone());
471        assert_eq!(wrapped_b, b);
472
473        let wrong_b = bytes::Bytes::from("0123absd");
474        assert_ne!(wrong_b, b);
475    }
476
477    #[test]
478    fn test_u128_from_bytes() {
479        let data = Bytes::from(vec![4, 3, 2, 1]);
480        let result: u128 = u128::from(data.clone());
481        assert_eq!(result, u128::from_str("67305985").unwrap());
482    }
483
484    #[test]
485    fn test_i128_from_bytes() {
486        let data = Bytes::from(vec![4, 3, 2, 1]);
487        let result: i128 = i128::from(data.clone());
488        assert_eq!(result, i128::from_str("67305985").unwrap());
489    }
490
491    #[test]
492    fn test_i32_from_bytes() {
493        let data = Bytes::from(vec![4, 3, 2, 1]);
494        let result: i32 = i32::from(data);
495        assert_eq!(result, i32::from_str("67305985").unwrap());
496    }
497}
498
499#[cfg(feature = "diesel")]
500#[cfg(test)]
501mod diesel_tests {
502    use diesel::{insert_into, table, Insertable, Queryable};
503    use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
504
505    use super::*;
506
507    async fn setup_db() -> AsyncPgConnection {
508        let db_url = std::env::var("DATABASE_URL").unwrap();
509        let mut conn = AsyncPgConnection::establish(&db_url)
510            .await
511            .unwrap();
512        conn.begin_test_transaction()
513            .await
514            .unwrap();
515        conn
516    }
517
518    #[tokio::test]
519    async fn test_bytes_db_round_trip() {
520        table! {
521            bytes_table (id) {
522                id -> Int4,
523                data -> Binary,
524            }
525        }
526
527        #[derive(Insertable)]
528        #[diesel(table_name = bytes_table)]
529        struct NewByteEntry {
530            data: Bytes,
531        }
532
533        #[derive(Queryable, PartialEq)]
534        struct ByteEntry {
535            id: i32,
536            data: Bytes,
537        }
538
539        let mut conn = setup_db().await;
540        let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
541
542        conn.batch_execute(
543            r"
544            CREATE TEMPORARY TABLE bytes_table (
545                id SERIAL PRIMARY KEY,
546                data BYTEA NOT NULL
547            );
548        ",
549        )
550        .await
551        .unwrap();
552
553        let new_entry = NewByteEntry { data: example_bytes.clone() };
554
555        let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
556            .values(&new_entry)
557            .get_results(&mut conn)
558            .await
559            .unwrap();
560
561        assert_eq!(inserted[0].data, example_bytes);
562    }
563}