safetensors/
tensor.rs

1//! Module Containing the most important structures
2use crate::lib::{Cow, HashMap, String, ToString, Vec};
3use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer};
4use core::fmt::Display;
5use core::str::Utf8Error;
6use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
7#[cfg(feature = "std")]
8use std::io::Write;
9
10const MAX_HEADER_SIZE: usize = 100_000_000;
11const N_LEN: usize = size_of::<u64>();
12
13/// Possible errors that could occur while reading
14/// A Safetensor file.
15#[derive(Debug)]
16pub enum SafeTensorError {
17    /// The header is an invalid UTF-8 string and cannot be read.
18    InvalidHeader(Utf8Error),
19    /// The header's first byte is not the expected `{`.
20    InvalidHeaderStart,
21    /// The header does contain a valid string, but it is not valid JSON.
22    InvalidHeaderDeserialization(serde_json::Error),
23    /// The header is large than 100Mo which is considered too large (Might evolve in the future).
24    HeaderTooLarge,
25    /// The header is smaller than 8 bytes
26    HeaderTooSmall,
27    /// The header length is invalid
28    InvalidHeaderLength,
29    /// The tensor name was not found in the archive
30    TensorNotFound(String),
31    /// Invalid information between shape, dtype and the proposed offsets in the file
32    TensorInvalidInfo,
33    /// The offsets declared for tensor with name `String` in the header are invalid
34    InvalidOffset(String),
35    /// IoError
36    #[cfg(feature = "std")]
37    IoError(std::io::Error),
38    /// JSON error
39    JsonError(serde_json::Error),
40    /// The follow tensor cannot be created because the buffer size doesn't match shape + dtype
41    InvalidTensorView(Dtype, Vec<usize>, usize),
42    /// The metadata is invalid because the data offsets of the tensor does not
43    /// fully cover the buffer part of the file. The last offset **must** be
44    /// the end of the file.
45    MetadataIncompleteBuffer,
46    /// The metadata contains information (shape or shape * dtype size) which lead to an
47    /// arithmetic overflow. This is most likely an error in the file.
48    ValidationOverflow,
49    /// For smaller than 1 byte dtypes, some slices will happen outside of the byte boundary, some special care has to be taken
50    /// and standard functions will fail
51    MisalignedSlice,
52}
53
54#[cfg(feature = "std")]
55impl From<std::io::Error> for SafeTensorError {
56    fn from(error: std::io::Error) -> SafeTensorError {
57        SafeTensorError::IoError(error)
58    }
59}
60
61impl From<serde_json::Error> for SafeTensorError {
62    fn from(error: serde_json::Error) -> SafeTensorError {
63        SafeTensorError::JsonError(error)
64    }
65}
66
67impl Display for SafeTensorError {
68    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
69        use SafeTensorError::*;
70
71        match self {
72            InvalidHeader(error) => write!(f, "invalid UTF-8 in header: {error}"),
73            InvalidHeaderStart => write!(f, "invalid start character in header, must be `{{`"),
74            InvalidHeaderDeserialization(error) => write!(f, "invalid JSON in header: {error}"),
75            JsonError(error) => write!(f, "JSON error: {error}"),
76            HeaderTooLarge => write!(f, "header too large"),
77            HeaderTooSmall => write!(f, "header too small"),
78            InvalidHeaderLength => write!(f, "invalid header length"),
79            TensorNotFound(name) => write!(f, "tensor `{name}` not found"),
80            TensorInvalidInfo => write!(f, "invalid shape, data type, or offset for tensor"),
81            InvalidOffset(name) => write!(f, "invalid offset for tensor `{name}`"),
82            #[cfg(feature = "std")]
83            IoError(error) => write!(f, "I/O error: {error}"),
84            InvalidTensorView(dtype, shape, n_bytes) => {
85                write!(f, "tensor of type {dtype} and shape (")?;
86                for (i, &dim) in shape.iter().enumerate() {
87                    write!(f, "{sep}{dim}", sep = if i == 0 { "" } else { ", " })?;
88                }
89                write!(f, ") can't be created from {n_bytes} bytes")
90            }
91            MetadataIncompleteBuffer => write!(f, "incomplete metadata, file not fully covered"),
92            ValidationOverflow => write!(f, "overflow computing buffer size from shape and/or element type"),
93            MisalignedSlice => write!(f, "The slice is slicing for subbytes dtypes, and the slice does not end up at a byte boundary, this is invalid.")
94        }
95    }
96}
97
98#[cfg(not(feature = "std"))]
99impl core::error::Error for SafeTensorError {
100    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
101        match self {
102            SafeTensorError::InvalidHeader(source) => Some(source),
103            SafeTensorError::JsonError(source) => Some(source),
104            SafeTensorError::InvalidHeaderDeserialization(source) => Some(source),
105            _ => None,
106        }
107    }
108}
109
110#[cfg(feature = "std")]
111impl std::error::Error for SafeTensorError {
112    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
113        match self {
114            SafeTensorError::InvalidHeader(source) => Some(source),
115            SafeTensorError::JsonError(source) => Some(source),
116            SafeTensorError::InvalidHeaderDeserialization(source) => Some(source),
117            SafeTensorError::IoError(source) => Some(source),
118            _ => None,
119        }
120    }
121}
122
123struct PreparedData {
124    n: u64,
125    header_bytes: Vec<u8>,
126    offset: usize,
127}
128
129/// The trait necessary to enable safetensors to serialize a tensor
130/// If you have an owned tensor like this:
131///
132/// ```rust
133/// use safetensors::tensor::{View, Dtype};
134/// use std::borrow::Cow;
135/// struct Tensor{ dtype: MyDtype, shape: Vec<usize>, data: Vec<u8>}
136///
137/// # type MyDtype = Dtype;
138/// impl<'data> View for &'data Tensor{
139///    fn dtype(&self) -> Dtype{
140///        self.dtype.into()
141///    }
142///    fn shape(&self) -> &[usize]{
143///         &self.shape
144///    }
145///    fn data(&self) -> Cow<'_, [u8]>{
146///        (&self.data).into()
147///    }
148///    fn data_len(&self) -> usize{
149///        self.data.len()
150///    }
151/// }
152/// ```
153///
154/// For a borrowed tensor:
155///
156/// ```rust
157/// use safetensors::tensor::{View, Dtype};
158/// use std::borrow::Cow;
159/// struct Tensor<'data>{ dtype: MyDtype, shape: Vec<usize>, data: &'data[u8]}
160///
161/// # type MyDtype = Dtype;
162/// impl<'data> View for Tensor<'data>{
163///    fn dtype(&self) -> Dtype{
164///        self.dtype.into()
165///    }
166///    fn shape(&self) -> &[usize]{
167///         &self.shape
168///    }
169///    fn data(&self) -> Cow<'_, [u8]>{
170///        self.data.into()
171///    }
172///    fn data_len(&self) -> usize{
173///        self.data.len()
174///    }
175/// }
176/// ```
177///
178/// Now if you have some unknown buffer that could be on GPU for instance,
179/// you can implement the trait to return an owned local buffer containing the data
180/// on CPU (needed to write on disk)
181/// ```rust
182/// use safetensors::tensor::{View, Dtype};
183/// use std::borrow::Cow;
184///
185/// # type MyDtype = Dtype;
186/// # type OpaqueGpu = Vec<u8>;
187/// struct Tensor{ dtype: MyDtype, shape: Vec<usize>, data: OpaqueGpu }
188///
189/// impl View for Tensor{
190///    fn dtype(&self) -> Dtype{
191///        self.dtype.into()
192///    }
193///    fn shape(&self) -> &[usize]{
194///         &self.shape
195///    }
196///    fn data(&self) -> Cow<'_, [u8]>{
197///        // This copies data from GPU to CPU.
198///        let data: Vec<u8> = self.data.to_vec();
199///        data.into()
200///    }
201///    fn data_len(&self) -> usize{
202///        let n: usize = self.shape.iter().product();
203///        let bytes_per_element = self.dtype.size();
204///        n * bytes_per_element
205///    }
206/// }
207/// ```
208pub trait View {
209    /// The `Dtype` of the tensor
210    fn dtype(&self) -> Dtype;
211    /// The shape of the tensor
212    fn shape(&self) -> &[usize];
213    /// The data of the tensor
214    fn data(&self) -> Cow<'_, [u8]>;
215    /// The length of the data, in bytes.
216    /// This is necessary as this might be faster to get than `data().len()`
217    /// for instance for tensors residing in GPU.
218    fn data_len(&self) -> usize;
219}
220
221fn prepare<S, V, I>(
222    data: I,
223    data_info: Option<HashMap<String, String>>,
224) -> Result<(PreparedData, Vec<V>), SafeTensorError>
225where
226    S: AsRef<str> + Ord + Display,
227    V: View,
228    I: IntoIterator<Item = (S, V)>,
229{
230    // Make sure we're sorting by descending dtype alignment
231    // Then by name
232    let mut data: Vec<_> = data.into_iter().collect();
233    data.sort_by(|(lname, left), (rname, right)| {
234        right.dtype().cmp(&left.dtype()).then(lname.cmp(rname))
235    });
236
237    let mut tensors: Vec<V> = Vec::with_capacity(data.len());
238    let mut hmetadata = Vec::with_capacity(data.len());
239    let mut offset = 0;
240
241    for (name, tensor) in data {
242        let n = tensor.data_len();
243        let tensor_info = TensorInfo {
244            dtype: tensor.dtype(),
245            shape: tensor.shape().to_vec(),
246            data_offsets: (offset, offset + n),
247        };
248        offset += n;
249        hmetadata.push((name.to_string(), tensor_info));
250        tensors.push(tensor);
251    }
252
253    let metadata: Metadata = Metadata::new(data_info, hmetadata)?;
254    let mut metadata_buf = serde_json::to_string(&metadata)?.into_bytes();
255
256    // Force alignment to 8 bytes.
257    let aligned_metadata_len = metadata_buf.len().next_multiple_of(N_LEN);
258    metadata_buf.resize(aligned_metadata_len, b' ');
259
260    Ok((
261        PreparedData {
262            n: aligned_metadata_len as u64,
263            header_bytes: metadata_buf,
264            offset,
265        },
266        tensors,
267    ))
268}
269
270/// Serialize to an owned byte buffer the dictionnary of tensors.
271pub fn serialize<
272    S: AsRef<str> + Ord + core::fmt::Display,
273    V: View,
274    I: IntoIterator<Item = (S, V)>,
275>(
276    data: I,
277    data_info: Option<HashMap<String, String>>,
278) -> Result<Vec<u8>, SafeTensorError> {
279    let (
280        PreparedData {
281            n,
282            header_bytes,
283            offset,
284        },
285        tensors,
286    ) = prepare(data, data_info)?;
287
288    if n > MAX_HEADER_SIZE as u64 {
289        return Err(SafeTensorError::HeaderTooLarge);
290    }
291
292    let expected_size = N_LEN + header_bytes.len() + offset;
293    let mut buffer: Vec<u8> = Vec::with_capacity(expected_size);
294    buffer.extend(n.to_le_bytes());
295    buffer.extend(header_bytes);
296
297    for tensor in tensors {
298        buffer.extend(tensor.data().as_ref());
299    }
300
301    Ok(buffer)
302}
303
304/// Serialize to a regular file the dictionnary of tensors.
305/// Writing directly to file reduces the need to allocate the whole amount to
306/// memory.
307#[cfg(feature = "std")]
308pub fn serialize_to_file<S, V, I>(
309    data: I,
310    data_info: Option<HashMap<String, String>>,
311    filename: &std::path::Path,
312) -> Result<(), SafeTensorError>
313where
314    S: AsRef<str> + Ord + Display,
315    V: View,
316    I: IntoIterator<Item = (S, V)>,
317{
318    let (
319        PreparedData {
320            n, header_bytes, ..
321        },
322        tensors,
323    ) = prepare(data, data_info)?;
324
325    if n > MAX_HEADER_SIZE as u64 {
326        return Err(SafeTensorError::HeaderTooLarge);
327    }
328
329    let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?);
330    f.write_all(n.to_le_bytes().as_ref())?;
331    f.write_all(&header_bytes)?;
332
333    for tensor in tensors {
334        f.write_all(tensor.data().as_ref())?;
335    }
336
337    f.flush()?;
338
339    Ok(())
340}
341
342/// A structure owning some metadata to lookup tensors on a shared `data`
343/// byte-buffer (not owned).
344#[derive(Debug)]
345pub struct SafeTensors<'data> {
346    metadata: Metadata,
347    data: &'data [u8],
348}
349
350impl<'data> SafeTensors<'data> {
351    /// Given a byte-buffer representing the whole safetensor file
352    /// parses the header, and returns the size of the header + the parsed data.
353    pub fn read_metadata(buffer: &'data [u8]) -> Result<(usize, Metadata), SafeTensorError> {
354        let buffer_len = buffer.len();
355        let Some(header_size_bytes) = buffer.get(..N_LEN) else {
356            return Err(SafeTensorError::HeaderTooSmall);
357        };
358        let arr: [u8; N_LEN] = header_size_bytes
359            .try_into()
360            .expect("this can't fail due to how `header_size_bytes` is defined above");
361        let n: usize = u64::from_le_bytes(arr)
362            .try_into()
363            .map_err(|_| SafeTensorError::HeaderTooLarge)?;
364
365        if n > MAX_HEADER_SIZE {
366            return Err(SafeTensorError::HeaderTooLarge);
367        }
368
369        let stop = n
370            .checked_add(N_LEN)
371            .ok_or(SafeTensorError::InvalidHeaderLength)?;
372
373        // the `.get(start..stop)` returns None if either index is out of bounds,
374        // so this implicitly also ensures that `stop <= buffer.len()`.
375        let Some(header_bytes) = buffer.get(N_LEN..stop) else {
376            return Err(SafeTensorError::InvalidHeaderLength);
377        };
378        let string = core::str::from_utf8(header_bytes).map_err(SafeTensorError::InvalidHeader)?;
379        // Assert the string starts with {
380        // NOTE: Add when we move to 0.4.0
381        // if !string.starts_with('{') {
382        //     return Err(SafeTensorError::InvalidHeaderStart);
383        // }
384        let metadata: HashMetadata =
385            serde_json::from_str(string).map_err(SafeTensorError::InvalidHeaderDeserialization)?;
386        let metadata: Metadata = metadata.try_into()?;
387        let buffer_end = metadata.validate()?;
388        if buffer_end + N_LEN + n != buffer_len {
389            return Err(SafeTensorError::MetadataIncompleteBuffer);
390        }
391
392        Ok((n, metadata))
393    }
394
395    /// Given a byte-buffer representing the whole safetensor file
396    /// parses it and returns the Deserialized form (No Tensor allocation).
397    ///
398    /// ```
399    /// use safetensors::SafeTensors;
400    /// use memmap2::MmapOptions;
401    /// use std::fs::File;
402    ///
403    /// let filename = "model.safetensors";
404    /// # use std::io::Write;
405    /// # let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
406    /// # File::create(filename).unwrap().write(serialized).unwrap();
407    /// let file = File::open(filename).unwrap();
408    /// let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
409    /// let tensors = SafeTensors::deserialize(&buffer).unwrap();
410    /// let tensor = tensors
411    ///         .tensor("test")
412    ///         .unwrap();
413    /// ```
414    pub fn deserialize(buffer: &'data [u8]) -> Result<Self, SafeTensorError> {
415        let (n, metadata) = SafeTensors::read_metadata(buffer)?;
416        let data = &buffer[N_LEN + n..];
417        Ok(Self { metadata, data })
418    }
419
420    /// Returns the tensors contained within the SafeTensors.
421    /// The tensors returned are merely views and the data is not owned by this
422    /// structure.
423    pub fn tensors(&self) -> Vec<(String, TensorView<'data>)> {
424        let mut tensors = Vec::with_capacity(self.metadata.index_map.len());
425        for (name, &index) in &self.metadata.index_map {
426            let info = &self.metadata.tensors[index];
427            let tensorview = TensorView {
428                dtype: info.dtype,
429                shape: info.shape.clone(),
430                data: &self.data[info.data_offsets.0..info.data_offsets.1],
431            };
432            tensors.push((name.to_string(), tensorview));
433        }
434        tensors
435    }
436
437    /// Returns an iterator over the tensors contained within the SafeTensors.
438    /// The tensors returned are merely views and the data is not owned by this
439    /// structure.
440    pub fn iter(&self) -> impl Iterator<Item = (&str, TensorView<'data>)> {
441        self.metadata.index_map.iter().map(|(name, &idx)| {
442            let info = &self.metadata.tensors[idx];
443            (
444                name.as_str(),
445                TensorView {
446                    dtype: info.dtype,
447                    shape: info.shape.clone(),
448                    data: &self.data[info.data_offsets.0..info.data_offsets.1],
449                },
450            )
451        })
452    }
453
454    /// Allow the user to get a specific tensor within the SafeTensors.
455    /// The tensor returned is merely a view and the data is not owned by this
456    /// structure.
457    pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'data>, SafeTensorError> {
458        let &index = self
459            .metadata
460            .index_map
461            .get(tensor_name)
462            .ok_or_else(|| SafeTensorError::TensorNotFound(tensor_name.to_string()))?;
463
464        let info = self
465            .metadata
466            .tensors
467            .get(index)
468            .ok_or_else(|| SafeTensorError::TensorNotFound(tensor_name.to_string()))?;
469
470        Ok(TensorView {
471            dtype: info.dtype,
472            shape: info.shape.clone(),
473            data: &self.data[info.data_offsets.0..info.data_offsets.1],
474        })
475    }
476
477    /// Return the names of the tensors within the SafeTensors.
478    /// These are used as keys to access to the actual tensors, that can be
479    /// retrieved using the tensor method.
480    pub fn names(&self) -> Vec<&'_ str> {
481        self.metadata.index_map.keys().map(String::as_str).collect()
482    }
483
484    /// Return how many tensors are currently stored within the SafeTensors.
485    #[inline]
486    pub fn len(&self) -> usize {
487        self.metadata.tensors.len()
488    }
489
490    /// Indicate if the SafeTensors contains or not any tensor.
491    #[inline]
492    pub fn is_empty(&self) -> bool {
493        self.metadata.tensors.is_empty()
494    }
495}
496
497/// The stuct representing the header of safetensor files which allow
498/// indexing into the raw byte-buffer array and how to interpret it.
499#[derive(Debug, Clone)]
500pub struct Metadata {
501    metadata: Option<HashMap<String, String>>,
502    tensors: Vec<TensorInfo>,
503    index_map: HashMap<String, usize>,
504}
505
506/// Helper struct used only for serialization and deserialization
507#[derive(Serialize, Deserialize)]
508struct HashMetadata {
509    #[serde(skip_serializing_if = "Option::is_none")]
510    #[serde(rename = "__metadata__")]
511    metadata: Option<HashMap<String, String>>,
512    #[serde(flatten)]
513    tensors: HashMap<String, TensorInfo>,
514}
515
516impl TryFrom<HashMetadata> for Metadata {
517    type Error = SafeTensorError;
518    fn try_from(hashdata: HashMetadata) -> Result<Self, Self::Error> {
519        let (metadata, tensors) = (hashdata.metadata, hashdata.tensors);
520        let mut tensors: Vec<_> = tensors.into_iter().collect();
521        // We need to sort by offsets
522        // Previous versions might have a different ordering
523        // Than we expect (Not aligned ordered, but purely name ordered,
524        // or actually any order).
525        tensors.sort_by(|(_, left), (_, right)| left.data_offsets.cmp(&right.data_offsets));
526        Metadata::new(metadata, tensors)
527    }
528}
529
530impl<'de> Deserialize<'de> for Metadata {
531    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
532    where
533        D: Deserializer<'de>,
534    {
535        let hashdata: HashMetadata = HashMetadata::deserialize(deserializer)?;
536
537        let metadata: Metadata = hashdata.try_into().map_err(serde::de::Error::custom)?;
538        Ok(metadata)
539    }
540}
541
542impl Serialize for Metadata {
543    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
544    where
545        S: Serializer,
546    {
547        let mut names = vec![""; self.index_map.len()];
548        for (name, &index) in &self.index_map {
549            names[index] = name;
550        }
551
552        let length = self.metadata.as_ref().map_or(0, HashMap::len);
553        let mut map = serializer.serialize_map(Some(self.tensors.len() + length))?;
554
555        if let Some(metadata) = &self.metadata {
556            map.serialize_entry("__metadata__", metadata)?;
557        }
558
559        for (name, info) in names.iter().zip(&self.tensors) {
560            map.serialize_entry(name, info)?;
561        }
562
563        map.end()
564    }
565}
566
567impl Metadata {
568    /// Creates a new metadata structure.
569    /// May fail if there is incorrect data in the Tensor Info.
570    /// Notably the tensors need to be ordered by increasing data_offsets.
571    pub fn new(
572        metadata: Option<HashMap<String, String>>,
573        tensors: Vec<(String, TensorInfo)>,
574    ) -> Result<Self, SafeTensorError> {
575        let mut index_map = HashMap::with_capacity(tensors.len());
576
577        let tensors: Vec<_> = tensors
578            .into_iter()
579            .enumerate()
580            .map(|(index, (k, tensor))| {
581                index_map.insert(k, index);
582                tensor
583            })
584            .collect();
585
586        let metadata = Self {
587            metadata,
588            tensors,
589            index_map,
590        };
591        metadata.validate()?;
592        Ok(metadata)
593    }
594
595    fn validate(&self) -> Result<usize, SafeTensorError> {
596        let mut start = 0;
597        for (i, info) in self.tensors.iter().enumerate() {
598            let (s, e) = info.data_offsets;
599            if s != start || e < s {
600                let tensor_name = self
601                    .index_map
602                    .iter()
603                    .find_map(|(name, &index)| if index == i { Some(&name[..]) } else { None })
604                    .unwrap_or("no_tensor");
605                return Err(SafeTensorError::InvalidOffset(tensor_name.to_string()));
606            }
607
608            start = e;
609
610            let nelements: usize = info
611                .shape
612                .iter()
613                .copied()
614                .try_fold(1usize, usize::checked_mul)
615                .ok_or(SafeTensorError::ValidationOverflow)?;
616            let nbits = nelements
617                .checked_mul(info.dtype.bitsize())
618                .ok_or(SafeTensorError::ValidationOverflow)?;
619
620            if nbits % 8 != 0 {
621                return Err(SafeTensorError::MisalignedSlice);
622            }
623            let size = nbits
624                .checked_div(8)
625                .ok_or(SafeTensorError::ValidationOverflow)?;
626
627            if e - s != size {
628                return Err(SafeTensorError::TensorInvalidInfo);
629            }
630        }
631        Ok(start)
632    }
633
634    /// Gives back the tensor metadata
635    pub fn info(&self, name: &str) -> Option<&TensorInfo> {
636        let &index = self.index_map.get(name)?;
637        self.tensors.get(index)
638    }
639
640    /// Gives back the tensor metadata
641    pub fn tensors(&self) -> HashMap<String, &TensorInfo> {
642        self.index_map
643            .iter()
644            .map(|(tensor_name, &index)| (tensor_name.clone(), &self.tensors[index]))
645            .collect()
646    }
647
648    /// Gives back the tensor names ordered by offset
649    pub fn offset_keys(&self) -> Vec<String> {
650        let mut index_vec: Vec<_> = self.index_map.iter().collect();
651        index_vec.sort_by_key(|a| a.1);
652        index_vec.into_iter().map(|a| a.0.clone()).collect()
653    }
654
655    /// Gives the size of the content buffer in bytes.
656    pub fn data_len(&self) -> usize {
657        if let Some(tensor) = self.tensors.last() {
658            tensor.data_offsets.1
659        } else {
660            0
661        }
662    }
663
664    /// Gives back the tensor metadata
665    pub fn metadata(&self) -> &Option<HashMap<String, String>> {
666        &self.metadata
667    }
668}
669
670/// A view of a Tensor within the file.
671/// Contains references to data within the full byte-buffer
672/// And is thus a readable view of a single tensor
673#[derive(Debug, PartialEq, Eq, Clone)]
674pub struct TensorView<'data> {
675    dtype: Dtype,
676    shape: Vec<usize>,
677    data: &'data [u8],
678}
679
680impl View for &TensorView<'_> {
681    fn dtype(&self) -> Dtype {
682        self.dtype
683    }
684
685    fn shape(&self) -> &[usize] {
686        &self.shape
687    }
688
689    fn data(&self) -> Cow<'_, [u8]> {
690        self.data.into()
691    }
692
693    fn data_len(&self) -> usize {
694        self.data.len()
695    }
696}
697
698impl View for TensorView<'_> {
699    fn dtype(&self) -> Dtype {
700        self.dtype
701    }
702
703    fn shape(&self) -> &[usize] {
704        &self.shape
705    }
706
707    fn data(&self) -> Cow<'_, [u8]> {
708        self.data.into()
709    }
710
711    fn data_len(&self) -> usize {
712        self.data.len()
713    }
714}
715
716impl<'data> TensorView<'data> {
717    /// Create new tensor view
718    pub fn new(
719        dtype: Dtype,
720        shape: Vec<usize>,
721        data: &'data [u8],
722    ) -> Result<Self, SafeTensorError> {
723        let n_elements: usize = shape.iter().product();
724
725        let nbits = n_elements * dtype.bitsize();
726        if nbits % 8 != 0 {
727            return Err(SafeTensorError::MisalignedSlice);
728        }
729        let size = nbits
730            .checked_div(8)
731            .ok_or(SafeTensorError::ValidationOverflow)?;
732
733        if data.len() != size {
734            Err(SafeTensorError::InvalidTensorView(dtype, shape, data.len()))
735        } else {
736            Ok(Self { dtype, shape, data })
737        }
738    }
739    /// The current tensor dtype
740    pub fn dtype(&self) -> Dtype {
741        self.dtype
742    }
743
744    /// The current tensor shape
745    pub fn shape(&self) -> &[usize] {
746        &self.shape
747    }
748
749    /// The current tensor byte-buffer
750    pub fn data(&self) -> &'data [u8] {
751        self.data
752    }
753
754    /// The various pieces of the data buffer according to the asked slice
755    pub fn sliced_data(
756        &'data self,
757        slices: &[TensorIndexer],
758    ) -> Result<SliceIterator<'data>, InvalidSlice> {
759        SliceIterator::new(self, slices)
760    }
761}
762
763/// A single tensor information.
764/// Endianness is assumed to be little endian
765/// Ordering is assumed to be 'C'.
766#[derive(Debug, Deserialize, Serialize, Clone)]
767pub struct TensorInfo {
768    /// The type of each element of the tensor
769    pub dtype: Dtype,
770    /// The shape of the tensor
771    pub shape: Vec<usize>,
772    /// The offsets to find the data within the byte-buffer array.
773    pub data_offsets: (usize, usize),
774}
775
776/// The various available dtypes. They MUST be in increasing alignment order
777#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
778#[non_exhaustive]
779pub enum Dtype {
780    /// Boolan type
781    BOOL,
782    /// MXF4 <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>_
783    F4,
784    /// MXF6 <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>_
785    #[allow(non_camel_case_types)]
786    F6_E2M3,
787    /// MXF6 <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>_
788    #[allow(non_camel_case_types)]
789    F6_E3M2,
790    /// Unsigned byte
791    U8,
792    /// Signed byte
793    I8,
794    /// FP8 <https://arxiv.org/pdf/2209.05433.pdf>_
795    #[allow(non_camel_case_types)]
796    F8_E5M2,
797    /// FP8 <https://arxiv.org/pdf/2209.05433.pdf>_
798    #[allow(non_camel_case_types)]
799    F8_E4M3,
800    /// F8_E8M0 <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>_
801    #[allow(non_camel_case_types)]
802    F8_E8M0,
803    /// Signed integer (16-bit)
804    I16,
805    /// Unsigned integer (16-bit)
806    U16,
807    /// Half-precision floating point
808    F16,
809    /// Brain floating point
810    BF16,
811    /// Signed integer (32-bit)
812    I32,
813    /// Unsigned integer (32-bit)
814    U32,
815    /// Floating point (32-bit)
816    F32,
817    /// Complex (32-bit parts)
818    C64,
819    /// Floating point (64-bit)
820    F64,
821    /// Signed integer (64-bit)
822    I64,
823    /// Unsigned integer (64-bit)
824    U64,
825}
826
827impl Dtype {
828    /// Gives out the size (in bits) of 1 element of this dtype.
829    pub fn bitsize(&self) -> usize {
830        match self {
831            Dtype::F4 => 4,
832            Dtype::F6_E3M2 => 6,
833            Dtype::F6_E2M3 => 6,
834            Dtype::BOOL => 8,
835            Dtype::U8 => 8,
836            Dtype::I8 => 8,
837            Dtype::F8_E5M2 => 8,
838            Dtype::F8_E4M3 => 8,
839            Dtype::F8_E8M0 => 8,
840            Dtype::I16 => 16,
841            Dtype::U16 => 16,
842            Dtype::I32 => 32,
843            Dtype::U32 => 32,
844            Dtype::I64 => 64,
845            Dtype::U64 => 64,
846            Dtype::F16 => 16,
847            Dtype::BF16 => 16,
848            Dtype::F32 => 32,
849            Dtype::F64 => 64,
850            Dtype::C64 => 64,
851        }
852    }
853    /// Gives out the size (in bytes) of 1 element of this dtype.
854    #[deprecated(
855        since = "0.6.0",
856        note = "Use `bitsize` instead as some elements have smaller than a full byte of width"
857    )]
858    pub fn size(&self) -> usize {
859        self.bitsize() / 8
860    }
861}
862
863impl Display for Dtype {
864    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
865        f.write_str(match *self {
866            Dtype::F4 => "F4",
867            Dtype::F6_E2M3 => "F6_E2M3",
868            Dtype::F6_E3M2 => "F6_E3M2",
869            Dtype::BOOL => "BOOL",
870            Dtype::I8 => "I8",
871            Dtype::U8 => "U8",
872            Dtype::F8_E5M2 => "F8_E5M2",
873            Dtype::F8_E4M3 => "F8_E4M3",
874            Dtype::F8_E8M0 => "F8_E8M0",
875            Dtype::I16 => "I16",
876            Dtype::U16 => "U16",
877            Dtype::I32 => "I32",
878            Dtype::U32 => "U32",
879            Dtype::I64 => "I64",
880            Dtype::U64 => "U64",
881            Dtype::F16 => "F16",
882            Dtype::BF16 => "BF16",
883            Dtype::F32 => "F32",
884            Dtype::F64 => "F64",
885            Dtype::C64 => "C64",
886        })
887    }
888}
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893    use crate::slice::IndexOp;
894    use proptest::prelude::*;
895    #[cfg(not(feature = "std"))]
896    extern crate std;
897    use std::io::Write;
898
899    const MAX_DIMENSION: usize = 8;
900    const MAX_SIZE: usize = 8;
901    const MAX_TENSORS: usize = 8;
902
903    fn arbitrary_dtype() -> impl Strategy<Value = Dtype> {
904        prop_oneof![
905            Just(Dtype::BOOL),
906            Just(Dtype::F4),
907            Just(Dtype::F6_E3M2),
908            Just(Dtype::F6_E2M3),
909            Just(Dtype::F8_E5M2),
910            Just(Dtype::F8_E4M3),
911            Just(Dtype::U8),
912            Just(Dtype::I8),
913            Just(Dtype::I16),
914            Just(Dtype::U16),
915            Just(Dtype::I32),
916            Just(Dtype::U32),
917            Just(Dtype::I64),
918            Just(Dtype::U64),
919            Just(Dtype::F16),
920            Just(Dtype::BF16),
921            Just(Dtype::F32),
922            Just(Dtype::F64),
923            Just(Dtype::C64),
924        ]
925    }
926
927    fn arbitrary_shape() -> impl Strategy<Value = Vec<usize>> {
928        // We do not allow empty shapes or 0 sizes.
929        (1..MAX_DIMENSION).prop_flat_map(|length| prop::collection::vec(1..MAX_SIZE, length))
930    }
931
932    fn arbitrary_metadata() -> impl Strategy<Value = Metadata> {
933        // We generate at least one tensor.
934        (1..MAX_TENSORS)
935            .prop_flat_map(|size| {
936                // Returns a strategy generating `size` data types and shapes.
937                (
938                    prop::collection::vec(arbitrary_dtype(), size),
939                    prop::collection::vec(arbitrary_shape(), size),
940                )
941            })
942            .prop_filter_map("Misaligned slices", |(dtypes, shapes)| {
943                // Returns a valid metadata object for a random (length, dtypes, shapes) triple.
944                let mut start = 0;
945                let tensors: Vec<TensorInfo> = dtypes
946                    .iter()
947                    .zip(shapes)
948                    .flat_map(|(dtype, shape)| {
949                        // This cannot overflow because the size of
950                        // the vector and elements are so small.
951                        let bitlength: usize = shape.iter().product::<usize>() * dtype.bitsize();
952                        if bitlength % 8 != 0 {
953                            return None;
954                        }
955                        let length = bitlength.div_ceil(8);
956                        let end = start + length;
957                        let tensor = TensorInfo {
958                            dtype: *dtype,
959                            shape,
960                            data_offsets: (start, end),
961                        };
962                        start = end;
963                        Some(tensor)
964                    })
965                    .collect();
966                let index_map = (0..tensors.len())
967                    .map(|index| (format!("t.{index}"), index))
968                    .collect();
969                if tensors.is_empty() {
970                    None
971                } else {
972                    Some(Metadata {
973                        metadata: None,
974                        tensors,
975                        index_map,
976                    })
977                }
978            })
979    }
980
981    /// This method returns the size of the data corresponding to the metadata. It
982    /// assumes that `metadata` contains at least one tensor, and that tensors are
983    /// ordered by offset in `metadata.tensors`.
984    ///
985    /// # Panics
986    ///
987    /// This method will panic if `metadata` does not contain any tensors.
988    fn data_size(metadata: &Metadata) -> usize {
989        metadata.tensors.last().unwrap().data_offsets.1
990    }
991
992    proptest! {
993        #![proptest_config(ProptestConfig::with_cases(20))]
994
995        #[test]
996        fn test_indexing(metadata in arbitrary_metadata()) {
997            let data = vec![0u8; data_size(&metadata)];
998            let tensors = SafeTensors { metadata, data: &data };
999            for name in tensors.names() {
1000                assert!(tensors.tensor(name).is_ok());
1001            }
1002        }
1003        #[test]
1004        fn test_roundtrip(metadata in arbitrary_metadata()) {
1005            let data: Vec<u8> = (0..data_size(&metadata)).map(|x| x as u8).collect();
1006            let before = SafeTensors { metadata, data: &data };
1007            let tensors = before.tensors();
1008            let bytes = serialize(tensors.iter().map(|(name, view)| (name.to_string(), view)), None).unwrap();
1009
1010            let after = SafeTensors::deserialize(&bytes).unwrap();
1011
1012            // Check that the tensors are the same after deserialization.
1013            assert_eq!(before.names().len(), after.names().len());
1014            for name in before.names() {
1015                let tensor_before = before.tensor(name).unwrap();
1016                let tensor_after = after.tensor(name).unwrap();
1017                assert_eq!(tensor_after.data().as_ptr() as usize % tensor_after.dtype().bitsize().div_ceil(8), 0);
1018                assert_eq!(tensor_before, tensor_after);
1019            }
1020        }
1021    }
1022
1023    #[test]
1024    fn test_serialization() {
1025        let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
1026            .into_iter()
1027            .flat_map(|f| f.to_le_bytes())
1028            .collect();
1029        let shape = vec![1, 2, 3];
1030        let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
1031        let metadata: HashMap<String, TensorView> =
1032            [("attn.0".to_string(), attn_0)].into_iter().collect();
1033
1034        let out = serialize(&metadata, None).unwrap();
1035        assert_eq!(
1036            out,
1037            [
1038                64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
1039                116, 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34,
1040                58, 91, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115,
1041                101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 0, 0, 0, 0, 0, 0, 128, 63,
1042                0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64
1043            ]
1044        );
1045        let _parsed = SafeTensors::deserialize(&out).unwrap();
1046    }
1047
1048    #[test]
1049    fn test_serialization_fp4() {
1050        let data: Vec<u8> = vec![0u8];
1051        let shape = vec![1, 2];
1052        let attn_0 = TensorView::new(Dtype::F4, shape, &data).unwrap();
1053        let metadata: HashMap<String, TensorView> =
1054            [("attn.0".to_string(), attn_0)].into_iter().collect();
1055
1056        let out = serialize(&metadata, None).unwrap();
1057        assert_eq!(
1058            out,
1059            [
1060                64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
1061                116, 121, 112, 101, 34, 58, 34, 70, 52, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
1062                91, 49, 44, 50, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115, 101, 116,
1063                115, 34, 58, 91, 48, 44, 49, 93, 125, 125, 32, 32, 32, 32, 0
1064            ]
1065        );
1066        let parsed = SafeTensors::deserialize(&out).unwrap();
1067        let tensors: HashMap<_, _> = parsed.tensors().into_iter().collect();
1068        assert_eq!(tensors, metadata);
1069    }
1070
1071    #[test]
1072    fn test_serialization_fp4_misaligned() {
1073        let data: Vec<u8> = vec![0u8, 1u8];
1074        let shape = vec![1, 3];
1075        let attn_0 = TensorView::new(Dtype::F4, shape, &data);
1076        assert!(matches!(attn_0, Err(SafeTensorError::MisalignedSlice)));
1077    }
1078
1079    #[test]
1080    fn test_serialization_fp4_invalid() {
1081        let data: Vec<u8> = vec![0u8, 1u8];
1082        let shape = vec![1, 2];
1083        let attn_0 = TensorView::new(Dtype::F4, shape, &data);
1084        assert!(matches!(
1085            attn_0,
1086            Err(SafeTensorError::InvalidTensorView(Dtype::F4, _shape, _size))
1087        ));
1088    }
1089
1090    #[test]
1091    fn test_empty() {
1092        let tensors: HashMap<String, TensorView> = HashMap::new();
1093
1094        let out = serialize(&tensors, None).unwrap();
1095        assert_eq!(
1096            out,
1097            [8, 0, 0, 0, 0, 0, 0, 0, 123, 125, 32, 32, 32, 32, 32, 32]
1098        );
1099        let _parsed = SafeTensors::deserialize(&out).unwrap();
1100
1101        let metadata: Option<HashMap<String, String>> = Some(
1102            [("framework".to_string(), "pt".to_string())]
1103                .into_iter()
1104                .collect(),
1105        );
1106        let out = serialize(&tensors, metadata).unwrap();
1107        assert_eq!(
1108            out,
1109            [
1110                40, 0, 0, 0, 0, 0, 0, 0, 123, 34, 95, 95, 109, 101, 116, 97, 100, 97, 116, 97, 95,
1111                95, 34, 58, 123, 34, 102, 114, 97, 109, 101, 119, 111, 114, 107, 34, 58, 34, 112,
1112                116, 34, 125, 125, 32, 32, 32, 32, 32
1113            ]
1114        );
1115        let _parsed = SafeTensors::deserialize(&out).unwrap();
1116    }
1117
1118    #[test]
1119    fn test_serialization_forced_alignement() {
1120        let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
1121            .into_iter()
1122            .flat_map(|f| f.to_le_bytes())
1123            .collect();
1124        let shape = vec![1, 1, 2, 3];
1125        let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
1126        let metadata: HashMap<String, TensorView> =
1127            // Smaller string to force misalignment compared to previous test.
1128            [("attn0".to_string(), attn_0)].into_iter().collect();
1129
1130        let out = serialize(&metadata, None).unwrap();
1131        assert_eq!(
1132            out,
1133            [
1134                72, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 48, 34, 58, 123, 34, 100, 116,
1135                121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
1136                91, 49, 44, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102,
1137                // All the 32 are forcing alignement of the tensor data for casting to f32, f64
1138                // etc..
1139                115, 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 32, 32, 32, 32, 32,
1140                32, 32, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0,
1141                160, 64
1142            ],
1143        );
1144        let parsed = SafeTensors::deserialize(&out).unwrap();
1145        let tensor = parsed.tensor("attn0").unwrap();
1146        assert_eq!(
1147            tensor.data().as_ptr() as usize % tensor.dtype().bitsize().div_ceil(8),
1148            0
1149        );
1150    }
1151
1152    #[test]
1153    fn test_slicing() {
1154        let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
1155            .into_iter()
1156            .flat_map(|f| f.to_le_bytes())
1157            .collect();
1158        let attn_0 = TensorView {
1159            dtype: Dtype::F32,
1160            shape: vec![1, 2, 3],
1161            data: &data,
1162        };
1163        let metadata: HashMap<String, TensorView> =
1164            [("attn.0".to_string(), attn_0)].into_iter().collect();
1165
1166        let out = serialize(&metadata, None).unwrap();
1167        let parsed = SafeTensors::deserialize(&out).unwrap();
1168
1169        let out_buffer: Vec<u8> = parsed
1170            .tensor("attn.0")
1171            .unwrap()
1172            .slice((.., ..1))
1173            .unwrap()
1174            .flat_map(|b| b.to_vec())
1175            .collect();
1176        assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64]);
1177        assert_eq!(
1178            out_buffer,
1179            vec![0.0f32, 1.0, 2.0]
1180                .into_iter()
1181                .flat_map(|f| f.to_le_bytes())
1182                .collect::<Vec<_>>()
1183        );
1184        let out_buffer: Vec<u8> = parsed
1185            .tensor("attn.0")
1186            .unwrap()
1187            .slice((.., .., ..1))
1188            .unwrap()
1189            .flat_map(|b| b.to_vec())
1190            .collect();
1191        assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 64, 64]);
1192        assert_eq!(
1193            out_buffer,
1194            vec![0.0f32, 3.0]
1195                .into_iter()
1196                .flat_map(|f| f.to_le_bytes())
1197                .collect::<Vec<_>>()
1198        );
1199    }
1200
1201    #[test]
1202    fn test_gpt2() {
1203        gpt2_like(12, "gpt2");
1204    }
1205
1206    #[test]
1207    fn test_gpt2_tiny() {
1208        gpt2_like(6, "gpt2_tiny");
1209    }
1210
1211    fn gpt2_like(n_heads: usize, model_id: &str) {
1212        let mut tensors_desc = vec![
1213            ("wte".to_string(), vec![50257, 768]),
1214            ("wpe".to_string(), vec![1024, 768]),
1215        ];
1216        for i in 0..n_heads {
1217            tensors_desc.push((format!("h.{i}.ln_1.weight"), vec![768]));
1218            tensors_desc.push((format!("h.{i}.ln_1.bias"), vec![768]));
1219            tensors_desc.push((format!("h.{i}.attn.bias"), vec![1, 1, 1024, 1024]));
1220            tensors_desc.push((format!("h.{i}.attn.c_attn.weight"), vec![768, 2304]));
1221            tensors_desc.push((format!("h.{i}.attn.c_attn.bias"), vec![2304]));
1222            tensors_desc.push((format!("h.{i}.attn.c_proj.weight"), vec![768, 768]));
1223            tensors_desc.push((format!("h.{i}.attn.c_proj.bias"), vec![768]));
1224            tensors_desc.push((format!("h.{i}.ln_2.weight"), vec![768]));
1225            tensors_desc.push((format!("h.{i}.ln_2.bias"), vec![768]));
1226            tensors_desc.push((format!("h.{i}.mlp.c_fc.weight"), vec![768, 3072]));
1227            tensors_desc.push((format!("h.{i}.mlp.c_fc.bias"), vec![3072]));
1228            tensors_desc.push((format!("h.{i}.mlp.c_proj.weight"), vec![3072, 768]));
1229            tensors_desc.push((format!("h.{i}.mlp.c_proj.bias"), vec![768]));
1230        }
1231        tensors_desc.push(("ln_f.weight".to_string(), vec![768]));
1232        tensors_desc.push(("ln_f.bias".to_string(), vec![768]));
1233
1234        let dtype = Dtype::F32;
1235        let nbits: usize = tensors_desc
1236            .iter()
1237            .map(|(_, shape)| shape.iter().product::<usize>())
1238            .sum::<usize>()
1239            * dtype.bitsize();
1240        if nbits % 8 != 0 {
1241            panic!("Misaligned slice");
1242        }
1243        let n = nbits
1244            .checked_div(8)
1245            .ok_or(SafeTensorError::ValidationOverflow)
1246            .unwrap(); // 4
1247        let all_data = vec![0; n];
1248        let mut metadata = HashMap::with_capacity(tensors_desc.len());
1249        let mut offset = 0;
1250        for (name, shape) in tensors_desc {
1251            let n: usize = shape.iter().product();
1252            let buffer = &all_data[offset..offset + (n * dtype.bitsize()) / 8];
1253            let tensor = TensorView::new(dtype, shape, buffer).unwrap();
1254            metadata.insert(name, tensor);
1255            offset += n;
1256        }
1257
1258        let filename = format!("./out_{model_id}.safetensors");
1259
1260        let out = serialize(&metadata, None).unwrap();
1261        std::fs::write(&filename, out).unwrap();
1262        let raw = std::fs::read(&filename).unwrap();
1263        let _deserialized = SafeTensors::deserialize(&raw).unwrap();
1264        std::fs::remove_file(&filename).unwrap();
1265
1266        // File api
1267        #[cfg(feature = "std")]
1268        {
1269            serialize_to_file(&metadata, None, std::path::Path::new(&filename)).unwrap();
1270            let raw = std::fs::read(&filename).unwrap();
1271            let _deserialized = SafeTensors::deserialize(&raw).unwrap();
1272            std::fs::remove_file(&filename).unwrap();
1273        }
1274    }
1275
1276    #[test]
1277    fn test_empty_shapes_allowed() {
1278        let serialized = b"8\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[],\"data_offsets\":[0,4]}}\x00\x00\x00\x00";
1279
1280        let loaded = SafeTensors::deserialize(serialized).unwrap();
1281        assert_eq!(loaded.names(), vec!["test"]);
1282        let tensor = loaded.tensor("test").unwrap();
1283        assert!(tensor.shape().is_empty());
1284        assert_eq!(tensor.dtype(), Dtype::I32);
1285        // 4 bytes
1286        assert_eq!(tensor.data(), b"\0\0\0\0");
1287    }
1288
1289    #[test]
1290    fn test_deserialization() {
1291        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1292
1293        let loaded = SafeTensors::deserialize(serialized).unwrap();
1294
1295        assert_eq!(loaded.len(), 1);
1296        assert_eq!(loaded.names(), vec!["test"]);
1297        let tensor = loaded.tensor("test").unwrap();
1298        assert_eq!(tensor.shape(), vec![2, 2]);
1299        assert_eq!(tensor.dtype(), Dtype::I32);
1300        // 16 bytes
1301        assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
1302    }
1303
1304    #[test]
1305    fn test_lifetimes() {
1306        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1307
1308        let tensor = {
1309            let loaded = SafeTensors::deserialize(serialized).unwrap();
1310            loaded.tensor("test").unwrap()
1311        };
1312
1313        assert_eq!(tensor.shape(), vec![2, 2]);
1314        assert_eq!(tensor.dtype(), Dtype::I32);
1315        // 16 bytes
1316        assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
1317    }
1318
1319    #[test]
1320    fn test_json_attack() {
1321        let mut tensors = HashMap::new();
1322        let dtype = Dtype::F32;
1323        let shape = vec![2, 2];
1324        let data_offsets = (0, 16);
1325        for i in 0..10 {
1326            tensors.insert(
1327                format!("weight_{i}"),
1328                TensorInfo {
1329                    dtype,
1330                    shape: shape.clone(),
1331                    data_offsets,
1332                },
1333            );
1334        }
1335
1336        let metadata = HashMetadata {
1337            metadata: None,
1338            tensors,
1339        };
1340        let serialized = serde_json::to_string(&metadata).unwrap();
1341        let serialized = serialized.as_bytes();
1342
1343        let n = serialized.len();
1344
1345        let filename = "out.safetensors";
1346        let mut f = std::io::BufWriter::new(std::fs::File::create(filename).unwrap());
1347        f.write_all(n.to_le_bytes().as_ref()).unwrap();
1348        f.write_all(serialized).unwrap();
1349        f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();
1350        f.flush().unwrap();
1351
1352        let reloaded = std::fs::read(filename).unwrap();
1353        match SafeTensors::deserialize(&reloaded) {
1354            Err(SafeTensorError::InvalidOffset(_)) => {
1355                // Yes we have the correct error, name of the tensor is random though
1356            }
1357            Err(err) => panic!("Unexpected error {err:?}"),
1358            Ok(_) => panic!("This should not be able to be deserialized"),
1359        }
1360    }
1361
1362    #[test]
1363    fn test_metadata_incomplete_buffer() {
1364        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00extra_bogus_data_for_polyglot_file";
1365
1366        match SafeTensors::deserialize(serialized) {
1367            Err(SafeTensorError::MetadataIncompleteBuffer) => {
1368                // Yes we have the correct error
1369            }
1370            _ => panic!("This should not be able to be deserialized"),
1371        }
1372
1373        // Missing data in the buffer
1374        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; // <--- missing 2 bytes
1375
1376        match SafeTensors::deserialize(serialized) {
1377            Err(SafeTensorError::MetadataIncompleteBuffer) => {
1378                // Yes we have the correct error
1379            }
1380            _ => panic!("This should not be able to be deserialized"),
1381        }
1382    }
1383
1384    #[test]
1385    fn test_header_too_large() {
1386        let serialized = b"<\x00\x00\x00\x00\xff\xff\xff{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1387
1388        match SafeTensors::deserialize(serialized) {
1389            Err(SafeTensorError::HeaderTooLarge) => {
1390                // Yes we have the correct error
1391            }
1392            _ => panic!("This should not be able to be deserialized"),
1393        }
1394    }
1395
1396    #[test]
1397    fn test_header_too_small() {
1398        let serialized = b"";
1399        match SafeTensors::deserialize(serialized) {
1400            Err(SafeTensorError::HeaderTooSmall) => {
1401                // Yes we have the correct error
1402            }
1403            _ => panic!("This should not be able to be deserialized"),
1404        }
1405    }
1406
1407    #[test]
1408    fn test_invalid_header_length() {
1409        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00";
1410        match SafeTensors::deserialize(serialized) {
1411            Err(SafeTensorError::InvalidHeaderLength) => {
1412                // Yes we have the correct error
1413            }
1414            _ => panic!("This should not be able to be deserialized"),
1415        }
1416    }
1417
1418    #[test]
1419    fn test_invalid_header_non_utf8() {
1420        let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00\xff";
1421        match SafeTensors::deserialize(serialized) {
1422            Err(SafeTensorError::InvalidHeader(_)) => {
1423                // Yes we have the correct error
1424            }
1425            _ => panic!("This should not be able to be deserialized"),
1426        }
1427    }
1428
1429    #[test]
1430    fn test_invalid_header_not_json() {
1431        let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00{";
1432        match SafeTensors::deserialize(serialized) {
1433            Err(SafeTensorError::InvalidHeaderDeserialization(_)) => {
1434                // Yes we have the correct error
1435            }
1436            _ => panic!("This should not be able to be deserialized"),
1437        }
1438    }
1439
1440    #[test]
1441    /// Test that the JSON header may be trailing-padded with JSON whitespace characters.
1442    fn test_whitespace_padded_header() {
1443        let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00{}\x0D\x20\x09\x0A";
1444        let loaded = SafeTensors::deserialize(serialized).unwrap();
1445        assert_eq!(loaded.len(), 0);
1446    }
1447
1448    // Reserver for 0.4.0
1449    // #[test]
1450    // /// Test that the JSON header must begin with a `{` character.
1451    // fn test_whitespace_start_padded_header_is_not_allowed() {
1452    //     let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00\x09\x0A{}\x0D\x20";
1453    //     match SafeTensors::deserialize(serialized) {
1454    //         Err(SafeTensorError::InvalidHeaderStart) => {
1455    //             // Correct error
1456    //         }
1457    //         _ => panic!("This should not be able to be deserialized"),
1458    //     }
1459    // }
1460
1461    #[test]
1462    fn test_zero_sized_tensor() {
1463        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,0],\"data_offsets\":[0, 0]}}";
1464        let loaded = SafeTensors::deserialize(serialized).unwrap();
1465
1466        assert_eq!(loaded.names(), vec!["test"]);
1467        let tensor = loaded.tensor("test").unwrap();
1468        assert_eq!(tensor.shape(), vec![2, 0]);
1469        assert_eq!(tensor.dtype(), Dtype::I32);
1470        assert_eq!(tensor.data(), b"");
1471    }
1472
1473    #[test]
1474    fn test_invalid_info() {
1475        let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0, 4]}}";
1476        match SafeTensors::deserialize(serialized) {
1477            Err(SafeTensorError::TensorInvalidInfo) => {
1478                // Yes we have the correct error
1479            }
1480            something => panic!("This should not be able to be deserialized got {something:?}"),
1481        }
1482    }
1483
1484    #[test]
1485    fn test_validation_overflow() {
1486        // u64::MAX =  18_446_744_073_709_551_615u64
1487        // Overflow the shape calculation.
1488        let serialized = b"O\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,18446744073709551614],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1489        match SafeTensors::deserialize(serialized) {
1490            Err(SafeTensorError::ValidationOverflow) => {
1491                // Yes we have the correct error
1492            }
1493            _ => panic!("This should not be able to be deserialized"),
1494        }
1495        // u64::MAX =  18_446_744_073_709_551_615u64
1496        // Overflow the num_elements * total shape.
1497        let serialized = b"N\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,9223372036854775807],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
1498        match SafeTensors::deserialize(serialized) {
1499            Err(SafeTensorError::ValidationOverflow) => {
1500                // Yes we have the correct error
1501            }
1502            _ => panic!("This should not be able to be deserialized"),
1503        }
1504    }
1505
1506    #[test]
1507    fn test_invalid_header_size_serialization() {
1508        let mut data_info = HashMap::<String, String>::new();
1509        let tensors: HashMap<String, TensorView> = HashMap::new();
1510
1511        // a char is 1 byte in utf-8, so we can just repeat 'a' to get large metadata
1512        let very_large_metadata = "a".repeat(MAX_HEADER_SIZE);
1513        data_info.insert("very_large_metadata".to_string(), very_large_metadata);
1514        match serialize(&tensors, Some(data_info)) {
1515            Err(SafeTensorError::HeaderTooLarge) => {
1516                // Yes we have the correct error
1517            }
1518            _ => panic!("This should not be able to be serialized"),
1519        }
1520    }
1521}