web_rwkv_converter/
lib.rs

1use std::{borrow::Cow, path::Path};
2
3use anyhow::Result;
4use half::{bf16, f16};
5use repugnant_pickle::{
6    RepugnantTorchTensor as TorchTensor, RepugnantTorchTensors as TorchTensors, TensorType,
7};
8use safetensors::{Dtype, View};
9
10struct Tensor {
11    name: String,
12    shape: Vec<usize>,
13    data: Vec<f16>,
14}
15
16impl View for Tensor {
17    fn dtype(&self) -> Dtype {
18        Dtype::F16
19    }
20
21    fn shape(&self) -> &[usize] {
22        &self.shape
23    }
24
25    fn data(&self) -> Cow<[u8]> {
26        Cow::Borrowed(bytemuck::cast_slice(&self.data))
27    }
28
29    fn data_len(&self) -> usize {
30        self.data.len() * self.dtype().size()
31    }
32}
33
34fn load_tensors<'a, 'b, 'c, 'd>(
35    data: &'a [u8],
36    torch: TorchTensors,
37    rename: impl IntoIterator<Item = (&'b str, &'c str)> + Clone + 'a,
38    transpose: impl IntoIterator<Item = &'d str> + Clone + 'a,
39) -> impl IntoIterator<Item = Tensor> + 'a {
40    torch.into_iter().map(move |tensor: TorchTensor| {
41        let name = rename
42            .clone()
43            .into_iter()
44            .fold(tensor.name, |name, (p, to)| name.replace(p, to));
45        let shape = tensor.shape;
46        let size: usize = shape.iter().product();
47        let bytes = size * tensor.tensor_type.size();
48
49        assert!(matches!(tensor.tensor_type, TensorType::BFloat16));
50        let start = tensor.absolute_offset as usize;
51        let end = start + bytes;
52        let data: &[bf16] = bytemuck::cast_slice(&data[start..end]);
53        let data: Vec<_> = data.iter().map(|x| f16::from_f32(x.to_f32())).collect();
54
55        if transpose.clone().into_iter().any(|p| name.contains(p)) {
56            let mut transposed = vec![f16::ZERO; data.len()];
57            let num_col = *shape.iter().nth_back(0).expect("should be at least 2d");
58            let num_row = *shape.iter().nth_back(1).expect("should be at least 2d");
59            let num_batch = *shape.iter().nth_back(2).unwrap_or(&1);
60            for b in 0..num_batch {
61                for i in 0..num_row {
62                    for j in 0..num_col {
63                        let from = b * num_col * num_row + i * num_col + j;
64                        let to = b * num_col * num_row + j * num_row + i;
65                        transposed[to] = data[from];
66                    }
67                }
68            }
69            let mut shape = shape;
70            *shape.iter_mut().nth_back(0).unwrap() = num_row;
71            *shape.iter_mut().nth_back(1).unwrap() = num_col;
72
73            println!("{name}\t{:?}\t(Transposed)", shape);
74            Tensor {
75                name,
76                shape,
77                data: transposed,
78            }
79        } else {
80            println!("{name}\t{:?}", shape);
81            Tensor { name, shape, data }
82        }
83    })
84}
85
86pub fn convert_safetensors<'a, 'b, 'c>(
87    input: impl AsRef<Path>,
88    data: &'a [u8],
89    output: impl AsRef<Path>,
90    rename: impl IntoIterator<Item = (&'b str, &'b str)> + Clone,
91    transpose: impl IntoIterator<Item = &'c str> + Clone,
92) -> Result<()> {
93    let torch = TorchTensors::new_from_file(input)?;
94    let tensors = load_tensors(data, torch, rename, transpose);
95
96    let data = tensors.into_iter().map(|tensor| {
97        let name = tensor.name.clone();
98        (name, tensor)
99    });
100
101    safetensors::serialize_to_file(data, &None, output.as_ref())?;
102    Ok(())
103}