web_rwkv/tensor/
serialization.rs

1use std::{borrow::Cow, fmt, marker::PhantomData, sync::Arc};
2
3use serde::{
4    de::{DeserializeSeed, Error, SeqAccess, Visitor},
5    Deserialize, Deserializer, Serialize,
6};
7
8use super::{kind::Kind, shape::Shape, TensorCpu, TensorGpu};
9use crate::{context::Context, num::Scalar};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(bound(deserialize = "'de: 'a"))]
13struct TensorBlob<'a> {
14    shape: Shape,
15    #[serde(with = "serde_bytes")]
16    data: Cow<'a, [u8]>,
17}
18
19impl<T: Scalar> From<TensorCpu<T>> for TensorBlob<'_> {
20    fn from(value: TensorCpu<T>) -> Self {
21        let TensorCpu { shape, data, .. } = value;
22        let data = bytemuck::cast_slice(&data).to_vec().into();
23        Self { shape, data }
24    }
25}
26
27impl<T: Scalar> From<TensorBlob<'_>> for TensorCpu<T> {
28    fn from(value: TensorBlob) -> Self {
29        let TensorBlob { shape, data } = value;
30        let data = data.to_vec().into_boxed_slice();
31        // let data: Vec<T> = bytemuck::cast_slice(&data).to_vec();
32        let data = Box::leak(data);
33        let data: Box<[T]> = unsafe {
34            let ptr = data.as_ptr() as *const T;
35            let len = data.len() / size_of::<T>();
36            let slice = core::slice::from_raw_parts(ptr, len);
37            Box::from(slice)
38        };
39        let data = data.into();
40        Self {
41            shape,
42            data,
43            id: uid::Id::new(),
44            phantom: PhantomData,
45        }
46    }
47}
48
49impl<T: Scalar + Serialize> Serialize for TensorCpu<T> {
50    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
51    where
52        S: serde::Serializer,
53    {
54        TensorBlob::from(self.clone()).serialize(serializer)
55    }
56}
57
58impl<'de, T: Scalar + Deserialize<'de>> Deserialize<'de> for TensorCpu<T> {
59    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60    where
61        D: serde::Deserializer<'de>,
62    {
63        TensorBlob::deserialize(deserializer).map(Into::into)
64    }
65}
66
67#[cfg(not(target_arch = "wasm32"))]
68impl<T: Scalar + Serialize, K: Kind> Serialize for TensorGpu<T, K> {
69    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: serde::Serializer,
72    {
73        TensorBlob::from(self.back_in_place()).serialize(serializer)
74    }
75}
76
77#[cfg(target_arch = "wasm32")]
78impl<T: Scalar + Serialize, K: Kind> Serialize for TensorGpu<T, K> {
79    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
80    where
81        S: serde::Serializer,
82    {
83        unimplemented!()
84    }
85}
86
87impl Serialize for Context {
88    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89    where
90        S: serde::Serializer,
91    {
92        PhantomData::<Context>::serialize(&PhantomData, serializer)
93    }
94}
95
96pub struct Seed<'a, Context, Product> {
97    pub context: &'a Context,
98    _phantom: PhantomData<Product>,
99}
100
101impl<'a, Context, Product> Seed<'a, Context, Product> {
102    pub fn new(context: &'a Context) -> Self {
103        Self {
104            context,
105            _phantom: PhantomData,
106        }
107    }
108}
109
110impl<'de> DeserializeSeed<'de> for Seed<'de, Context, Context> {
111    type Value = Context;
112
113    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
114    where
115        D: Deserializer<'de>,
116    {
117        <PhantomData<Context> as Deserialize<'de>>::deserialize(deserializer)?;
118        Ok(self.context.clone())
119    }
120}
121
122impl<'de, T: Scalar + Deserialize<'de>> DeserializeSeed<'de> for Seed<'de, Context, TensorCpu<T>> {
123    type Value = TensorCpu<T>;
124
125    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
126    where
127        D: serde::Deserializer<'de>,
128    {
129        Deserialize::deserialize(deserializer)
130    }
131}
132
133impl<'de, T: Deserialize<'de>> DeserializeSeed<'de> for Seed<'de, Context, Arc<T>> {
134    type Value = Arc<T>;
135
136    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
137    where
138        D: serde::Deserializer<'de>,
139    {
140        Deserialize::deserialize(deserializer)
141    }
142}
143
144impl<'de, T: Scalar + Deserialize<'de>, K: Kind> DeserializeSeed<'de>
145    for Seed<'de, Context, TensorGpu<T, K>>
146{
147    type Value = TensorGpu<T, K>;
148
149    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
150    where
151        D: serde::Deserializer<'de>,
152    {
153        let context = &self.context;
154        let tensor: TensorBlob<'de> = Deserialize::deserialize(deserializer)?;
155        let tensor = TensorGpu::from_data_u8(context, tensor.shape, &tensor.data);
156        context.queue.submit(None);
157        tensor.map_err(D::Error::custom)
158    }
159}
160
161#[macro_export]
162macro_rules! impl_deserialize_seed {
163    ($tt:tt) => {
164        impl<'de, C> serde::de::DeserializeSeed<'de>
165            for $crate::tensor::serialization::Seed<'de, C, $tt>
166        {
167            type Value = $tt;
168
169            fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
170            where
171                D: serde::de::Deserializer<'de>,
172            {
173                $tt::deserialize(deserializer)
174            }
175        }
176    };
177    ($tt:tt, $gt:tt) => {
178        impl<'de, C, $gt> serde::de::DeserializeSeed<'de>
179            for $crate::tensor::serialization::Seed<'de, C, $tt<$gt>>
180        {
181            type Value = $tt<$gt>;
182
183            fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
184            where
185                D: serde::de::Deserializer<'de>,
186            {
187                $tt::<$gt>::deserialize(deserializer)
188            }
189        }
190    };
191}
192
193impl_deserialize_seed!(bool);
194impl_deserialize_seed!(usize);
195impl_deserialize_seed!(PhantomData, T);
196
197impl<'de, C, T> DeserializeSeed<'de> for Seed<'de, C, Vec<T>>
198where
199    Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
200{
201    type Value = Vec<T>;
202
203    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
204    where
205        D: serde::Deserializer<'de>,
206    {
207        struct VecVisitor<'de, C, T> {
208            context: &'de C,
209            marker: PhantomData<T>,
210        }
211
212        impl<'de, C, T> Visitor<'de> for VecVisitor<'de, C, T>
213        where
214            Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
215        {
216            type Value = Vec<T>;
217
218            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219                formatter.write_str("a sequence")
220            }
221
222            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
223            where
224                A: SeqAccess<'de>,
225            {
226                let mut values = Vec::<T>::new();
227
228                while let Some(value) = seq.next_element_seed(Seed::<C, T>::new(self.context))? {
229                    values.push(value);
230                }
231
232                Ok(values)
233            }
234        }
235
236        let visitor: VecVisitor<C, T> = VecVisitor {
237            context: self.context,
238            marker: PhantomData,
239        };
240        deserializer.deserialize_seq(visitor)
241    }
242}
243
244impl<'de, C, T> DeserializeSeed<'de> for Seed<'de, C, Option<T>>
245where
246    Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
247{
248    type Value = Option<T>;
249
250    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
251    where
252        D: serde::Deserializer<'de>,
253    {
254        struct OptionVisitor<'de, C, T> {
255            context: &'de C,
256            marker: PhantomData<T>,
257        }
258
259        impl<'de, C, T> Visitor<'de> for OptionVisitor<'de, C, T>
260        where
261            Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
262        {
263            type Value = Option<T>;
264
265            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
266                formatter.write_str("option")
267            }
268
269            #[inline]
270            fn visit_unit<E>(self) -> Result<Self::Value, E>
271            where
272                E: Error,
273            {
274                Ok(None)
275            }
276
277            #[inline]
278            fn visit_none<E>(self) -> Result<Self::Value, E>
279            where
280                E: Error,
281            {
282                Ok(None)
283            }
284
285            #[inline]
286            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
287            where
288                D: Deserializer<'de>,
289            {
290                let seed = Seed::<C, T>::new(self.context);
291                DeserializeSeed::deserialize(seed, deserializer).map(Some)
292            }
293
294            fn __private_visit_untagged_option<D>(self, deserializer: D) -> Result<Self::Value, ()>
295            where
296                D: Deserializer<'de>,
297            {
298                let seed = Seed::<C, T>::new(self.context);
299                Ok(DeserializeSeed::deserialize(seed, deserializer).ok())
300            }
301        }
302
303        let visitor: OptionVisitor<C, T> = OptionVisitor {
304            context: self.context,
305            marker: PhantomData,
306        };
307        deserializer.deserialize_option(visitor)
308    }
309}