1use crate::error::{EncodingError, Result, YetiError};
39use std::cell::RefCell;
40use std::collections::HashSet;
41
42use serde::de::{DeserializeSeed, MapAccess, SeqAccess, Visitor};
43use serde::ser::{SerializeMap, SerializeSeq, Serializer};
44use serde::{Deserializer, Serialize};
45
46pub fn transcode_msgpack_to_json(bytes: &[u8]) -> Result<Vec<u8>> {
61 transcode(bytes, None)
62}
63
64#[allow(clippy::implicit_hasher)]
80pub fn transcode_msgpack_to_json_filtered(
81 bytes: &[u8],
82 allowed: &HashSet<String>,
83) -> Result<Vec<u8>> {
84 transcode(bytes, Some(allowed))
85}
86
87fn transcode(bytes: &[u8], allowed: Option<&HashSet<String>>) -> Result<Vec<u8>> {
88 let mut de = rmp_serde::Deserializer::new(bytes);
89 let mut out = Vec::with_capacity(bytes.len() * 2);
90 let mut ser = serde_json::Serializer::new(&mut out);
91
92 let transcoder = RootTranscoder {
95 de: RefCell::new(Some(&mut de)),
96 allowed,
97 };
98
99 transcoder.serialize(&mut ser).map_err(|e| {
100 YetiError::Encoding(EncodingError::MessagePack(format!(
101 "Failed to transcode MessagePack to JSON: {e}"
102 )))
103 })?;
104
105 Ok(out)
106}
107
108struct RootTranscoder<'a, D> {
116 de: RefCell<Option<D>>,
117 allowed: Option<&'a HashSet<String>>,
118}
119
120impl<'de, D: Deserializer<'de>> Serialize for RootTranscoder<'_, D> {
121 fn serialize<S: Serializer>(&self, ser: S) -> std::result::Result<S::Ok, S::Error> {
122 let de = self.de.borrow_mut().take().ok_or_else(|| {
126 serde::ser::Error::custom("RootTranscoder deserializer already consumed")
127 })?;
128 de.deserialize_any(RootVisitor {
134 ser,
135 allowed: self.allowed,
136 })
137 .map_err(serde::ser::Error::custom)?
138 }
139}
140
141struct RootVisitor<'a, S: Serializer> {
143 ser: S,
144 allowed: Option<&'a HashSet<String>>,
145}
146
147impl<'de, S: Serializer> Visitor<'de> for RootVisitor<'_, S> {
148 type Value = std::result::Result<S::Ok, S::Error>;
149
150 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
151 f.write_str("any MessagePack value")
152 }
153
154 fn visit_map<A: MapAccess<'de>>(
155 self,
156 mut map: A,
157 ) -> std::result::Result<Self::Value, A::Error> {
158 let mut sm = match self.ser.serialize_map(None) {
159 Ok(sm) => sm,
160 Err(e) => return Ok(Err(e)),
161 };
162 while let Some(key) = map.next_key::<String>()? {
163 let keep = self
164 .allowed
165 .is_none_or(|set| key == "id" || set.contains(&key));
166 if keep {
167 if let Err(e) = sm.serialize_key(&key) {
168 return Ok(Err(e));
169 }
170 let value = map.next_value_seed(JsonSeed { sm: &mut sm })?;
171 if let Err(e) = value {
172 return Ok(Err(e));
173 }
174 } else {
175 map.next_value::<serde::de::IgnoredAny>()?;
177 }
178 }
179 Ok(sm.end())
180 }
181
182 fn visit_seq<A: SeqAccess<'de>>(
183 self,
184 mut seq: A,
185 ) -> std::result::Result<Self::Value, A::Error> {
186 let mut ss = match self.ser.serialize_seq(None) {
187 Ok(ss) => ss,
188 Err(e) => return Ok(Err(e)),
189 };
190 loop {
191 let stepped = seq.next_element_seed(SeqSeed { ss: &mut ss })?;
192 match stepped {
193 Some(Ok(())) => {},
194 Some(Err(e)) => return Ok(Err(e)),
195 None => break,
196 }
197 }
198 Ok(ss.end())
199 }
200
201 fn visit_bool<E>(self, v: bool) -> std::result::Result<Self::Value, E> {
202 Ok(self.ser.serialize_bool(v))
203 }
204 fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E> {
205 Ok(self.ser.serialize_i64(v))
206 }
207 fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, E> {
208 Ok(self.ser.serialize_u64(v))
209 }
210 fn visit_f64<E>(self, v: f64) -> std::result::Result<Self::Value, E> {
211 Ok(self.ser.serialize_f64(v))
212 }
213 fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E> {
214 Ok(self.ser.serialize_str(v))
215 }
216 fn visit_string<E>(self, v: String) -> std::result::Result<Self::Value, E> {
217 Ok(self.ser.serialize_str(&v))
218 }
219 fn visit_bytes<E>(self, v: &[u8]) -> std::result::Result<Self::Value, E> {
220 use serde::ser::SerializeSeq;
225 let mut seq = match self.ser.serialize_seq(Some(v.len())) {
226 Ok(s) => s,
227 Err(e) => return Ok(Err(e)),
228 };
229 for b in v {
230 if let Err(e) = seq.serialize_element(b) {
231 return Ok(Err(e));
232 }
233 }
234 Ok(seq.end())
235 }
236 fn visit_none<E>(self) -> std::result::Result<Self::Value, E> {
237 Ok(self.ser.serialize_none())
238 }
239 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E> {
240 Ok(self.ser.serialize_unit())
241 }
242 fn visit_some<D2: Deserializer<'de>>(
243 self,
244 de: D2,
245 ) -> std::result::Result<Self::Value, D2::Error> {
246 de.deserialize_any(self)
247 }
248}
249
250struct JsonSeed<'a, M: SerializeMap> {
252 sm: &'a mut M,
253}
254
255impl<'de, M: SerializeMap> DeserializeSeed<'de> for JsonSeed<'_, M> {
256 type Value = std::result::Result<(), M::Error>;
257
258 fn deserialize<D2: Deserializer<'de>>(
259 self,
260 de: D2,
261 ) -> std::result::Result<Self::Value, D2::Error> {
262 let v = ValueTranscoder {
263 de: RefCell::new(Some(de)),
264 };
265 Ok(self.sm.serialize_value(&v))
266 }
267}
268
269struct SeqSeed<'a, Q: SerializeSeq> {
272 ss: &'a mut Q,
273}
274
275impl<'de, Q: SerializeSeq> DeserializeSeed<'de> for SeqSeed<'_, Q> {
276 type Value = std::result::Result<(), Q::Error>;
277
278 fn deserialize<D2: Deserializer<'de>>(
279 self,
280 de: D2,
281 ) -> std::result::Result<Self::Value, D2::Error> {
282 let v = ValueTranscoder {
283 de: RefCell::new(Some(de)),
284 };
285 Ok(self.ss.serialize_element(&v))
286 }
287}
288
289struct ValueTranscoder<D> {
294 de: RefCell<Option<D>>,
295}
296
297impl<'de, D: Deserializer<'de>> Serialize for ValueTranscoder<D> {
298 fn serialize<S: Serializer>(&self, ser: S) -> std::result::Result<S::Ok, S::Error> {
299 let de = self.de.borrow_mut().take().ok_or_else(|| {
302 serde::ser::Error::custom("ValueTranscoder deserializer already consumed")
303 })?;
304 de.deserialize_any(RootVisitor { ser, allowed: None })
305 .map_err(serde::ser::Error::custom)?
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
312
313 use super::*;
314 use crate::encoding::{decode, encode};
315 use serde_json::{Value, json};
316
317 fn round_trip_oracle(msgpack: &[u8]) -> Vec<u8> {
320 let value: Value = decode(msgpack).unwrap();
321 serde_json::to_vec(&value).unwrap()
322 }
323
324 fn assert_identical(value: &Value) {
325 let msgpack = encode(value).unwrap();
326 let expected = round_trip_oracle(&msgpack);
327 let actual = transcode_msgpack_to_json(&msgpack).unwrap();
328 assert_eq!(
329 String::from_utf8_lossy(&actual),
330 String::from_utf8_lossy(&expected),
331 "transcode diverged from the round-trip oracle"
332 );
333 assert_eq!(
334 actual, expected,
335 "transcode produced non-byte-identical JSON"
336 );
337 }
338
339 #[test]
340 fn scalars_round_trip_identically() {
341 assert_identical(&json!(null));
342 assert_identical(&json!(true));
343 assert_identical(&json!(false));
344 assert_identical(&json!(0));
345 assert_identical(&json!(-42));
346 assert_identical(&json!(9_007_199_254_740_993u64));
347 assert_identical(&json!(3.5));
348 assert_identical(&json!("hello"));
349 assert_identical(&json!("with \"quotes\" and \n newlines \t and unicode ✓"));
350 }
351
352 #[test]
353 fn arrays_round_trip_identically() {
354 assert_identical(&json!([]));
355 assert_identical(&json!([1, 2, 3]));
356 assert_identical(&json!(["a", "b", "c"]));
357 assert_identical(&json!([{"x": 1}, {"y": 2}]));
358 assert_identical(&json!([[1, [2, [3]]], null, "x"]));
359 }
360
361 #[test]
362 fn objects_round_trip_identically() {
363 assert_identical(&json!({}));
364 assert_identical(&json!({"id": "abc", "name": "n"}));
365 assert_identical(&json!({"zeta": 1, "alpha": 2, "mid": 3}));
368 assert_identical(&json!({
369 "id": "rec_00000001",
370 "name": "User",
371 "nested": {"b": 2, "a": 1, "deep": {"y": 2, "x": 1}},
372 "list": [1, 2, {"k": "v"}],
373 "flag": true,
374 "missing": null
375 }));
376 }
377
378 #[test]
379 fn realistic_record_round_trips_identically() {
380 let value = json!({
381 "id": "rec_00000042",
382 "name": "User Name 42",
383 "email": "user42@example.com",
384 "category": "category_2",
385 "description": "A detailed description with punctuation, commas, and \"quotes\".",
386 "metadata": "{\"created\":\"42\",\"tags\":[\"a\",\"b\"]}",
387 "payload": "X".repeat(550),
388 });
389 assert_identical(&value);
390 }
391
392 #[test]
393 fn filter_drops_unknown_top_level_keys() {
394 let value = json!({
395 "id": "rec_1",
396 "name": "keep",
397 "stale_field": "drop me",
398 "another_stale": 99,
399 });
400 let msgpack = encode(&value).unwrap();
401
402 let mut allowed = HashSet::new();
403 allowed.insert("name".to_owned());
404 let actual = transcode_msgpack_to_json_filtered(&msgpack, &allowed).unwrap();
407
408 let mut oracle_value = value;
410 if let Some(obj) = oracle_value.as_object_mut() {
411 obj.retain(|k, _| k == "id" || allowed.contains(k));
412 }
413 let expected = serde_json::to_vec(&oracle_value).unwrap();
414
415 assert_eq!(actual, expected);
416 let parsed: Value = serde_json::from_slice(&actual).unwrap();
418 assert!(parsed.get("stale_field").is_none());
419 assert!(parsed.get("another_stale").is_none());
420 assert_eq!(parsed["id"], json!("rec_1"));
421 assert_eq!(parsed["name"], json!("keep"));
422 }
423
424 #[test]
425 fn filter_keeps_all_when_every_key_allowed() {
426 let value = json!({"id": "x", "a": 1, "b": 2, "c": 3});
427 let msgpack = encode(&value).unwrap();
428 let allowed: HashSet<String> = ["a", "b", "c"].iter().map(|s| (*s).to_owned()).collect();
429
430 let filtered = transcode_msgpack_to_json_filtered(&msgpack, &allowed).unwrap();
431 let unfiltered = transcode_msgpack_to_json(&msgpack).unwrap();
432 assert_eq!(filtered, unfiltered);
433 assert_eq!(filtered, round_trip_oracle(&msgpack));
434 }
435
436 #[test]
437 fn filter_does_not_touch_nested_objects() {
438 let value = json!({
440 "id": "x",
441 "keep": {"stale_field": "nested stays", "n": 1},
442 });
443 let msgpack = encode(&value).unwrap();
444 let allowed: HashSet<String> = std::iter::once("keep".to_owned()).collect();
445
446 let actual = transcode_msgpack_to_json_filtered(&msgpack, &allowed).unwrap();
447 let parsed: Value = serde_json::from_slice(&actual).unwrap();
448 assert_eq!(parsed["keep"]["stale_field"], json!("nested stays"));
449 assert_eq!(parsed["keep"]["n"], json!(1));
450 }
451
452 #[test]
453 fn invalid_msgpack_is_an_error() {
454 let err = transcode_msgpack_to_json(&[0xc1]);
456 assert!(err.is_err());
457 }
458}