1use std::{
2 borrow::Borrow,
3 clone::Clone,
4 fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5 ops::Deref,
6 str::FromStr,
7};
8
9use deepsize::{Context, DeepSizeOf};
10#[cfg(feature = "diesel")]
11use diesel::{
12 deserialize::{self, FromSql, FromSqlRow},
13 expression::AsExpression,
14 pg::Pg,
15 serialize::{self, ToSql},
16 sql_types::Binary,
17};
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use crate::serde_primitives::hex_bytes;
23
24#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
26#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
27#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
28pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
29
30impl DeepSizeOf for Bytes {
31 fn deep_size_of_children(&self, _ctx: &mut Context) -> usize {
32 self.0.len()
35 }
36}
37
38fn bytes_to_hex(b: &Bytes) -> String {
39 hex::encode(b.0.as_ref())
40}
41
42impl Debug for Bytes {
43 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
44 write!(f, "Bytes(0x{})", bytes_to_hex(self))
45 }
46}
47
48impl Display for Bytes {
49 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
50 write!(f, "0x{}", bytes_to_hex(self))
51 }
52}
53
54impl LowerHex for Bytes {
55 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
56 write!(f, "0x{}", bytes_to_hex(self))
57 }
58}
59
60impl Bytes {
61 pub fn new() -> Self {
62 Self(bytes::Bytes::new())
63 }
64 pub fn to_vec(&self) -> Vec<u8> {
78 self.as_ref().to_vec()
79 }
80
81 pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
106 let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
107 padded_vec.extend_from_slice(self.as_ref());
108
109 Bytes(bytes::Bytes::from(padded_vec))
110 }
111
112 pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
137 let mut padded_vec = self.to_vec();
138 padded_vec.resize(length, pad_byte);
139
140 Bytes(bytes::Bytes::from(padded_vec))
141 }
142
143 pub fn zero(length: usize) -> Bytes {
160 Bytes::from(vec![0u8; length])
161 }
162
163 pub fn random(length: usize) -> Bytes {
180 let mut data = vec![0u8; length];
181 rand::thread_rng().fill(&mut data[..]);
182 Bytes::from(data)
183 }
184
185 pub fn is_zero(&self) -> bool {
198 self.as_ref().iter().all(|b| *b == 0)
199 }
200}
201
202impl Deref for Bytes {
203 type Target = [u8];
204
205 #[inline]
206 fn deref(&self) -> &[u8] {
207 self.as_ref()
208 }
209}
210
211impl AsRef<[u8]> for Bytes {
212 fn as_ref(&self) -> &[u8] {
213 self.0.as_ref()
214 }
215}
216
217impl Borrow<[u8]> for Bytes {
218 fn borrow(&self) -> &[u8] {
219 self.as_ref()
220 }
221}
222
223impl IntoIterator for Bytes {
224 type Item = u8;
225 type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
226
227 fn into_iter(self) -> Self::IntoIter {
228 self.0.into_iter()
229 }
230}
231
232impl<'a> IntoIterator for &'a Bytes {
233 type Item = &'a u8;
234 type IntoIter = core::slice::Iter<'a, u8>;
235
236 fn into_iter(self) -> Self::IntoIter {
237 self.as_ref().iter()
238 }
239}
240
241impl From<&[u8]> for Bytes {
242 fn from(src: &[u8]) -> Self {
243 Self(bytes::Bytes::copy_from_slice(src))
244 }
245}
246
247impl From<bytes::Bytes> for Bytes {
248 fn from(src: bytes::Bytes) -> Self {
249 Self(src)
250 }
251}
252
253impl From<Bytes> for bytes::Bytes {
254 fn from(src: Bytes) -> Self {
255 src.0
256 }
257}
258
259impl From<Vec<u8>> for Bytes {
260 fn from(src: Vec<u8>) -> Self {
261 Self(src.into())
262 }
263}
264
265impl From<Bytes> for Vec<u8> {
266 fn from(value: Bytes) -> Self {
267 value.to_vec()
268 }
269}
270
271impl<const N: usize> From<[u8; N]> for Bytes {
272 fn from(src: [u8; N]) -> Self {
273 src.to_vec().into()
274 }
275}
276
277impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
278 fn from(src: &'a [u8; N]) -> Self {
279 src.to_vec().into()
280 }
281}
282
283impl PartialEq<[u8]> for Bytes {
284 fn eq(&self, other: &[u8]) -> bool {
285 self.as_ref() == other
286 }
287}
288
289impl PartialEq<Bytes> for [u8] {
290 fn eq(&self, other: &Bytes) -> bool {
291 *other == *self
292 }
293}
294
295impl PartialEq<Vec<u8>> for Bytes {
296 fn eq(&self, other: &Vec<u8>) -> bool {
297 self.as_ref() == &other[..]
298 }
299}
300
301impl PartialEq<Bytes> for Vec<u8> {
302 fn eq(&self, other: &Bytes) -> bool {
303 *other == *self
304 }
305}
306
307impl PartialEq<bytes::Bytes> for Bytes {
308 fn eq(&self, other: &bytes::Bytes) -> bool {
309 other == self.as_ref()
310 }
311}
312
313#[derive(Debug, Clone, Error)]
314#[error("Failed to parse bytes: {0}")]
315pub struct ParseBytesError(String);
316
317impl FromStr for Bytes {
318 type Err = ParseBytesError;
319
320 fn from_str(value: &str) -> Result<Self, Self::Err> {
321 if let Some(value) = value.strip_prefix("0x") {
322 hex::decode(value)
323 } else {
324 hex::decode(value)
325 }
326 .map(Into::into)
327 .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
328 }
329}
330
331impl From<&str> for Bytes {
332 fn from(value: &str) -> Self {
333 value.parse().unwrap()
334 }
335}
336
337#[cfg(feature = "diesel")]
338impl ToSql<Binary, Pg> for Bytes {
339 fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
340 let bytes_slice: &[u8] = &self.0;
341 <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
342 }
343}
344
345#[cfg(feature = "diesel")]
346impl FromSql<Binary, Pg> for Bytes {
347 fn from_sql(
348 bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
349 ) -> deserialize::Result<Self> {
350 let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
351 Ok(Bytes(bytes::Bytes::from(byte_vec)))
352 }
353}
354
355macro_rules! impl_from_uint_for_bytes {
356 ($($t:ty),*) => {
357 $(
358 impl From<$t> for Bytes {
359 fn from(src: $t) -> Self {
360 let size = std::mem::size_of::<$t>();
361 let mut buf = vec![0u8; size];
362 buf.copy_from_slice(&src.to_be_bytes());
363
364 Self(bytes::Bytes::from(buf))
365 }
366 }
367 )*
368 };
369}
370
371impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
372
373macro_rules! impl_from_bytes_for_uint {
374 ($($t:ty),*) => {
375 $(
376 impl From<Bytes> for $t {
377 fn from(src: Bytes) -> Self {
378 let bytes_slice = src.as_ref();
379
380 let mut buf = [0u8; std::mem::size_of::<$t>()];
382
383 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
385
386 <$t>::from_be_bytes(buf)
388 }
389 }
390 )*
391 };
392}
393
394impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
395
396macro_rules! impl_from_bytes_for_signed_int {
397 ($($t:ty),*) => {
398 $(
399 impl From<Bytes> for $t {
400 fn from(src: Bytes) -> Self {
401 let bytes_slice = src.as_ref();
402
403 let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
405 [0xFFu8; std::mem::size_of::<$t>()] } else {
407 [0x00u8; std::mem::size_of::<$t>()] };
409
410 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
412
413 <$t>::from_be_bytes(buf)
415 }
416 }
417 )*
418 };
419}
420
421impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_from_bytes() {
429 let b = bytes::Bytes::from("0123456789abcdef");
430 let wrapped_b = Bytes::from(b.clone());
431 let expected = Bytes(b);
432
433 assert_eq!(wrapped_b, expected);
434 }
435
436 #[test]
437 fn test_from_slice() {
438 let arr = [1, 35, 69, 103, 137, 171, 205, 239];
439 let b = Bytes::from(&arr);
440 let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
441
442 assert_eq!(b, expected);
443 }
444
445 #[test]
446 fn hex_formatting() {
447 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
448 let expected = String::from("0x0123456789abcdef");
449 assert_eq!(format!("{b:x}"), expected);
450 assert_eq!(format!("{b}"), expected);
451 }
452
453 #[test]
454 fn test_from_str() {
455 let b = Bytes::from_str("0x1213");
456 assert!(b.is_ok());
457 let b = b.unwrap();
458 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
459
460 let b = Bytes::from_str("1213");
461 let b = b.unwrap();
462 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
463 }
464
465 #[test]
466 fn test_debug_formatting() {
467 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
468 assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
469 assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
470 }
471
472 #[test]
473 fn test_to_vec() {
474 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
475 let b = Bytes::from(vec.clone());
476
477 assert_eq!(b.to_vec(), vec);
478 }
479
480 #[test]
481 fn test_vec_partialeq() {
482 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
483 let b = Bytes::from(vec.clone());
484 assert_eq!(b, vec);
485 assert_eq!(vec, b);
486
487 let wrong_vec = vec![1, 3, 52, 137];
488 assert_ne!(b, wrong_vec);
489 assert_ne!(wrong_vec, b);
490 }
491
492 #[test]
493 fn test_bytes_partialeq() {
494 let b = bytes::Bytes::from("0123456789abcdef");
495 let wrapped_b = Bytes::from(b.clone());
496 assert_eq!(wrapped_b, b);
497
498 let wrong_b = bytes::Bytes::from("0123absd");
499 assert_ne!(wrong_b, b);
500 }
501
502 #[test]
503 fn test_u128_from_bytes() {
504 let data = Bytes::from(vec![4, 3, 2, 1]);
505 let result: u128 = u128::from(data.clone());
506 assert_eq!(result, u128::from_str("67305985").unwrap());
507 }
508
509 #[test]
510 fn test_i128_from_bytes() {
511 let data = Bytes::from(vec![4, 3, 2, 1]);
512 let result: i128 = i128::from(data.clone());
513 assert_eq!(result, i128::from_str("67305985").unwrap());
514 }
515
516 #[test]
517 fn test_i32_from_bytes() {
518 let data = Bytes::from(vec![4, 3, 2, 1]);
519 let result: i32 = i32::from(data);
520 assert_eq!(result, i32::from_str("67305985").unwrap());
521 }
522}
523
524#[cfg(feature = "diesel")]
525#[cfg(test)]
526mod diesel_tests {
527 use diesel::{insert_into, table, Insertable, Queryable};
528 use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
529
530 use super::*;
531
532 async fn setup_db() -> AsyncPgConnection {
533 let db_url = std::env::var("DATABASE_URL").unwrap();
534 let mut conn = AsyncPgConnection::establish(&db_url)
535 .await
536 .unwrap();
537 conn.begin_test_transaction()
538 .await
539 .unwrap();
540 conn
541 }
542
543 #[tokio::test]
544 async fn test_bytes_db_round_trip() {
545 table! {
546 bytes_table (id) {
547 id -> Int4,
548 data -> Binary,
549 }
550 }
551
552 #[derive(Insertable)]
553 #[diesel(table_name = bytes_table)]
554 struct NewByteEntry {
555 data: Bytes,
556 }
557
558 #[derive(Queryable, PartialEq)]
559 struct ByteEntry {
560 id: i32,
561 data: Bytes,
562 }
563
564 let mut conn = setup_db().await;
565 let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
566
567 conn.batch_execute(
568 r"
569 CREATE TEMPORARY TABLE bytes_table (
570 id SERIAL PRIMARY KEY,
571 data BYTEA NOT NULL
572 );
573 ",
574 )
575 .await
576 .unwrap();
577
578 let new_entry = NewByteEntry { data: example_bytes.clone() };
579
580 let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
581 .values(&new_entry)
582 .get_results(&mut conn)
583 .await
584 .unwrap();
585
586 assert_eq!(inserted[0].data, example_bytes);
587 }
588}