1use serde::de;
16use std::fmt;
17
18use crate::Key;
19
20pub struct Deserializer<'a, D, T> {
24 deserializer: D,
25 key: Option<&'a Key<T>>,
26}
27
28impl<'a, 'de, D, T> Deserializer<'a, D, T>
29where
30 D: de::Deserializer<'de>,
31{
32 pub fn new(deserializer: D, key: Option<&'a Key<T>>) -> Deserializer<'a, D, T> {
36 Deserializer { deserializer, key }
37 }
38}
39
40macro_rules! forward_deserialize {
41 ($name:ident) => {forward_deserialize!($name, );};
42 ($name:ident, $($arg:tt => $ty:ty),*) => {
43 fn $name<V>(self, $($arg: $ty,)* visitor: V) -> Result<V::Value, D::Error>
44 where V: de::Visitor<'de>
45 {
46 let visitor = Visitor {
47 visitor,
48 key: self.key,
49 };
50 self.deserializer.$name($($arg,)* visitor)
51 }
52 }
53}
54
55impl<'de, D, T> de::Deserializer<'de> for Deserializer<'_, D, T>
56where
57 D: de::Deserializer<'de>,
58{
59 type Error = D::Error;
60
61 forward_deserialize!(deserialize_any);
62 forward_deserialize!(deserialize_bool);
63 forward_deserialize!(deserialize_u8);
64 forward_deserialize!(deserialize_u16);
65 forward_deserialize!(deserialize_u32);
66 forward_deserialize!(deserialize_u64);
67 forward_deserialize!(deserialize_i8);
68 forward_deserialize!(deserialize_i16);
69 forward_deserialize!(deserialize_i32);
70 forward_deserialize!(deserialize_i64);
71 forward_deserialize!(deserialize_f32);
72 forward_deserialize!(deserialize_f64);
73 forward_deserialize!(deserialize_char);
74 forward_deserialize!(deserialize_str);
75 forward_deserialize!(deserialize_string);
76 forward_deserialize!(deserialize_unit);
77 forward_deserialize!(deserialize_option);
78 forward_deserialize!(deserialize_seq);
79 forward_deserialize!(deserialize_bytes);
80 forward_deserialize!(deserialize_byte_buf);
81 forward_deserialize!(deserialize_map);
82 forward_deserialize!(deserialize_unit_struct, name => &'static str);
83 forward_deserialize!(deserialize_newtype_struct, name => &'static str);
84 forward_deserialize!(deserialize_tuple_struct, name => &'static str, len => usize);
85 forward_deserialize!(deserialize_struct,
86 name => &'static str,
87 fields => &'static [&'static str]);
88 forward_deserialize!(deserialize_identifier);
89 forward_deserialize!(deserialize_tuple, len => usize);
90 forward_deserialize!(deserialize_enum,
91 name => &'static str,
92 variants => &'static [&'static str]);
93 forward_deserialize!(deserialize_ignored_any);
94}
95
96struct Visitor<'a, V, T> {
97 visitor: V,
98 key: Option<&'a Key<T>>,
99}
100
101impl<V, T> Visitor<'_, V, T> {
102 fn expand_str<E>(&self, s: &str) -> Result<Option<String>, E>
103 where
104 E: de::Error,
105 {
106 if s.starts_with("${enc:") && s.ends_with('}') {
107 match self.key {
108 Some(key) => match key.decrypt(&s[6..s.len() - 1]) {
109 Ok(s) => Ok(Some(s)),
110 Err(e) => Err(E::custom(e.to_string())),
111 },
112 None => Err(E::custom("missing encryption key")),
113 }
114 } else {
115 Ok(None)
116 }
117 }
118}
119
120macro_rules! forward_visit {
121 ($name:ident, $ty:ty) => {
122 fn $name<E>(self, v: $ty) -> Result<V::Value, E>
123 where
124 E: de::Error,
125 {
126 self.visitor.$name(v)
127 }
128 };
129}
130
131impl<'de, V, T> de::Visitor<'de> for Visitor<'_, V, T>
132where
133 V: de::Visitor<'de>,
134{
135 type Value = V::Value;
136
137 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
138 self.visitor.expecting(formatter)
139 }
140
141 forward_visit!(visit_bool, bool);
142 forward_visit!(visit_i8, i8);
143 forward_visit!(visit_i16, i16);
144 forward_visit!(visit_i32, i32);
145 forward_visit!(visit_i64, i64);
146 forward_visit!(visit_u8, u8);
147 forward_visit!(visit_u16, u16);
148 forward_visit!(visit_u32, u32);
149 forward_visit!(visit_u64, u64);
150 forward_visit!(visit_f32, f32);
151 forward_visit!(visit_f64, f64);
152 forward_visit!(visit_char, char);
153 forward_visit!(visit_bytes, &[u8]);
154 forward_visit!(visit_byte_buf, Vec<u8>);
155
156 fn visit_str<E>(self, v: &str) -> Result<V::Value, E>
157 where
158 E: de::Error,
159 {
160 match self.expand_str(v)? {
161 Some(s) => self.visitor.visit_string(s),
162 None => self.visitor.visit_str(v),
163 }
164 }
165
166 fn visit_string<E>(self, v: String) -> Result<V::Value, E>
167 where
168 E: de::Error,
169 {
170 match self.expand_str(&v)? {
171 Some(s) => self.visitor.visit_string(s),
172 None => self.visitor.visit_string(v),
173 }
174 }
175
176 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<V::Value, E>
177 where
178 E: de::Error,
179 {
180 match self.expand_str(v)? {
181 Some(s) => self.visitor.visit_string(s),
182 None => self.visitor.visit_borrowed_str(v),
183 }
184 }
185
186 fn visit_unit<E>(self) -> Result<V::Value, E>
187 where
188 E: de::Error,
189 {
190 self.visitor.visit_unit()
191 }
192
193 fn visit_none<E>(self) -> Result<V::Value, E>
194 where
195 E: de::Error,
196 {
197 self.visitor.visit_none()
198 }
199
200 fn visit_some<D>(self, deserializer: D) -> Result<V::Value, D::Error>
201 where
202 D: de::Deserializer<'de>,
203 {
204 let deserializer = Deserializer::new(deserializer, self.key);
205 self.visitor.visit_some(deserializer)
206 }
207
208 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<V::Value, D::Error>
209 where
210 D: de::Deserializer<'de>,
211 {
212 let deserializer = Deserializer::new(deserializer, self.key);
213 self.visitor.visit_newtype_struct(deserializer)
214 }
215
216 fn visit_seq<V2>(self, visitor: V2) -> Result<V::Value, V2::Error>
217 where
218 V2: de::SeqAccess<'de>,
219 {
220 let visitor = Visitor {
221 visitor,
222 key: self.key,
223 };
224 self.visitor.visit_seq(visitor)
225 }
226
227 fn visit_map<V2>(self, visitor: V2) -> Result<V::Value, V2::Error>
228 where
229 V2: de::MapAccess<'de>,
230 {
231 let visitor = Visitor {
232 visitor,
233 key: self.key,
234 };
235 self.visitor.visit_map(visitor)
236 }
237
238 fn visit_enum<V2>(self, visitor: V2) -> Result<V::Value, V2::Error>
239 where
240 V2: de::EnumAccess<'de>,
241 {
242 let visitor = Visitor {
243 visitor,
244 key: self.key,
245 };
246 self.visitor.visit_enum(visitor)
247 }
248}
249
250impl<'de, V, T> de::SeqAccess<'de> for Visitor<'_, V, T>
251where
252 V: de::SeqAccess<'de>,
253{
254 type Error = V::Error;
255
256 fn next_element_seed<S>(&mut self, seed: S) -> Result<Option<S::Value>, V::Error>
257 where
258 S: de::DeserializeSeed<'de>,
259 {
260 let seed = DeserializeSeed {
261 seed,
262 key: self.key,
263 };
264 self.visitor.next_element_seed(seed)
265 }
266
267 fn size_hint(&self) -> Option<usize> {
268 self.visitor.size_hint()
269 }
270}
271
272impl<'de, V, T> de::MapAccess<'de> for Visitor<'_, V, T>
273where
274 V: de::MapAccess<'de>,
275{
276 type Error = V::Error;
277
278 fn next_key_seed<S>(&mut self, seed: S) -> Result<Option<S::Value>, V::Error>
279 where
280 S: de::DeserializeSeed<'de>,
281 {
282 let seed = DeserializeSeed {
283 seed,
284 key: self.key,
285 };
286 self.visitor.next_key_seed(seed)
287 }
288
289 fn next_value_seed<S>(&mut self, seed: S) -> Result<S::Value, V::Error>
290 where
291 S: de::DeserializeSeed<'de>,
292 {
293 let seed = DeserializeSeed {
294 seed,
295 key: self.key,
296 };
297 self.visitor.next_value_seed(seed)
298 }
299
300 #[allow(clippy::type_complexity)]
301 fn next_entry_seed<K, V2>(
302 &mut self,
303 kseed: K,
304 vseed: V2,
305 ) -> Result<Option<(K::Value, V2::Value)>, V::Error>
306 where
307 K: de::DeserializeSeed<'de>,
308 V2: de::DeserializeSeed<'de>,
309 {
310 let kseed = DeserializeSeed {
311 seed: kseed,
312 key: self.key,
313 };
314 let vseed = DeserializeSeed {
315 seed: vseed,
316 key: self.key,
317 };
318 self.visitor.next_entry_seed(kseed, vseed)
319 }
320
321 fn size_hint(&self) -> Option<usize> {
322 self.visitor.size_hint()
323 }
324}
325
326impl<'a, 'de, V, T> de::EnumAccess<'de> for Visitor<'a, V, T>
327where
328 V: de::EnumAccess<'de>,
329{
330 type Error = V::Error;
331 type Variant = Visitor<'a, V::Variant, T>;
332
333 #[allow(clippy::type_complexity)]
334 fn variant_seed<S>(self, seed: S) -> Result<(S::Value, Visitor<'a, V::Variant, T>), V::Error>
335 where
336 S: de::DeserializeSeed<'de>,
337 {
338 let seed = DeserializeSeed {
339 seed,
340 key: self.key,
341 };
342 match self.visitor.variant_seed(seed) {
343 Ok((value, variant)) => {
344 let variant = Visitor {
345 visitor: variant,
346 key: self.key,
347 };
348 Ok((value, variant))
349 }
350 Err(e) => Err(e),
351 }
352 }
353}
354
355impl<'de, V, T> de::VariantAccess<'de> for Visitor<'_, V, T>
356where
357 V: de::VariantAccess<'de>,
358{
359 type Error = V::Error;
360
361 fn unit_variant(self) -> Result<(), V::Error> {
362 self.visitor.unit_variant()
363 }
364
365 fn newtype_variant_seed<S>(self, seed: S) -> Result<S::Value, V::Error>
366 where
367 S: de::DeserializeSeed<'de>,
368 {
369 let seed = DeserializeSeed {
370 seed,
371 key: self.key,
372 };
373 self.visitor.newtype_variant_seed(seed)
374 }
375
376 fn tuple_variant<V2>(self, len: usize, visitor: V2) -> Result<V2::Value, V::Error>
377 where
378 V2: de::Visitor<'de>,
379 {
380 let visitor = Visitor {
381 visitor,
382 key: self.key,
383 };
384 self.visitor.tuple_variant(len, visitor)
385 }
386
387 fn struct_variant<V2>(
388 self,
389 fields: &'static [&'static str],
390 visitor: V2,
391 ) -> Result<V2::Value, V::Error>
392 where
393 V2: de::Visitor<'de>,
394 {
395 let visitor = Visitor {
396 visitor,
397 key: self.key,
398 };
399 self.visitor.struct_variant(fields, visitor)
400 }
401}
402
403struct DeserializeSeed<'a, S, T> {
404 seed: S,
405 key: Option<&'a Key<T>>,
406}
407
408impl<'de, S, T> de::DeserializeSeed<'de> for DeserializeSeed<'_, S, T>
409where
410 S: de::DeserializeSeed<'de>,
411{
412 type Value = S::Value;
413
414 fn deserialize<D>(self, deserializer: D) -> Result<S::Value, D::Error>
415 where
416 D: de::Deserializer<'de>,
417 {
418 let deserializer = Deserializer::new(deserializer, self.key);
419 self.seed.deserialize(deserializer)
420 }
421}