1use std::{borrow::Cow, fmt, marker::PhantomData, sync::Arc};
2
3use serde::{
4 de::{DeserializeSeed, Error, SeqAccess, Visitor},
5 Deserialize, Deserializer, Serialize,
6};
7
8use super::{kind::Kind, shape::Shape, TensorCpu, TensorGpu};
9use crate::{context::Context, num::Scalar};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(bound(deserialize = "'de: 'a"))]
13struct TensorBlob<'a> {
14 shape: Shape,
15 #[serde(with = "serde_bytes")]
16 data: Cow<'a, [u8]>,
17}
18
19impl<T: Scalar> From<TensorCpu<T>> for TensorBlob<'_> {
20 fn from(value: TensorCpu<T>) -> Self {
21 let TensorCpu { shape, data, .. } = value;
22 let data = bytemuck::cast_slice(&data).to_vec().into();
23 Self { shape, data }
24 }
25}
26
27impl<T: Scalar> From<TensorBlob<'_>> for TensorCpu<T> {
28 fn from(value: TensorBlob) -> Self {
29 let TensorBlob { shape, data } = value;
30 let data = data.to_vec().into_boxed_slice();
31 let data = Box::leak(data);
33 let data: Box<[T]> = unsafe {
34 let ptr = data.as_ptr() as *const T;
35 let len = data.len() / size_of::<T>();
36 let slice = core::slice::from_raw_parts(ptr, len);
37 Box::from(slice)
38 };
39 let data = data.into();
40 Self {
41 shape,
42 data,
43 id: uid::Id::new(),
44 phantom: PhantomData,
45 }
46 }
47}
48
49impl<T: Scalar + Serialize> Serialize for TensorCpu<T> {
50 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
51 where
52 S: serde::Serializer,
53 {
54 TensorBlob::from(self.clone()).serialize(serializer)
55 }
56}
57
58impl<'de, T: Scalar + Deserialize<'de>> Deserialize<'de> for TensorCpu<T> {
59 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60 where
61 D: serde::Deserializer<'de>,
62 {
63 TensorBlob::deserialize(deserializer).map(Into::into)
64 }
65}
66
67#[cfg(not(target_arch = "wasm32"))]
68impl<T: Scalar + Serialize, K: Kind> Serialize for TensorGpu<T, K> {
69 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: serde::Serializer,
72 {
73 TensorBlob::from(self.back_in_place()).serialize(serializer)
74 }
75}
76
77#[cfg(target_arch = "wasm32")]
78impl<T: Scalar + Serialize, K: Kind> Serialize for TensorGpu<T, K> {
79 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
80 where
81 S: serde::Serializer,
82 {
83 unimplemented!()
84 }
85}
86
87impl Serialize for Context {
88 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89 where
90 S: serde::Serializer,
91 {
92 PhantomData::<Context>::serialize(&PhantomData, serializer)
93 }
94}
95
96pub struct Seed<'a, Context, Product> {
97 pub context: &'a Context,
98 _phantom: PhantomData<Product>,
99}
100
101impl<'a, Context, Product> Seed<'a, Context, Product> {
102 pub fn new(context: &'a Context) -> Self {
103 Self {
104 context,
105 _phantom: PhantomData,
106 }
107 }
108}
109
110impl<'de> DeserializeSeed<'de> for Seed<'de, Context, Context> {
111 type Value = Context;
112
113 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
114 where
115 D: Deserializer<'de>,
116 {
117 <PhantomData<Context> as Deserialize<'de>>::deserialize(deserializer)?;
118 Ok(self.context.clone())
119 }
120}
121
122impl<'de, T: Scalar + Deserialize<'de>> DeserializeSeed<'de> for Seed<'de, Context, TensorCpu<T>> {
123 type Value = TensorCpu<T>;
124
125 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
126 where
127 D: serde::Deserializer<'de>,
128 {
129 Deserialize::deserialize(deserializer)
130 }
131}
132
133impl<'de, T: Deserialize<'de>> DeserializeSeed<'de> for Seed<'de, Context, Arc<T>> {
134 type Value = Arc<T>;
135
136 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
137 where
138 D: serde::Deserializer<'de>,
139 {
140 Deserialize::deserialize(deserializer)
141 }
142}
143
144impl<'de, T: Scalar + Deserialize<'de>, K: Kind> DeserializeSeed<'de>
145 for Seed<'de, Context, TensorGpu<T, K>>
146{
147 type Value = TensorGpu<T, K>;
148
149 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
150 where
151 D: serde::Deserializer<'de>,
152 {
153 let context = &self.context;
154 let tensor: TensorBlob<'de> = Deserialize::deserialize(deserializer)?;
155 let tensor = TensorGpu::from_data_u8(context, tensor.shape, &tensor.data);
156 context.queue.submit(None);
157 tensor.map_err(D::Error::custom)
158 }
159}
160
161#[macro_export]
162macro_rules! impl_deserialize_seed {
163 ($tt:tt) => {
164 impl<'de, C> serde::de::DeserializeSeed<'de>
165 for $crate::tensor::serialization::Seed<'de, C, $tt>
166 {
167 type Value = $tt;
168
169 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
170 where
171 D: serde::de::Deserializer<'de>,
172 {
173 $tt::deserialize(deserializer)
174 }
175 }
176 };
177 ($tt:tt, $gt:tt) => {
178 impl<'de, C, $gt> serde::de::DeserializeSeed<'de>
179 for $crate::tensor::serialization::Seed<'de, C, $tt<$gt>>
180 {
181 type Value = $tt<$gt>;
182
183 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
184 where
185 D: serde::de::Deserializer<'de>,
186 {
187 $tt::<$gt>::deserialize(deserializer)
188 }
189 }
190 };
191}
192
193impl_deserialize_seed!(bool);
194impl_deserialize_seed!(usize);
195impl_deserialize_seed!(PhantomData, T);
196
197impl<'de, C, T> DeserializeSeed<'de> for Seed<'de, C, Vec<T>>
198where
199 Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
200{
201 type Value = Vec<T>;
202
203 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
204 where
205 D: serde::Deserializer<'de>,
206 {
207 struct VecVisitor<'de, C, T> {
208 context: &'de C,
209 marker: PhantomData<T>,
210 }
211
212 impl<'de, C, T> Visitor<'de> for VecVisitor<'de, C, T>
213 where
214 Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
215 {
216 type Value = Vec<T>;
217
218 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219 formatter.write_str("a sequence")
220 }
221
222 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
223 where
224 A: SeqAccess<'de>,
225 {
226 let mut values = Vec::<T>::new();
227
228 while let Some(value) = seq.next_element_seed(Seed::<C, T>::new(self.context))? {
229 values.push(value);
230 }
231
232 Ok(values)
233 }
234 }
235
236 let visitor: VecVisitor<C, T> = VecVisitor {
237 context: self.context,
238 marker: PhantomData,
239 };
240 deserializer.deserialize_seq(visitor)
241 }
242}
243
244impl<'de, C, T> DeserializeSeed<'de> for Seed<'de, C, Option<T>>
245where
246 Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
247{
248 type Value = Option<T>;
249
250 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
251 where
252 D: serde::Deserializer<'de>,
253 {
254 struct OptionVisitor<'de, C, T> {
255 context: &'de C,
256 marker: PhantomData<T>,
257 }
258
259 impl<'de, C, T> Visitor<'de> for OptionVisitor<'de, C, T>
260 where
261 Seed<'de, C, T>: DeserializeSeed<'de, Value = T>,
262 {
263 type Value = Option<T>;
264
265 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
266 formatter.write_str("option")
267 }
268
269 #[inline]
270 fn visit_unit<E>(self) -> Result<Self::Value, E>
271 where
272 E: Error,
273 {
274 Ok(None)
275 }
276
277 #[inline]
278 fn visit_none<E>(self) -> Result<Self::Value, E>
279 where
280 E: Error,
281 {
282 Ok(None)
283 }
284
285 #[inline]
286 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
287 where
288 D: Deserializer<'de>,
289 {
290 let seed = Seed::<C, T>::new(self.context);
291 DeserializeSeed::deserialize(seed, deserializer).map(Some)
292 }
293
294 fn __private_visit_untagged_option<D>(self, deserializer: D) -> Result<Self::Value, ()>
295 where
296 D: Deserializer<'de>,
297 {
298 let seed = Seed::<C, T>::new(self.context);
299 Ok(DeserializeSeed::deserialize(seed, deserializer).ok())
300 }
301 }
302
303 let visitor: OptionVisitor<C, T> = OptionVisitor {
304 context: self.context,
305 marker: PhantomData,
306 };
307 deserializer.deserialize_option(visitor)
308 }
309}