1use std::{
2 borrow::Borrow,
3 clone::Clone,
4 fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5 ops::Deref,
6 str::FromStr,
7};
8
9#[cfg(feature = "diesel")]
10use diesel::{
11 deserialize::{self, FromSql, FromSqlRow},
12 expression::AsExpression,
13 pg::Pg,
14 serialize::{self, ToSql},
15 sql_types::Binary,
16};
17use rand::Rng;
18use serde::{Deserialize, Serialize};
19use thiserror::Error;
20
21use crate::serde_primitives::hex_bytes;
22
23#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
25#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
26#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
27pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
28
29fn bytes_to_hex(b: &Bytes) -> String {
30 hex::encode(b.0.as_ref())
31}
32
33impl Debug for Bytes {
34 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
35 write!(f, "Bytes(0x{})", bytes_to_hex(self))
36 }
37}
38
39impl Display for Bytes {
40 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
41 write!(f, "0x{}", bytes_to_hex(self))
42 }
43}
44
45impl LowerHex for Bytes {
46 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47 write!(f, "0x{}", bytes_to_hex(self))
48 }
49}
50
51impl Bytes {
52 pub fn new() -> Self {
53 Self(bytes::Bytes::new())
54 }
55 pub fn to_vec(&self) -> Vec<u8> {
69 self.as_ref().to_vec()
70 }
71
72 pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
97 let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
98 padded_vec.extend_from_slice(self.as_ref());
99
100 Bytes(bytes::Bytes::from(padded_vec))
101 }
102
103 pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
128 let mut padded_vec = self.to_vec();
129 padded_vec.resize(length, pad_byte);
130
131 Bytes(bytes::Bytes::from(padded_vec))
132 }
133
134 pub fn zero(length: usize) -> Bytes {
151 Bytes::from(vec![0u8; length])
152 }
153
154 pub fn random(length: usize) -> Bytes {
171 let mut data = vec![0u8; length];
172 rand::thread_rng().fill(&mut data[..]);
173 Bytes::from(data)
174 }
175}
176
177impl Deref for Bytes {
178 type Target = [u8];
179
180 #[inline]
181 fn deref(&self) -> &[u8] {
182 self.as_ref()
183 }
184}
185
186impl AsRef<[u8]> for Bytes {
187 fn as_ref(&self) -> &[u8] {
188 self.0.as_ref()
189 }
190}
191
192impl Borrow<[u8]> for Bytes {
193 fn borrow(&self) -> &[u8] {
194 self.as_ref()
195 }
196}
197
198impl IntoIterator for Bytes {
199 type Item = u8;
200 type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
201
202 fn into_iter(self) -> Self::IntoIter {
203 self.0.into_iter()
204 }
205}
206
207impl<'a> IntoIterator for &'a Bytes {
208 type Item = &'a u8;
209 type IntoIter = core::slice::Iter<'a, u8>;
210
211 fn into_iter(self) -> Self::IntoIter {
212 self.as_ref().iter()
213 }
214}
215
216impl From<&[u8]> for Bytes {
217 fn from(src: &[u8]) -> Self {
218 Self(bytes::Bytes::copy_from_slice(src))
219 }
220}
221
222impl From<bytes::Bytes> for Bytes {
223 fn from(src: bytes::Bytes) -> Self {
224 Self(src)
225 }
226}
227
228impl From<Bytes> for bytes::Bytes {
229 fn from(src: Bytes) -> Self {
230 src.0
231 }
232}
233
234impl From<Vec<u8>> for Bytes {
235 fn from(src: Vec<u8>) -> Self {
236 Self(src.into())
237 }
238}
239
240impl From<Bytes> for Vec<u8> {
241 fn from(value: Bytes) -> Self {
242 value.to_vec()
243 }
244}
245
246impl<const N: usize> From<[u8; N]> for Bytes {
247 fn from(src: [u8; N]) -> Self {
248 src.to_vec().into()
249 }
250}
251
252impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
253 fn from(src: &'a [u8; N]) -> Self {
254 src.to_vec().into()
255 }
256}
257
258impl PartialEq<[u8]> for Bytes {
259 fn eq(&self, other: &[u8]) -> bool {
260 self.as_ref() == other
261 }
262}
263
264impl PartialEq<Bytes> for [u8] {
265 fn eq(&self, other: &Bytes) -> bool {
266 *other == *self
267 }
268}
269
270impl PartialEq<Vec<u8>> for Bytes {
271 fn eq(&self, other: &Vec<u8>) -> bool {
272 self.as_ref() == &other[..]
273 }
274}
275
276impl PartialEq<Bytes> for Vec<u8> {
277 fn eq(&self, other: &Bytes) -> bool {
278 *other == *self
279 }
280}
281
282impl PartialEq<bytes::Bytes> for Bytes {
283 fn eq(&self, other: &bytes::Bytes) -> bool {
284 other == self.as_ref()
285 }
286}
287
288#[derive(Debug, Clone, Error)]
289#[error("Failed to parse bytes: {0}")]
290pub struct ParseBytesError(String);
291
292impl FromStr for Bytes {
293 type Err = ParseBytesError;
294
295 fn from_str(value: &str) -> Result<Self, Self::Err> {
296 if let Some(value) = value.strip_prefix("0x") {
297 hex::decode(value)
298 } else {
299 hex::decode(value)
300 }
301 .map(Into::into)
302 .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
303 }
304}
305
306impl From<&str> for Bytes {
307 fn from(value: &str) -> Self {
308 value.parse().unwrap()
309 }
310}
311
312#[cfg(feature = "diesel")]
313impl ToSql<Binary, Pg> for Bytes {
314 fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
315 let bytes_slice: &[u8] = &self.0;
316 <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
317 }
318}
319
320#[cfg(feature = "diesel")]
321impl FromSql<Binary, Pg> for Bytes {
322 fn from_sql(
323 bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
324 ) -> deserialize::Result<Self> {
325 let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
326 Ok(Bytes(bytes::Bytes::from(byte_vec)))
327 }
328}
329
330macro_rules! impl_from_uint_for_bytes {
331 ($($t:ty),*) => {
332 $(
333 impl From<$t> for Bytes {
334 fn from(src: $t) -> Self {
335 let size = std::mem::size_of::<$t>();
336 let mut buf = vec![0u8; size];
337 buf.copy_from_slice(&src.to_be_bytes());
338
339 Self(bytes::Bytes::from(buf))
340 }
341 }
342 )*
343 };
344}
345
346impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
347
348macro_rules! impl_from_bytes_for_uint {
349 ($($t:ty),*) => {
350 $(
351 impl From<Bytes> for $t {
352 fn from(src: Bytes) -> Self {
353 let bytes_slice = src.as_ref();
354
355 let mut buf = [0u8; std::mem::size_of::<$t>()];
357
358 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
360
361 <$t>::from_be_bytes(buf)
363 }
364 }
365 )*
366 };
367}
368
369impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
370
371macro_rules! impl_from_bytes_for_signed_int {
372 ($($t:ty),*) => {
373 $(
374 impl From<Bytes> for $t {
375 fn from(src: Bytes) -> Self {
376 let bytes_slice = src.as_ref();
377
378 let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
380 [0xFFu8; std::mem::size_of::<$t>()] } else {
382 [0x00u8; std::mem::size_of::<$t>()] };
384
385 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
387
388 <$t>::from_be_bytes(buf)
390 }
391 }
392 )*
393 };
394}
395
396impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_from_bytes() {
404 let b = bytes::Bytes::from("0123456789abcdef");
405 let wrapped_b = Bytes::from(b.clone());
406 let expected = Bytes(b);
407
408 assert_eq!(wrapped_b, expected);
409 }
410
411 #[test]
412 fn test_from_slice() {
413 let arr = [1, 35, 69, 103, 137, 171, 205, 239];
414 let b = Bytes::from(&arr);
415 let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
416
417 assert_eq!(b, expected);
418 }
419
420 #[test]
421 fn hex_formatting() {
422 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
423 let expected = String::from("0x0123456789abcdef");
424 assert_eq!(format!("{b:x}"), expected);
425 assert_eq!(format!("{b}"), expected);
426 }
427
428 #[test]
429 fn test_from_str() {
430 let b = Bytes::from_str("0x1213");
431 assert!(b.is_ok());
432 let b = b.unwrap();
433 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
434
435 let b = Bytes::from_str("1213");
436 let b = b.unwrap();
437 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
438 }
439
440 #[test]
441 fn test_debug_formatting() {
442 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
443 assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
444 assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
445 }
446
447 #[test]
448 fn test_to_vec() {
449 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
450 let b = Bytes::from(vec.clone());
451
452 assert_eq!(b.to_vec(), vec);
453 }
454
455 #[test]
456 fn test_vec_partialeq() {
457 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
458 let b = Bytes::from(vec.clone());
459 assert_eq!(b, vec);
460 assert_eq!(vec, b);
461
462 let wrong_vec = vec![1, 3, 52, 137];
463 assert_ne!(b, wrong_vec);
464 assert_ne!(wrong_vec, b);
465 }
466
467 #[test]
468 fn test_bytes_partialeq() {
469 let b = bytes::Bytes::from("0123456789abcdef");
470 let wrapped_b = Bytes::from(b.clone());
471 assert_eq!(wrapped_b, b);
472
473 let wrong_b = bytes::Bytes::from("0123absd");
474 assert_ne!(wrong_b, b);
475 }
476
477 #[test]
478 fn test_u128_from_bytes() {
479 let data = Bytes::from(vec![4, 3, 2, 1]);
480 let result: u128 = u128::from(data.clone());
481 assert_eq!(result, u128::from_str("67305985").unwrap());
482 }
483
484 #[test]
485 fn test_i128_from_bytes() {
486 let data = Bytes::from(vec![4, 3, 2, 1]);
487 let result: i128 = i128::from(data.clone());
488 assert_eq!(result, i128::from_str("67305985").unwrap());
489 }
490
491 #[test]
492 fn test_i32_from_bytes() {
493 let data = Bytes::from(vec![4, 3, 2, 1]);
494 let result: i32 = i32::from(data);
495 assert_eq!(result, i32::from_str("67305985").unwrap());
496 }
497}
498
499#[cfg(feature = "diesel")]
500#[cfg(test)]
501mod diesel_tests {
502 use diesel::{insert_into, table, Insertable, Queryable};
503 use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
504
505 use super::*;
506
507 async fn setup_db() -> AsyncPgConnection {
508 let db_url = std::env::var("DATABASE_URL").unwrap();
509 let mut conn = AsyncPgConnection::establish(&db_url)
510 .await
511 .unwrap();
512 conn.begin_test_transaction()
513 .await
514 .unwrap();
515 conn
516 }
517
518 #[tokio::test]
519 async fn test_bytes_db_round_trip() {
520 table! {
521 bytes_table (id) {
522 id -> Int4,
523 data -> Binary,
524 }
525 }
526
527 #[derive(Insertable)]
528 #[diesel(table_name = bytes_table)]
529 struct NewByteEntry {
530 data: Bytes,
531 }
532
533 #[derive(Queryable, PartialEq)]
534 struct ByteEntry {
535 id: i32,
536 data: Bytes,
537 }
538
539 let mut conn = setup_db().await;
540 let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
541
542 conn.batch_execute(
543 r"
544 CREATE TEMPORARY TABLE bytes_table (
545 id SERIAL PRIMARY KEY,
546 data BYTEA NOT NULL
547 );
548 ",
549 )
550 .await
551 .unwrap();
552
553 let new_entry = NewByteEntry { data: example_bytes.clone() };
554
555 let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
556 .values(&new_entry)
557 .get_results(&mut conn)
558 .await
559 .unwrap();
560
561 assert_eq!(inserted[0].data, example_bytes);
562 }
563}