1use crate::lib::{Cow, HashMap, String, ToString, Vec};
3use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer};
4use core::fmt::Display;
5use core::str::Utf8Error;
6use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
7#[cfg(feature = "std")]
8use std::io::Write;
9
10const MAX_HEADER_SIZE: usize = 100_000_000;
11const N_LEN: usize = size_of::<u64>();
12
13#[derive(Debug)]
16pub enum SafeTensorError {
17 InvalidHeader(Utf8Error),
19 InvalidHeaderStart,
21 InvalidHeaderDeserialization(serde_json::Error),
23 HeaderTooLarge,
25 HeaderTooSmall,
27 InvalidHeaderLength,
29 TensorNotFound(String),
31 TensorInvalidInfo,
33 InvalidOffset(String),
35 #[cfg(feature = "std")]
37 IoError(std::io::Error),
38 JsonError(serde_json::Error),
40 InvalidTensorView(Dtype, Vec<usize>, usize),
42 MetadataIncompleteBuffer,
46 ValidationOverflow,
49 MisalignedSlice,
52}
53
54#[cfg(feature = "std")]
55impl From<std::io::Error> for SafeTensorError {
56 fn from(error: std::io::Error) -> SafeTensorError {
57 SafeTensorError::IoError(error)
58 }
59}
60
61impl From<serde_json::Error> for SafeTensorError {
62 fn from(error: serde_json::Error) -> SafeTensorError {
63 SafeTensorError::JsonError(error)
64 }
65}
66
67impl Display for SafeTensorError {
68 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
69 use SafeTensorError::*;
70
71 match self {
72 InvalidHeader(error) => write!(f, "invalid UTF-8 in header: {error}"),
73 InvalidHeaderStart => write!(f, "invalid start character in header, must be `{{`"),
74 InvalidHeaderDeserialization(error) => write!(f, "invalid JSON in header: {error}"),
75 JsonError(error) => write!(f, "JSON error: {error}"),
76 HeaderTooLarge => write!(f, "header too large"),
77 HeaderTooSmall => write!(f, "header too small"),
78 InvalidHeaderLength => write!(f, "invalid header length"),
79 TensorNotFound(name) => write!(f, "tensor `{name}` not found"),
80 TensorInvalidInfo => write!(f, "invalid shape, data type, or offset for tensor"),
81 InvalidOffset(name) => write!(f, "invalid offset for tensor `{name}`"),
82 #[cfg(feature = "std")]
83 IoError(error) => write!(f, "I/O error: {error}"),
84 InvalidTensorView(dtype, shape, n_bytes) => {
85 write!(f, "tensor of type {dtype} and shape (")?;
86 for (i, &dim) in shape.iter().enumerate() {
87 write!(f, "{sep}{dim}", sep = if i == 0 { "" } else { ", " })?;
88 }
89 write!(f, ") can't be created from {n_bytes} bytes")
90 }
91 MetadataIncompleteBuffer => write!(f, "incomplete metadata, file not fully covered"),
92 ValidationOverflow => write!(f, "overflow computing buffer size from shape and/or element type"),
93 MisalignedSlice => write!(f, "The slice is slicing for subbytes dtypes, and the slice does not end up at a byte boundary, this is invalid.")
94 }
95 }
96}
97
98#[cfg(not(feature = "std"))]
99impl core::error::Error for SafeTensorError {
100 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
101 match self {
102 SafeTensorError::InvalidHeader(source) => Some(source),
103 SafeTensorError::JsonError(source) => Some(source),
104 SafeTensorError::InvalidHeaderDeserialization(source) => Some(source),
105 _ => None,
106 }
107 }
108}
109
110#[cfg(feature = "std")]
111impl std::error::Error for SafeTensorError {
112 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
113 match self {
114 SafeTensorError::InvalidHeader(source) => Some(source),
115 SafeTensorError::JsonError(source) => Some(source),
116 SafeTensorError::InvalidHeaderDeserialization(source) => Some(source),
117 SafeTensorError::IoError(source) => Some(source),
118 _ => None,
119 }
120 }
121}
122
123struct PreparedData {
124 n: u64,
125 header_bytes: Vec<u8>,
126 offset: usize,
127}
128
129pub trait View {
209 fn dtype(&self) -> Dtype;
211 fn shape(&self) -> &[usize];
213 fn data(&self) -> Cow<'_, [u8]>;
215 fn data_len(&self) -> usize;
219}
220
221fn prepare<S, V, I>(
222 data: I,
223 data_info: Option<HashMap<String, String>>,
224) -> Result<(PreparedData, Vec<V>), SafeTensorError>
225where
226 S: AsRef<str> + Ord + Display,
227 V: View,
228 I: IntoIterator<Item = (S, V)>,
229{
230 let mut data: Vec<_> = data.into_iter().collect();
233 data.sort_by(|(lname, left), (rname, right)| {
234 right.dtype().cmp(&left.dtype()).then(lname.cmp(rname))
235 });
236
237 let mut tensors: Vec<V> = Vec::with_capacity(data.len());
238 let mut hmetadata = Vec::with_capacity(data.len());
239 let mut offset = 0;
240
241 for (name, tensor) in data {
242 let n = tensor.data_len();
243 let tensor_info = TensorInfo {
244 dtype: tensor.dtype(),
245 shape: tensor.shape().to_vec(),
246 data_offsets: (offset, offset + n),
247 };
248 offset += n;
249 hmetadata.push((name.to_string(), tensor_info));
250 tensors.push(tensor);
251 }
252
253 let metadata: Metadata = Metadata::new(data_info, hmetadata)?;
254 let mut metadata_buf = serde_json::to_string(&metadata)?.into_bytes();
255
256 let aligned_metadata_len = metadata_buf.len().next_multiple_of(N_LEN);
258 metadata_buf.resize(aligned_metadata_len, b' ');
259
260 Ok((
261 PreparedData {
262 n: aligned_metadata_len as u64,
263 header_bytes: metadata_buf,
264 offset,
265 },
266 tensors,
267 ))
268}
269
270pub fn serialize<
272 S: AsRef<str> + Ord + core::fmt::Display,
273 V: View,
274 I: IntoIterator<Item = (S, V)>,
275>(
276 data: I,
277 data_info: Option<HashMap<String, String>>,
278) -> Result<Vec<u8>, SafeTensorError> {
279 let (
280 PreparedData {
281 n,
282 header_bytes,
283 offset,
284 },
285 tensors,
286 ) = prepare(data, data_info)?;
287
288 if n > MAX_HEADER_SIZE as u64 {
289 return Err(SafeTensorError::HeaderTooLarge);
290 }
291
292 let expected_size = N_LEN + header_bytes.len() + offset;
293 let mut buffer: Vec<u8> = Vec::with_capacity(expected_size);
294 buffer.extend(n.to_le_bytes());
295 buffer.extend(header_bytes);
296
297 for tensor in tensors {
298 buffer.extend(tensor.data().as_ref());
299 }
300
301 Ok(buffer)
302}
303
304#[cfg(feature = "std")]
308pub fn serialize_to_file<S, V, I>(
309 data: I,
310 data_info: Option<HashMap<String, String>>,
311 filename: &std::path::Path,
312) -> Result<(), SafeTensorError>
313where
314 S: AsRef<str> + Ord + Display,
315 V: View,
316 I: IntoIterator<Item = (S, V)>,
317{
318 let (
319 PreparedData {
320 n, header_bytes, ..
321 },
322 tensors,
323 ) = prepare(data, data_info)?;
324
325 if n > MAX_HEADER_SIZE as u64 {
326 return Err(SafeTensorError::HeaderTooLarge);
327 }
328
329 let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?);
330 f.write_all(n.to_le_bytes().as_ref())?;
331 f.write_all(&header_bytes)?;
332
333 for tensor in tensors {
334 f.write_all(tensor.data().as_ref())?;
335 }
336
337 f.flush()?;
338
339 Ok(())
340}
341
342#[derive(Debug)]
345pub struct SafeTensors<'data> {
346 metadata: Metadata,
347 data: &'data [u8],
348}
349
350impl<'data> SafeTensors<'data> {
351 pub fn read_metadata(buffer: &'data [u8]) -> Result<(usize, Metadata), SafeTensorError> {
354 let buffer_len = buffer.len();
355 let Some(header_size_bytes) = buffer.get(..N_LEN) else {
356 return Err(SafeTensorError::HeaderTooSmall);
357 };
358 let arr: [u8; N_LEN] = header_size_bytes
359 .try_into()
360 .expect("this can't fail due to how `header_size_bytes` is defined above");
361 let n: usize = u64::from_le_bytes(arr)
362 .try_into()
363 .map_err(|_| SafeTensorError::HeaderTooLarge)?;
364
365 if n > MAX_HEADER_SIZE {
366 return Err(SafeTensorError::HeaderTooLarge);
367 }
368
369 let stop = n
370 .checked_add(N_LEN)
371 .ok_or(SafeTensorError::InvalidHeaderLength)?;
372
373 let Some(header_bytes) = buffer.get(N_LEN..stop) else {
376 return Err(SafeTensorError::InvalidHeaderLength);
377 };
378 let string = core::str::from_utf8(header_bytes).map_err(SafeTensorError::InvalidHeader)?;
379 let metadata: HashMetadata =
385 serde_json::from_str(string).map_err(SafeTensorError::InvalidHeaderDeserialization)?;
386 let metadata: Metadata = metadata.try_into()?;
387 let buffer_end = metadata.validate()?;
388 if buffer_end + N_LEN + n != buffer_len {
389 return Err(SafeTensorError::MetadataIncompleteBuffer);
390 }
391
392 Ok((n, metadata))
393 }
394
395 pub fn deserialize(buffer: &'data [u8]) -> Result<Self, SafeTensorError> {
415 let (n, metadata) = SafeTensors::read_metadata(buffer)?;
416 let data = &buffer[N_LEN + n..];
417 Ok(Self { metadata, data })
418 }
419
420 pub fn tensors(&self) -> Vec<(String, TensorView<'data>)> {
424 let mut tensors = Vec::with_capacity(self.metadata.index_map.len());
425 for (name, &index) in &self.metadata.index_map {
426 let info = &self.metadata.tensors[index];
427 let tensorview = TensorView {
428 dtype: info.dtype,
429 shape: info.shape.clone(),
430 data: &self.data[info.data_offsets.0..info.data_offsets.1],
431 };
432 tensors.push((name.to_string(), tensorview));
433 }
434 tensors
435 }
436
437 pub fn iter(&self) -> impl Iterator<Item = (&str, TensorView<'data>)> {
441 self.metadata.index_map.iter().map(|(name, &idx)| {
442 let info = &self.metadata.tensors[idx];
443 (
444 name.as_str(),
445 TensorView {
446 dtype: info.dtype,
447 shape: info.shape.clone(),
448 data: &self.data[info.data_offsets.0..info.data_offsets.1],
449 },
450 )
451 })
452 }
453
454 pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'data>, SafeTensorError> {
458 let &index = self
459 .metadata
460 .index_map
461 .get(tensor_name)
462 .ok_or_else(|| SafeTensorError::TensorNotFound(tensor_name.to_string()))?;
463
464 let info = self
465 .metadata
466 .tensors
467 .get(index)
468 .ok_or_else(|| SafeTensorError::TensorNotFound(tensor_name.to_string()))?;
469
470 Ok(TensorView {
471 dtype: info.dtype,
472 shape: info.shape.clone(),
473 data: &self.data[info.data_offsets.0..info.data_offsets.1],
474 })
475 }
476
477 pub fn names(&self) -> Vec<&'_ str> {
481 self.metadata.index_map.keys().map(String::as_str).collect()
482 }
483
484 #[inline]
486 pub fn len(&self) -> usize {
487 self.metadata.tensors.len()
488 }
489
490 #[inline]
492 pub fn is_empty(&self) -> bool {
493 self.metadata.tensors.is_empty()
494 }
495}
496
497#[derive(Debug, Clone)]
500pub struct Metadata {
501 metadata: Option<HashMap<String, String>>,
502 tensors: Vec<TensorInfo>,
503 index_map: HashMap<String, usize>,
504}
505
506#[derive(Serialize, Deserialize)]
508struct HashMetadata {
509 #[serde(skip_serializing_if = "Option::is_none")]
510 #[serde(rename = "__metadata__")]
511 metadata: Option<HashMap<String, String>>,
512 #[serde(flatten)]
513 tensors: HashMap<String, TensorInfo>,
514}
515
516impl TryFrom<HashMetadata> for Metadata {
517 type Error = SafeTensorError;
518 fn try_from(hashdata: HashMetadata) -> Result<Self, Self::Error> {
519 let (metadata, tensors) = (hashdata.metadata, hashdata.tensors);
520 let mut tensors: Vec<_> = tensors.into_iter().collect();
521 tensors.sort_by(|(_, left), (_, right)| left.data_offsets.cmp(&right.data_offsets));
526 Metadata::new(metadata, tensors)
527 }
528}
529
530impl<'de> Deserialize<'de> for Metadata {
531 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
532 where
533 D: Deserializer<'de>,
534 {
535 let hashdata: HashMetadata = HashMetadata::deserialize(deserializer)?;
536
537 let metadata: Metadata = hashdata.try_into().map_err(serde::de::Error::custom)?;
538 Ok(metadata)
539 }
540}
541
542impl Serialize for Metadata {
543 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
544 where
545 S: Serializer,
546 {
547 let mut names = vec![""; self.index_map.len()];
548 for (name, &index) in &self.index_map {
549 names[index] = name;
550 }
551
552 let length = self.metadata.as_ref().map_or(0, HashMap::len);
553 let mut map = serializer.serialize_map(Some(self.tensors.len() + length))?;
554
555 if let Some(metadata) = &self.metadata {
556 map.serialize_entry("__metadata__", metadata)?;
557 }
558
559 for (name, info) in names.iter().zip(&self.tensors) {
560 map.serialize_entry(name, info)?;
561 }
562
563 map.end()
564 }
565}
566
567impl Metadata {
568 pub fn new(
572 metadata: Option<HashMap<String, String>>,
573 tensors: Vec<(String, TensorInfo)>,
574 ) -> Result<Self, SafeTensorError> {
575 let mut index_map = HashMap::with_capacity(tensors.len());
576
577 let tensors: Vec<_> = tensors
578 .into_iter()
579 .enumerate()
580 .map(|(index, (k, tensor))| {
581 index_map.insert(k, index);
582 tensor
583 })
584 .collect();
585
586 let metadata = Self {
587 metadata,
588 tensors,
589 index_map,
590 };
591 metadata.validate()?;
592 Ok(metadata)
593 }
594
595 fn validate(&self) -> Result<usize, SafeTensorError> {
596 let mut start = 0;
597 for (i, info) in self.tensors.iter().enumerate() {
598 let (s, e) = info.data_offsets;
599 if s != start || e < s {
600 let tensor_name = self
601 .index_map
602 .iter()
603 .find_map(|(name, &index)| if index == i { Some(&name[..]) } else { None })
604 .unwrap_or("no_tensor");
605 return Err(SafeTensorError::InvalidOffset(tensor_name.to_string()));
606 }
607
608 start = e;
609
610 let nelements: usize = info
611 .shape
612 .iter()
613 .copied()
614 .try_fold(1usize, usize::checked_mul)
615 .ok_or(SafeTensorError::ValidationOverflow)?;
616 let nbits = nelements
617 .checked_mul(info.dtype.bitsize())
618 .ok_or(SafeTensorError::ValidationOverflow)?;
619
620 if nbits % 8 != 0 {
621 return Err(SafeTensorError::MisalignedSlice);
622 }
623 let size = nbits
624 .checked_div(8)
625 .ok_or(SafeTensorError::ValidationOverflow)?;
626
627 if e - s != size {
628 return Err(SafeTensorError::TensorInvalidInfo);
629 }
630 }
631 Ok(start)
632 }
633
634 pub fn info(&self, name: &str) -> Option<&TensorInfo> {
636 let &index = self.index_map.get(name)?;
637 self.tensors.get(index)
638 }
639
640 pub fn tensors(&self) -> HashMap<String, &TensorInfo> {
642 self.index_map
643 .iter()
644 .map(|(tensor_name, &index)| (tensor_name.clone(), &self.tensors[index]))
645 .collect()
646 }
647
648 pub fn offset_keys(&self) -> Vec<String> {
650 let mut index_vec: Vec<_> = self.index_map.iter().collect();
651 index_vec.sort_by_key(|a| a.1);
652 index_vec.into_iter().map(|a| a.0.clone()).collect()
653 }
654
655 pub fn data_len(&self) -> usize {
657 if let Some(tensor) = self.tensors.last() {
658 tensor.data_offsets.1
659 } else {
660 0
661 }
662 }
663
664 pub fn metadata(&self) -> &Option<HashMap<String, String>> {
666 &self.metadata
667 }
668}
669
670#[derive(Debug, PartialEq, Eq, Clone)]
674pub struct TensorView<'data> {
675 dtype: Dtype,
676 shape: Vec<usize>,
677 data: &'data [u8],
678}
679
680impl View for &TensorView<'_> {
681 fn dtype(&self) -> Dtype {
682 self.dtype
683 }
684
685 fn shape(&self) -> &[usize] {
686 &self.shape
687 }
688
689 fn data(&self) -> Cow<'_, [u8]> {
690 self.data.into()
691 }
692
693 fn data_len(&self) -> usize {
694 self.data.len()
695 }
696}
697
698impl View for TensorView<'_> {
699 fn dtype(&self) -> Dtype {
700 self.dtype
701 }
702
703 fn shape(&self) -> &[usize] {
704 &self.shape
705 }
706
707 fn data(&self) -> Cow<'_, [u8]> {
708 self.data.into()
709 }
710
711 fn data_len(&self) -> usize {
712 self.data.len()
713 }
714}
715
716impl<'data> TensorView<'data> {
717 pub fn new(
719 dtype: Dtype,
720 shape: Vec<usize>,
721 data: &'data [u8],
722 ) -> Result<Self, SafeTensorError> {
723 let n_elements: usize = shape.iter().product();
724
725 let nbits = n_elements * dtype.bitsize();
726 if nbits % 8 != 0 {
727 return Err(SafeTensorError::MisalignedSlice);
728 }
729 let size = nbits
730 .checked_div(8)
731 .ok_or(SafeTensorError::ValidationOverflow)?;
732
733 if data.len() != size {
734 Err(SafeTensorError::InvalidTensorView(dtype, shape, data.len()))
735 } else {
736 Ok(Self { dtype, shape, data })
737 }
738 }
739 pub fn dtype(&self) -> Dtype {
741 self.dtype
742 }
743
744 pub fn shape(&self) -> &[usize] {
746 &self.shape
747 }
748
749 pub fn data(&self) -> &'data [u8] {
751 self.data
752 }
753
754 pub fn sliced_data(
756 &'data self,
757 slices: &[TensorIndexer],
758 ) -> Result<SliceIterator<'data>, InvalidSlice> {
759 SliceIterator::new(self, slices)
760 }
761}
762
763#[derive(Debug, Deserialize, Serialize, Clone)]
767pub struct TensorInfo {
768 pub dtype: Dtype,
770 pub shape: Vec<usize>,
772 pub data_offsets: (usize, usize),
774}
775
776#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
778#[non_exhaustive]
779pub enum Dtype {
780 BOOL,
782 F4,
784 #[allow(non_camel_case_types)]
786 F6_E2M3,
787 #[allow(non_camel_case_types)]
789 F6_E3M2,
790 U8,
792 I8,
794 #[allow(non_camel_case_types)]
796 F8_E5M2,
797 #[allow(non_camel_case_types)]
799 F8_E4M3,
800 #[allow(non_camel_case_types)]
802 F8_E8M0,
803 I16,
805 U16,
807 F16,
809 BF16,
811 I32,
813 U32,
815 F32,
817 C64,
819 F64,
821 I64,
823 U64,
825}
826
827impl Dtype {
828 pub fn bitsize(&self) -> usize {
830 match self {
831 Dtype::F4 => 4,
832 Dtype::F6_E3M2 => 6,
833 Dtype::F6_E2M3 => 6,
834 Dtype::BOOL => 8,
835 Dtype::U8 => 8,
836 Dtype::I8 => 8,
837 Dtype::F8_E5M2 => 8,
838 Dtype::F8_E4M3 => 8,
839 Dtype::F8_E8M0 => 8,
840 Dtype::I16 => 16,
841 Dtype::U16 => 16,
842 Dtype::I32 => 32,
843 Dtype::U32 => 32,
844 Dtype::I64 => 64,
845 Dtype::U64 => 64,
846 Dtype::F16 => 16,
847 Dtype::BF16 => 16,
848 Dtype::F32 => 32,
849 Dtype::F64 => 64,
850 Dtype::C64 => 64,
851 }
852 }
853 #[deprecated(
855 since = "0.6.0",
856 note = "Use `bitsize` instead as some elements have smaller than a full byte of width"
857 )]
858 pub fn size(&self) -> usize {
859 self.bitsize() / 8
860 }
861}
862
863impl Display for Dtype {
864 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
865 f.write_str(match *self {
866 Dtype::F4 => "F4",
867 Dtype::F6_E2M3 => "F6_E2M3",
868 Dtype::F6_E3M2 => "F6_E3M2",
869 Dtype::BOOL => "BOOL",
870 Dtype::I8 => "I8",
871 Dtype::U8 => "U8",
872 Dtype::F8_E5M2 => "F8_E5M2",
873 Dtype::F8_E4M3 => "F8_E4M3",
874 Dtype::F8_E8M0 => "F8_E8M0",
875 Dtype::I16 => "I16",
876 Dtype::U16 => "U16",
877 Dtype::I32 => "I32",
878 Dtype::U32 => "U32",
879 Dtype::I64 => "I64",
880 Dtype::U64 => "U64",
881 Dtype::F16 => "F16",
882 Dtype::BF16 => "BF16",
883 Dtype::F32 => "F32",
884 Dtype::F64 => "F64",
885 Dtype::C64 => "C64",
886 })
887 }
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893 use crate::slice::IndexOp;
894 use proptest::prelude::*;
895 #[cfg(not(feature = "std"))]
896 extern crate std;
897 use std::io::Write;
898
899 const MAX_DIMENSION: usize = 8;
900 const MAX_SIZE: usize = 8;
901 const MAX_TENSORS: usize = 8;
902
903 fn arbitrary_dtype() -> impl Strategy<Value = Dtype> {
904 prop_oneof![
905 Just(Dtype::BOOL),
906 Just(Dtype::F4),
907 Just(Dtype::F6_E3M2),
908 Just(Dtype::F6_E2M3),
909 Just(Dtype::F8_E5M2),
910 Just(Dtype::F8_E4M3),
911 Just(Dtype::U8),
912 Just(Dtype::I8),
913 Just(Dtype::I16),
914 Just(Dtype::U16),
915 Just(Dtype::I32),
916 Just(Dtype::U32),
917 Just(Dtype::I64),
918 Just(Dtype::U64),
919 Just(Dtype::F16),
920 Just(Dtype::BF16),
921 Just(Dtype::F32),
922 Just(Dtype::F64),
923 Just(Dtype::C64),
924 ]
925 }
926
927 fn arbitrary_shape() -> impl Strategy<Value = Vec<usize>> {
928 (1..MAX_DIMENSION).prop_flat_map(|length| prop::collection::vec(1..MAX_SIZE, length))
930 }
931
932 fn arbitrary_metadata() -> impl Strategy<Value = Metadata> {
933 (1..MAX_TENSORS)
935 .prop_flat_map(|size| {
936 (
938 prop::collection::vec(arbitrary_dtype(), size),
939 prop::collection::vec(arbitrary_shape(), size),
940 )
941 })
942 .prop_filter_map("Misaligned slices", |(dtypes, shapes)| {
943 let mut start = 0;
945 let tensors: Vec<TensorInfo> = dtypes
946 .iter()
947 .zip(shapes)
948 .flat_map(|(dtype, shape)| {
949 let bitlength: usize = shape.iter().product::<usize>() * dtype.bitsize();
952 if bitlength % 8 != 0 {
953 return None;
954 }
955 let length = bitlength.div_ceil(8);
956 let end = start + length;
957 let tensor = TensorInfo {
958 dtype: *dtype,
959 shape,
960 data_offsets: (start, end),
961 };
962 start = end;
963 Some(tensor)
964 })
965 .collect();
966 let index_map = (0..tensors.len())
967 .map(|index| (format!("t.{index}"), index))
968 .collect();
969 if tensors.is_empty() {
970 None
971 } else {
972 Some(Metadata {
973 metadata: None,
974 tensors,
975 index_map,
976 })
977 }
978 })
979 }
980
981 fn data_size(metadata: &Metadata) -> usize {
989 metadata.tensors.last().unwrap().data_offsets.1
990 }
991
992 proptest! {
993 #![proptest_config(ProptestConfig::with_cases(20))]
994
995 #[test]
996 fn test_indexing(metadata in arbitrary_metadata()) {
997 let data = vec![0u8; data_size(&metadata)];
998 let tensors = SafeTensors { metadata, data: &data };
999 for name in tensors.names() {
1000 assert!(tensors.tensor(name).is_ok());
1001 }
1002 }
1003 #[test]
1004 fn test_roundtrip(metadata in arbitrary_metadata()) {
1005 let data: Vec<u8> = (0..data_size(&metadata)).map(|x| x as u8).collect();
1006 let before = SafeTensors { metadata, data: &data };
1007 let tensors = before.tensors();
1008 let bytes = serialize(tensors.iter().map(|(name, view)| (name.to_string(), view)), None).unwrap();
1009
1010 let after = SafeTensors::deserialize(&bytes).unwrap();
1011
1012 assert_eq!(before.names().len(), after.names().len());
1014 for name in before.names() {
1015 let tensor_before = before.tensor(name).unwrap();
1016 let tensor_after = after.tensor(name).unwrap();
1017 assert_eq!(tensor_after.data().as_ptr() as usize % tensor_after.dtype().bitsize().div_ceil(8), 0);
1018 assert_eq!(tensor_before, tensor_after);
1019 }
1020 }
1021 }
1022
1023 #[test]
1024 fn test_serialization() {
1025 let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
1026 .into_iter()
1027 .flat_map(|f| f.to_le_bytes())
1028 .collect();
1029 let shape = vec![1, 2, 3];
1030 let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
1031 let metadata: HashMap<String, TensorView> =
1032 [("attn.0".to_string(), attn_0)].into_iter().collect();
1033
1034 let out = serialize(&metadata, None).unwrap();
1035 assert_eq!(
1036 out,
1037 [
1038 64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
1039 116, 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34,
1040 58, 91, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115,
1041 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 0, 0, 0, 0, 0, 0, 128, 63,
1042 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64
1043 ]
1044 );
1045 let _parsed = SafeTensors::deserialize(&out).unwrap();
1046 }
1047
1048 #[test]
1049 fn test_serialization_fp4() {
1050 let data: Vec<u8> = vec![0u8];
1051 let shape = vec![1, 2];
1052 let attn_0 = TensorView::new(Dtype::F4, shape, &data).unwrap();
1053 let metadata: HashMap<String, TensorView> =
1054 [("attn.0".to_string(), attn_0)].into_iter().collect();
1055
1056 let out = serialize(&metadata, None).unwrap();
1057 assert_eq!(
1058 out,
1059 [
1060 64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
1061 116, 121, 112, 101, 34, 58, 34, 70, 52, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
1062 91, 49, 44, 50, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115, 101, 116,
1063 115, 34, 58, 91, 48, 44, 49, 93, 125, 125, 32, 32, 32, 32, 0
1064 ]
1065 );
1066 let parsed = SafeTensors::deserialize(&out).unwrap();
1067 let tensors: HashMap<_, _> = parsed.tensors().into_iter().collect();
1068 assert_eq!(tensors, metadata);
1069 }
1070
1071 #[test]
1072 fn test_serialization_fp4_misaligned() {
1073 let data: Vec<u8> = vec![0u8, 1u8];
1074 let shape = vec![1, 3];
1075 let attn_0 = TensorView::new(Dtype::F4, shape, &data);
1076 assert!(matches!(attn_0, Err(SafeTensorError::MisalignedSlice)));
1077 }
1078
1079 #[test]
1080 fn test_serialization_fp4_invalid() {
1081 let data: Vec<u8> = vec![0u8, 1u8];
1082 let shape = vec![1, 2];
1083 let attn_0 = TensorView::new(Dtype::F4, shape, &data);
1084 assert!(matches!(
1085 attn_0,
1086 Err(SafeTensorError::InvalidTensorView(Dtype::F4, _shape, _size))
1087 ));
1088 }
1089
1090 #[test]
1091 fn test_empty() {
1092 let tensors: HashMap<String, TensorView> = HashMap::new();
1093
1094 let out = serialize(&tensors, None).unwrap();
1095 assert_eq!(
1096 out,
1097 [8, 0, 0, 0, 0, 0, 0, 0, 123, 125, 32, 32, 32, 32, 32, 32]
1098 );
1099 let _parsed = SafeTensors::deserialize(&out).unwrap();
1100
1101 let metadata: Option<HashMap<String, String>> = Some(
1102 [("framework".to_string(), "pt".to_string())]
1103 .into_iter()
1104 .collect(),
1105 );
1106 let out = serialize(&tensors, metadata).unwrap();
1107 assert_eq!(
1108 out,
1109 [
1110 40, 0, 0, 0, 0, 0, 0, 0, 123, 34, 95, 95, 109, 101, 116, 97, 100, 97, 116, 97, 95,
1111 95, 34, 58, 123, 34, 102, 114, 97, 109, 101, 119, 111, 114, 107, 34, 58, 34, 112,
1112 116, 34, 125, 125, 32, 32, 32, 32, 32
1113 ]
1114 );
1115 let _parsed = SafeTensors::deserialize(&out).unwrap();
1116 }
1117
1118 #[test]
1119 fn test_serialization_forced_alignement() {
1120 let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
1121 .into_iter()
1122 .flat_map(|f| f.to_le_bytes())
1123 .collect();
1124 let shape = vec![1, 1, 2, 3];
1125 let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
1126 let metadata: HashMap<String, TensorView> =
1127 [("attn0".to_string(), attn_0)].into_iter().collect();
1129
1130 let out = serialize(&metadata, None).unwrap();
1131 assert_eq!(
1132 out,
1133 [
1134 72, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 48, 34, 58, 123, 34, 100, 116,
1135 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
1136 91, 49, 44, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102,
1137 115, 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 32, 32, 32, 32, 32,
1140 32, 32, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0,
1141 160, 64
1142 ],
1143 );
1144 let parsed = SafeTensors::deserialize(&out).unwrap();
1145 let tensor = parsed.tensor("attn0").unwrap();
1146 assert_eq!(
1147 tensor.data().as_ptr() as usize % tensor.dtype().bitsize().div_ceil(8),
1148 0
1149 );
1150 }
1151
1152 #[test]
1153 fn test_slicing() {
1154 let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
1155 .into_iter()
1156 .flat_map(|f| f.to_le_bytes())
1157 .collect();
1158 let attn_0 = TensorView {
1159 dtype: Dtype::F32,
1160 shape: vec![1, 2, 3],
1161 data: &data,
1162 };
1163 let metadata: HashMap<String, TensorView> =
1164 [("attn.0".to_string(), attn_0)].into_iter().collect();
1165
1166 let out = serialize(&metadata, None).unwrap();
1167 let parsed = SafeTensors::deserialize(&out).unwrap();
1168
1169 let out_buffer: Vec<u8> = parsed
1170 .tensor("attn.0")
1171 .unwrap()
1172 .slice((.., ..1))
1173 .unwrap()
1174 .flat_map(|b| b.to_vec())
1175 .collect();
1176 assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64]);
1177 assert_eq!(
1178 out_buffer,
1179 vec![0.0f32, 1.0, 2.0]
1180 .into_iter()
1181 .flat_map(|f| f.to_le_bytes())
1182 .collect::<Vec<_>>()
1183 );
1184 let out_buffer: Vec<u8> = parsed
1185 .tensor("attn.0")
1186 .unwrap()
1187 .slice((.., .., ..1))
1188 .unwrap()
1189 .flat_map(|b| b.to_vec())
1190 .collect();
1191 assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 64, 64]);
1192 assert_eq!(
1193 out_buffer,
1194 vec![0.0f32, 3.0]
1195 .into_iter()
1196 .flat_map(|f| f.to_le_bytes())
1197 .collect::<Vec<_>>()
1198 );
1199 }
1200
1201 #[test]
1202 fn test_gpt2() {
1203 gpt2_like(12, "gpt2");
1204 }
1205
1206 #[test]
1207 fn test_gpt2_tiny() {
1208 gpt2_like(6, "gpt2_tiny");
1209 }
1210
1211 fn gpt2_like(n_heads: usize, model_id: &str) {
1212 let mut tensors_desc = vec![
1213 ("wte".to_string(), vec![50257, 768]),
1214 ("wpe".to_string(), vec![1024, 768]),
1215 ];
1216 for i in 0..n_heads {
1217 tensors_desc.push((format!("h.{i}.ln_1.weight"), vec![768]));
1218 tensors_desc.push((format!("h.{i}.ln_1.bias"), vec![768]));
1219 tensors_desc.push((format!("h.{i}.attn.bias"), vec![1, 1, 1024, 1024]));
1220 tensors_desc.push((format!("h.{i}.attn.c_attn.weight"), vec![768, 2304]));
1221 tensors_desc.push((format!("h.{i}.attn.c_attn.bias"), vec![2304]));
1222 tensors_desc.push((format!("h.{i}.attn.c_proj.weight"), vec![768, 768]));
1223 tensors_desc.push((format!("h.{i}.attn.c_proj.bias"), vec![768]));
1224 tensors_desc.push((format!("h.{i}.ln_2.weight"), vec![768]));
1225 tensors_desc.push((format!("h.{i}.ln_2.bias"), vec![768]));
1226 tensors_desc.push((format!("h.{i}.mlp.c_fc.weight"), vec![768, 3072]));
1227 tensors_desc.push((format!("h.{i}.mlp.c_fc.bias"), vec![3072]));
1228 tensors_desc.push((format!("h.{i}.mlp.c_proj.weight"), vec![3072, 768]));
1229 tensors_desc.push((format!("h.{i}.mlp.c_proj.bias"), vec![768]));
1230 }
1231 tensors_desc.push(("ln_f.weight".to_string(), vec![768]));
1232 tensors_desc.push(("ln_f.bias".to_string(), vec![768]));
1233
1234 let dtype = Dtype::F32;
1235 let nbits: usize = tensors_desc
1236 .iter()
1237 .map(|(_, shape)| shape.iter().product::<usize>())
1238 .sum::<usize>()
1239 * dtype.bitsize();
1240 if nbits % 8 != 0 {
1241 panic!("Misaligned slice");
1242 }
1243 let n = nbits
1244 .checked_div(8)
1245 .ok_or(SafeTensorError::ValidationOverflow)
1246 .unwrap(); let all_data = vec![0; n];
1248 let mut metadata = HashMap::with_capacity(tensors_desc.len());
1249 let mut offset = 0;
1250 for (name, shape) in tensors_desc {
1251 let n: usize = shape.iter().product();
1252 let buffer = &all_data[offset..offset + (n * dtype.bitsize()) / 8];
1253 let tensor = TensorView::new(dtype, shape, buffer).unwrap();
1254 metadata.insert(name, tensor);
1255 offset += n;
1256 }
1257
1258 let filename = format!("./out_{model_id}.safetensors");
1259
1260 let out = serialize(&metadata, None).unwrap();
1261 std::fs::write(&filename, out).unwrap();
1262 let raw = std::fs::read(&filename).unwrap();
1263 let _deserialized = SafeTensors::deserialize(&raw).unwrap();
1264 std::fs::remove_file(&filename).unwrap();
1265
1266 #[cfg(feature = "std")]
1268 {
1269 serialize_to_file(&metadata, None, std::path::Path::new(&filename)).unwrap();
1270 let raw = std::fs::read(&filename).unwrap();
1271 let _deserialized = SafeTensors::deserialize(&raw).unwrap();
1272 std::fs::remove_file(&filename).unwrap();
1273 }
1274 }
1275
1276 #[test]
1277 fn test_empty_shapes_allowed() {
1278 let serialized = b"8\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[],\"data_offsets\":[0,4]}}\x00\x00\x00\x00";
1279
1280 let loaded = SafeTensors::deserialize(serialized).unwrap();
1281 assert_eq!(loaded.names(), vec!["test"]);
1282 let tensor = loaded.tensor("test").unwrap();
1283 assert!(tensor.shape().is_empty());
1284 assert_eq!(tensor.dtype(), Dtype::I32);
1285 assert_eq!(tensor.data(), b"\0\0\0\0");
1287 }
1288
1289 #[test]
1290 fn test_deserialization() {
1291 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1292
1293 let loaded = SafeTensors::deserialize(serialized).unwrap();
1294
1295 assert_eq!(loaded.len(), 1);
1296 assert_eq!(loaded.names(), vec!["test"]);
1297 let tensor = loaded.tensor("test").unwrap();
1298 assert_eq!(tensor.shape(), vec![2, 2]);
1299 assert_eq!(tensor.dtype(), Dtype::I32);
1300 assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
1302 }
1303
1304 #[test]
1305 fn test_lifetimes() {
1306 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1307
1308 let tensor = {
1309 let loaded = SafeTensors::deserialize(serialized).unwrap();
1310 loaded.tensor("test").unwrap()
1311 };
1312
1313 assert_eq!(tensor.shape(), vec![2, 2]);
1314 assert_eq!(tensor.dtype(), Dtype::I32);
1315 assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
1317 }
1318
1319 #[test]
1320 fn test_json_attack() {
1321 let mut tensors = HashMap::new();
1322 let dtype = Dtype::F32;
1323 let shape = vec![2, 2];
1324 let data_offsets = (0, 16);
1325 for i in 0..10 {
1326 tensors.insert(
1327 format!("weight_{i}"),
1328 TensorInfo {
1329 dtype,
1330 shape: shape.clone(),
1331 data_offsets,
1332 },
1333 );
1334 }
1335
1336 let metadata = HashMetadata {
1337 metadata: None,
1338 tensors,
1339 };
1340 let serialized = serde_json::to_string(&metadata).unwrap();
1341 let serialized = serialized.as_bytes();
1342
1343 let n = serialized.len();
1344
1345 let filename = "out.safetensors";
1346 let mut f = std::io::BufWriter::new(std::fs::File::create(filename).unwrap());
1347 f.write_all(n.to_le_bytes().as_ref()).unwrap();
1348 f.write_all(serialized).unwrap();
1349 f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();
1350 f.flush().unwrap();
1351
1352 let reloaded = std::fs::read(filename).unwrap();
1353 match SafeTensors::deserialize(&reloaded) {
1354 Err(SafeTensorError::InvalidOffset(_)) => {
1355 }
1357 Err(err) => panic!("Unexpected error {err:?}"),
1358 Ok(_) => panic!("This should not be able to be deserialized"),
1359 }
1360 }
1361
1362 #[test]
1363 fn test_metadata_incomplete_buffer() {
1364 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00extra_bogus_data_for_polyglot_file";
1365
1366 match SafeTensors::deserialize(serialized) {
1367 Err(SafeTensorError::MetadataIncompleteBuffer) => {
1368 }
1370 _ => panic!("This should not be able to be deserialized"),
1371 }
1372
1373 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; match SafeTensors::deserialize(serialized) {
1377 Err(SafeTensorError::MetadataIncompleteBuffer) => {
1378 }
1380 _ => panic!("This should not be able to be deserialized"),
1381 }
1382 }
1383
1384 #[test]
1385 fn test_header_too_large() {
1386 let serialized = b"<\x00\x00\x00\x00\xff\xff\xff{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1387
1388 match SafeTensors::deserialize(serialized) {
1389 Err(SafeTensorError::HeaderTooLarge) => {
1390 }
1392 _ => panic!("This should not be able to be deserialized"),
1393 }
1394 }
1395
1396 #[test]
1397 fn test_header_too_small() {
1398 let serialized = b"";
1399 match SafeTensors::deserialize(serialized) {
1400 Err(SafeTensorError::HeaderTooSmall) => {
1401 }
1403 _ => panic!("This should not be able to be deserialized"),
1404 }
1405 }
1406
1407 #[test]
1408 fn test_invalid_header_length() {
1409 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00";
1410 match SafeTensors::deserialize(serialized) {
1411 Err(SafeTensorError::InvalidHeaderLength) => {
1412 }
1414 _ => panic!("This should not be able to be deserialized"),
1415 }
1416 }
1417
1418 #[test]
1419 fn test_invalid_header_non_utf8() {
1420 let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00\xff";
1421 match SafeTensors::deserialize(serialized) {
1422 Err(SafeTensorError::InvalidHeader(_)) => {
1423 }
1425 _ => panic!("This should not be able to be deserialized"),
1426 }
1427 }
1428
1429 #[test]
1430 fn test_invalid_header_not_json() {
1431 let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00{";
1432 match SafeTensors::deserialize(serialized) {
1433 Err(SafeTensorError::InvalidHeaderDeserialization(_)) => {
1434 }
1436 _ => panic!("This should not be able to be deserialized"),
1437 }
1438 }
1439
1440 #[test]
1441 fn test_whitespace_padded_header() {
1443 let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00{}\x0D\x20\x09\x0A";
1444 let loaded = SafeTensors::deserialize(serialized).unwrap();
1445 assert_eq!(loaded.len(), 0);
1446 }
1447
1448 #[test]
1462 fn test_zero_sized_tensor() {
1463 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,0],\"data_offsets\":[0, 0]}}";
1464 let loaded = SafeTensors::deserialize(serialized).unwrap();
1465
1466 assert_eq!(loaded.names(), vec!["test"]);
1467 let tensor = loaded.tensor("test").unwrap();
1468 assert_eq!(tensor.shape(), vec![2, 0]);
1469 assert_eq!(tensor.dtype(), Dtype::I32);
1470 assert_eq!(tensor.data(), b"");
1471 }
1472
1473 #[test]
1474 fn test_invalid_info() {
1475 let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0, 4]}}";
1476 match SafeTensors::deserialize(serialized) {
1477 Err(SafeTensorError::TensorInvalidInfo) => {
1478 }
1480 something => panic!("This should not be able to be deserialized got {something:?}"),
1481 }
1482 }
1483
1484 #[test]
1485 fn test_validation_overflow() {
1486 let serialized = b"O\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,18446744073709551614],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1489 match SafeTensors::deserialize(serialized) {
1490 Err(SafeTensorError::ValidationOverflow) => {
1491 }
1493 _ => panic!("This should not be able to be deserialized"),
1494 }
1495 let serialized = b"N\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,9223372036854775807],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1498 match SafeTensors::deserialize(serialized) {
1499 Err(SafeTensorError::ValidationOverflow) => {
1500 }
1502 _ => panic!("This should not be able to be deserialized"),
1503 }
1504 }
1505
1506 #[test]
1507 fn test_invalid_header_size_serialization() {
1508 let mut data_info = HashMap::<String, String>::new();
1509 let tensors: HashMap<String, TensorView> = HashMap::new();
1510
1511 let very_large_metadata = "a".repeat(MAX_HEADER_SIZE);
1513 data_info.insert("very_large_metadata".to_string(), very_large_metadata);
1514 match serialize(&tensors, Some(data_info)) {
1515 Err(SafeTensorError::HeaderTooLarge) => {
1516 }
1518 _ => panic!("This should not be able to be serialized"),
1519 }
1520 }
1521}