web_rwkv_converter/
lib.rs1use std::{borrow::Cow, path::Path};
2
3use anyhow::Result;
4use half::{bf16, f16};
5use repugnant_pickle::{
6 RepugnantTorchTensor as TorchTensor, RepugnantTorchTensors as TorchTensors, TensorType,
7};
8use safetensors::{Dtype, View};
9
10struct Tensor {
11 name: String,
12 shape: Vec<usize>,
13 data: Vec<f16>,
14}
15
16impl View for Tensor {
17 fn dtype(&self) -> Dtype {
18 Dtype::F16
19 }
20
21 fn shape(&self) -> &[usize] {
22 &self.shape
23 }
24
25 fn data(&self) -> Cow<[u8]> {
26 Cow::Borrowed(bytemuck::cast_slice(&self.data))
27 }
28
29 fn data_len(&self) -> usize {
30 self.data.len() * self.dtype().size()
31 }
32}
33
34fn load_tensors<'a, 'b, 'c, 'd>(
35 data: &'a [u8],
36 torch: TorchTensors,
37 rename: impl IntoIterator<Item = (&'b str, &'c str)> + Clone + 'a,
38 transpose: impl IntoIterator<Item = &'d str> + Clone + 'a,
39) -> impl IntoIterator<Item = Tensor> + 'a {
40 torch.into_iter().map(move |tensor: TorchTensor| {
41 let name = rename
42 .clone()
43 .into_iter()
44 .fold(tensor.name, |name, (p, to)| name.replace(p, to));
45 let shape = tensor.shape;
46 let size: usize = shape.iter().product();
47 let bytes = size * tensor.tensor_type.size();
48
49 assert!(matches!(tensor.tensor_type, TensorType::BFloat16));
50 let start = tensor.absolute_offset as usize;
51 let end = start + bytes;
52 let data: &[bf16] = bytemuck::cast_slice(&data[start..end]);
53 let data: Vec<_> = data.iter().map(|x| f16::from_f32(x.to_f32())).collect();
54
55 if transpose.clone().into_iter().any(|p| name.contains(p)) {
56 let mut transposed = vec![f16::ZERO; data.len()];
57 let num_col = *shape.iter().nth_back(0).expect("should be at least 2d");
58 let num_row = *shape.iter().nth_back(1).expect("should be at least 2d");
59 let num_batch = *shape.iter().nth_back(2).unwrap_or(&1);
60 for b in 0..num_batch {
61 for i in 0..num_row {
62 for j in 0..num_col {
63 let from = b * num_col * num_row + i * num_col + j;
64 let to = b * num_col * num_row + j * num_row + i;
65 transposed[to] = data[from];
66 }
67 }
68 }
69 let mut shape = shape;
70 *shape.iter_mut().nth_back(0).unwrap() = num_row;
71 *shape.iter_mut().nth_back(1).unwrap() = num_col;
72
73 println!("{name}\t{:?}\t(Transposed)", shape);
74 Tensor {
75 name,
76 shape,
77 data: transposed,
78 }
79 } else {
80 println!("{name}\t{:?}", shape);
81 Tensor { name, shape, data }
82 }
83 })
84}
85
86pub fn convert_safetensors<'a, 'b, 'c>(
87 input: impl AsRef<Path>,
88 data: &'a [u8],
89 output: impl AsRef<Path>,
90 rename: impl IntoIterator<Item = (&'b str, &'b str)> + Clone,
91 transpose: impl IntoIterator<Item = &'c str> + Clone,
92) -> Result<()> {
93 let torch = TorchTensors::new_from_file(input)?;
94 let tensors = load_tensors(data, torch, rename, transpose);
95
96 let data = tensors.into_iter().map(|tensor| {
97 let name = tensor.name.clone();
98 (name, tensor)
99 });
100
101 safetensors::serialize_to_file(data, &None, output.as_ref())?;
102 Ok(())
103}