1use std::{
2 borrow::Borrow,
3 clone::Clone,
4 fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5 ops::Deref,
6 str::FromStr,
7};
8
9#[cfg(feature = "diesel")]
10use diesel::{
11 deserialize::{self, FromSql, FromSqlRow},
12 expression::AsExpression,
13 pg::Pg,
14 serialize::{self, ToSql},
15 sql_types::Binary,
16};
17use rand::Rng;
18use serde::{Deserialize, Serialize};
19use thiserror::Error;
20
21use crate::serde_primitives::hex_bytes;
22
23#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
25#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
26#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
27pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
28
29fn bytes_to_hex(b: &Bytes) -> String {
30 hex::encode(b.0.as_ref())
31}
32
33impl Debug for Bytes {
34 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
35 write!(f, "Bytes(0x{})", bytes_to_hex(self))
36 }
37}
38
39impl Display for Bytes {
40 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
41 write!(f, "0x{}", bytes_to_hex(self))
42 }
43}
44
45impl LowerHex for Bytes {
46 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47 write!(f, "0x{}", bytes_to_hex(self))
48 }
49}
50
51impl Bytes {
52 pub fn new() -> Self {
53 Self(bytes::Bytes::new())
54 }
55 pub fn to_vec(&self) -> Vec<u8> {
69 self.as_ref().to_vec()
70 }
71
72 pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
97 let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
98 padded_vec.extend_from_slice(self.as_ref());
99
100 Bytes(bytes::Bytes::from(padded_vec))
101 }
102
103 pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
128 let mut padded_vec = self.to_vec();
129 padded_vec.resize(length, pad_byte);
130
131 Bytes(bytes::Bytes::from(padded_vec))
132 }
133
134 pub fn zero(length: usize) -> Bytes {
151 Bytes::from(vec![0u8; length])
152 }
153
154 pub fn random(length: usize) -> Bytes {
171 let mut data = vec![0u8; length];
172 rand::thread_rng().fill(&mut data[..]);
173 Bytes::from(data)
174 }
175
176 pub fn is_zero(&self) -> bool {
189 self.as_ref().iter().all(|b| *b == 0)
190 }
191}
192
193impl Deref for Bytes {
194 type Target = [u8];
195
196 #[inline]
197 fn deref(&self) -> &[u8] {
198 self.as_ref()
199 }
200}
201
202impl AsRef<[u8]> for Bytes {
203 fn as_ref(&self) -> &[u8] {
204 self.0.as_ref()
205 }
206}
207
208impl Borrow<[u8]> for Bytes {
209 fn borrow(&self) -> &[u8] {
210 self.as_ref()
211 }
212}
213
214impl IntoIterator for Bytes {
215 type Item = u8;
216 type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
217
218 fn into_iter(self) -> Self::IntoIter {
219 self.0.into_iter()
220 }
221}
222
223impl<'a> IntoIterator for &'a Bytes {
224 type Item = &'a u8;
225 type IntoIter = core::slice::Iter<'a, u8>;
226
227 fn into_iter(self) -> Self::IntoIter {
228 self.as_ref().iter()
229 }
230}
231
232impl From<&[u8]> for Bytes {
233 fn from(src: &[u8]) -> Self {
234 Self(bytes::Bytes::copy_from_slice(src))
235 }
236}
237
238impl From<bytes::Bytes> for Bytes {
239 fn from(src: bytes::Bytes) -> Self {
240 Self(src)
241 }
242}
243
244impl From<Bytes> for bytes::Bytes {
245 fn from(src: Bytes) -> Self {
246 src.0
247 }
248}
249
250impl From<Vec<u8>> for Bytes {
251 fn from(src: Vec<u8>) -> Self {
252 Self(src.into())
253 }
254}
255
256impl From<Bytes> for Vec<u8> {
257 fn from(value: Bytes) -> Self {
258 value.to_vec()
259 }
260}
261
262impl<const N: usize> From<[u8; N]> for Bytes {
263 fn from(src: [u8; N]) -> Self {
264 src.to_vec().into()
265 }
266}
267
268impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
269 fn from(src: &'a [u8; N]) -> Self {
270 src.to_vec().into()
271 }
272}
273
274impl PartialEq<[u8]> for Bytes {
275 fn eq(&self, other: &[u8]) -> bool {
276 self.as_ref() == other
277 }
278}
279
280impl PartialEq<Bytes> for [u8] {
281 fn eq(&self, other: &Bytes) -> bool {
282 *other == *self
283 }
284}
285
286impl PartialEq<Vec<u8>> for Bytes {
287 fn eq(&self, other: &Vec<u8>) -> bool {
288 self.as_ref() == &other[..]
289 }
290}
291
292impl PartialEq<Bytes> for Vec<u8> {
293 fn eq(&self, other: &Bytes) -> bool {
294 *other == *self
295 }
296}
297
298impl PartialEq<bytes::Bytes> for Bytes {
299 fn eq(&self, other: &bytes::Bytes) -> bool {
300 other == self.as_ref()
301 }
302}
303
304#[derive(Debug, Clone, Error)]
305#[error("Failed to parse bytes: {0}")]
306pub struct ParseBytesError(String);
307
308impl FromStr for Bytes {
309 type Err = ParseBytesError;
310
311 fn from_str(value: &str) -> Result<Self, Self::Err> {
312 if let Some(value) = value.strip_prefix("0x") {
313 hex::decode(value)
314 } else {
315 hex::decode(value)
316 }
317 .map(Into::into)
318 .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
319 }
320}
321
322impl From<&str> for Bytes {
323 fn from(value: &str) -> Self {
324 value.parse().unwrap()
325 }
326}
327
328#[cfg(feature = "diesel")]
329impl ToSql<Binary, Pg> for Bytes {
330 fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
331 let bytes_slice: &[u8] = &self.0;
332 <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
333 }
334}
335
336#[cfg(feature = "diesel")]
337impl FromSql<Binary, Pg> for Bytes {
338 fn from_sql(
339 bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
340 ) -> deserialize::Result<Self> {
341 let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
342 Ok(Bytes(bytes::Bytes::from(byte_vec)))
343 }
344}
345
346macro_rules! impl_from_uint_for_bytes {
347 ($($t:ty),*) => {
348 $(
349 impl From<$t> for Bytes {
350 fn from(src: $t) -> Self {
351 let size = std::mem::size_of::<$t>();
352 let mut buf = vec![0u8; size];
353 buf.copy_from_slice(&src.to_be_bytes());
354
355 Self(bytes::Bytes::from(buf))
356 }
357 }
358 )*
359 };
360}
361
362impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
363
364macro_rules! impl_from_bytes_for_uint {
365 ($($t:ty),*) => {
366 $(
367 impl From<Bytes> for $t {
368 fn from(src: Bytes) -> Self {
369 let bytes_slice = src.as_ref();
370
371 let mut buf = [0u8; std::mem::size_of::<$t>()];
373
374 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
376
377 <$t>::from_be_bytes(buf)
379 }
380 }
381 )*
382 };
383}
384
385impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
386
387macro_rules! impl_from_bytes_for_signed_int {
388 ($($t:ty),*) => {
389 $(
390 impl From<Bytes> for $t {
391 fn from(src: Bytes) -> Self {
392 let bytes_slice = src.as_ref();
393
394 let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
396 [0xFFu8; std::mem::size_of::<$t>()] } else {
398 [0x00u8; std::mem::size_of::<$t>()] };
400
401 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
403
404 <$t>::from_be_bytes(buf)
406 }
407 }
408 )*
409 };
410}
411
412impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_from_bytes() {
420 let b = bytes::Bytes::from("0123456789abcdef");
421 let wrapped_b = Bytes::from(b.clone());
422 let expected = Bytes(b);
423
424 assert_eq!(wrapped_b, expected);
425 }
426
427 #[test]
428 fn test_from_slice() {
429 let arr = [1, 35, 69, 103, 137, 171, 205, 239];
430 let b = Bytes::from(&arr);
431 let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
432
433 assert_eq!(b, expected);
434 }
435
436 #[test]
437 fn hex_formatting() {
438 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
439 let expected = String::from("0x0123456789abcdef");
440 assert_eq!(format!("{b:x}"), expected);
441 assert_eq!(format!("{b}"), expected);
442 }
443
444 #[test]
445 fn test_from_str() {
446 let b = Bytes::from_str("0x1213");
447 assert!(b.is_ok());
448 let b = b.unwrap();
449 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
450
451 let b = Bytes::from_str("1213");
452 let b = b.unwrap();
453 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
454 }
455
456 #[test]
457 fn test_debug_formatting() {
458 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
459 assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
460 assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
461 }
462
463 #[test]
464 fn test_to_vec() {
465 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
466 let b = Bytes::from(vec.clone());
467
468 assert_eq!(b.to_vec(), vec);
469 }
470
471 #[test]
472 fn test_vec_partialeq() {
473 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
474 let b = Bytes::from(vec.clone());
475 assert_eq!(b, vec);
476 assert_eq!(vec, b);
477
478 let wrong_vec = vec![1, 3, 52, 137];
479 assert_ne!(b, wrong_vec);
480 assert_ne!(wrong_vec, b);
481 }
482
483 #[test]
484 fn test_bytes_partialeq() {
485 let b = bytes::Bytes::from("0123456789abcdef");
486 let wrapped_b = Bytes::from(b.clone());
487 assert_eq!(wrapped_b, b);
488
489 let wrong_b = bytes::Bytes::from("0123absd");
490 assert_ne!(wrong_b, b);
491 }
492
493 #[test]
494 fn test_u128_from_bytes() {
495 let data = Bytes::from(vec![4, 3, 2, 1]);
496 let result: u128 = u128::from(data.clone());
497 assert_eq!(result, u128::from_str("67305985").unwrap());
498 }
499
500 #[test]
501 fn test_i128_from_bytes() {
502 let data = Bytes::from(vec![4, 3, 2, 1]);
503 let result: i128 = i128::from(data.clone());
504 assert_eq!(result, i128::from_str("67305985").unwrap());
505 }
506
507 #[test]
508 fn test_i32_from_bytes() {
509 let data = Bytes::from(vec![4, 3, 2, 1]);
510 let result: i32 = i32::from(data);
511 assert_eq!(result, i32::from_str("67305985").unwrap());
512 }
513}
514
515#[cfg(feature = "diesel")]
516#[cfg(test)]
517mod diesel_tests {
518 use diesel::{insert_into, table, Insertable, Queryable};
519 use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
520
521 use super::*;
522
523 async fn setup_db() -> AsyncPgConnection {
524 let db_url = std::env::var("DATABASE_URL").unwrap();
525 let mut conn = AsyncPgConnection::establish(&db_url)
526 .await
527 .unwrap();
528 conn.begin_test_transaction()
529 .await
530 .unwrap();
531 conn
532 }
533
534 #[tokio::test]
535 async fn test_bytes_db_round_trip() {
536 table! {
537 bytes_table (id) {
538 id -> Int4,
539 data -> Binary,
540 }
541 }
542
543 #[derive(Insertable)]
544 #[diesel(table_name = bytes_table)]
545 struct NewByteEntry {
546 data: Bytes,
547 }
548
549 #[derive(Queryable, PartialEq)]
550 struct ByteEntry {
551 id: i32,
552 data: Bytes,
553 }
554
555 let mut conn = setup_db().await;
556 let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
557
558 conn.batch_execute(
559 r"
560 CREATE TEMPORARY TABLE bytes_table (
561 id SERIAL PRIMARY KEY,
562 data BYTEA NOT NULL
563 );
564 ",
565 )
566 .await
567 .unwrap();
568
569 let new_entry = NewByteEntry { data: example_bytes.clone() };
570
571 let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
572 .values(&new_entry)
573 .get_results(&mut conn)
574 .await
575 .unwrap();
576
577 assert_eq!(inserted[0].data, example_bytes);
578 }
579}