1use byteorder::{BigEndian, LittleEndian, WriteBytesExt};
10use num_bigint::BigInt;
11use num_traits::Signed;
12use serde::ser;
13use serde::ser::Serialize;
14use std::collections::BTreeSet;
15use std::io;
16
17use super::consts::*;
18use super::error::{Error, Result};
19use super::value::{HashableValue, Value};
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub enum PickleProto {
24 V2,
25 V3,
26}
27
28impl Default for PickleProto {
29 fn default() -> Self {
30 Self::V3
31 }
32}
33
34#[derive(Clone, Debug, Default)]
36pub struct SerOptions {
37 proto: PickleProto,
38 compat_enum_repr: bool,
39}
40
41impl SerOptions {
42 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub fn proto_v2(mut self) -> Self {
52 self.proto = PickleProto::V2;
53 self
54 }
55
56 pub fn compat_enum_repr(mut self) -> Self {
76 self.compat_enum_repr = true;
77 self
78 }
79}
80
81pub struct Serializer<W> {
83 writer: W,
84 options: SerOptions,
85}
86
87impl<W: io::Write> Serializer<W> {
88 pub fn new(writer: W, options: SerOptions) -> Self {
89 Serializer { writer, options }
90 }
91
92 pub fn into_inner(self) -> W {
94 self.writer
95 }
96
97 #[inline]
98 fn write_opcode(&mut self, opcode: u8) -> Result<()> {
99 self.writer.write_all(&[opcode]).map_err(From::from)
100 }
101
102 fn serialize_hashable_value(&mut self, value: &HashableValue) -> Result<()> {
103 use serde::Serializer;
104 match *value {
105 HashableValue::None => self.serialize_unit(),
107 HashableValue::Bool(b) => self.serialize_bool(b),
108 HashableValue::I64(i) => self.serialize_i64(i),
109 HashableValue::F64(f) => self.serialize_f64(f),
110 HashableValue::Bytes(ref b) => self.serialize_bytes(b),
111 HashableValue::String(ref s) => self.serialize_str(s),
112 HashableValue::Int(ref i) => self.serialize_bigint(i),
113 HashableValue::FrozenSet(ref s) => self.serialize_set(s, b"frozenset"),
114 HashableValue::Tuple(ref t) => {
115 self.serialize_tuplevalue(t, |slf, v| slf.serialize_hashable_value(v))
116 }
117 }
118 }
119
120 fn serialize_value(&mut self, value: &Value) -> Result<()> {
121 use serde::Serializer;
122 match *value {
123 Value::None => self.serialize_unit(),
125 Value::Bool(b) => self.serialize_bool(b),
126 Value::I64(i) => self.serialize_i64(i),
127 Value::F64(f) => self.serialize_f64(f),
128 Value::Bytes(ref b) => self.serialize_bytes(b),
129 Value::String(ref s) => self.serialize_str(s),
130 Value::List(ref l) => {
131 self.write_opcode(EMPTY_LIST)?;
132 for chunk in l.chunks(1000) {
133 self.write_opcode(MARK)?;
134 for item in chunk {
135 self.serialize_value(item)?;
136 }
137 self.write_opcode(APPENDS)?;
138 }
139 Ok(())
140 }
141 Value::Dict(ref d) => {
142 self.write_opcode(EMPTY_DICT)?;
143 self.write_opcode(MARK)?;
144 for (n, (key, value)) in d.iter().enumerate() {
145 if n % 1000 == 999 {
146 self.write_opcode(SETITEMS)?;
147 self.write_opcode(MARK)?;
148 }
149 self.serialize_hashable_value(key)?;
150 self.serialize_value(value)?;
151 }
152 self.write_opcode(SETITEMS)?;
153 Ok(())
154 }
155
156 Value::Int(ref i) => self.serialize_bigint(i),
158 Value::Tuple(ref t) => self.serialize_tuplevalue(t, |slf, v| slf.serialize_value(v)),
159 Value::Set(ref s) => self.serialize_set(s, b"set"),
160 Value::FrozenSet(ref s) => self.serialize_set(s, b"frozenset"),
161 }
162 }
163
164 fn serialize_bigint(&mut self, i: &BigInt) -> Result<()> {
165 let bytes = if i.is_negative() {
166 let n_bytes = i.to_bytes_le().1.len();
167 let pos = i + (BigInt::from(1) << (n_bytes * 8));
168 let mut bytes = pos.to_bytes_le().1;
169 while bytes.len() < n_bytes {
170 bytes.push(0x00);
171 }
172 if *bytes.last().unwrap() < 0x80 {
173 bytes.push(0xff);
174 }
175 bytes
176 } else {
177 let mut bytes = i.to_bytes_le().1;
178 if *bytes.last().unwrap() >= 0x80 {
179 bytes.push(0x00);
180 }
181 bytes
182 };
183 if bytes.len() < 256 {
184 self.write_opcode(LONG1)?;
185 self.writer.write_u8(bytes.len() as u8)?;
186 } else {
187 self.write_opcode(LONG4)?;
188 self.writer.write_u32::<LittleEndian>(bytes.len() as u32)?;
189 }
190 self.writer.write_all(&bytes).map_err(From::from)
191 }
192
193 fn serialize_tuplevalue<T, F>(&mut self, t: &[T], f: F) -> Result<()>
194 where
195 F: Fn(&mut Self, &T) -> Result<()>,
196 {
197 if t.is_empty() {
198 self.write_opcode(EMPTY_TUPLE)
199 } else if t.len() == 1 {
200 f(self, &t[0])?;
201 self.write_opcode(TUPLE1)
202 } else if t.len() == 2 {
203 f(self, &t[0])?;
204 f(self, &t[1])?;
205 self.write_opcode(TUPLE2)
206 } else if t.len() == 3 {
207 f(self, &t[0])?;
208 f(self, &t[1])?;
209 f(self, &t[2])?;
210 self.write_opcode(TUPLE3)
211 } else {
212 self.write_opcode(MARK)?;
213 for item in t.iter() {
214 f(self, item)?;
215 }
216 self.write_opcode(TUPLE)?;
217 Ok(())
218 }
219 }
220
221 fn serialize_set(&mut self, items: &BTreeSet<HashableValue>, name: &[u8]) -> Result<()> {
222 self.write_opcode(GLOBAL)?;
223 if self.options.proto == PickleProto::V3 {
224 self.writer.write_all(b"builtins\n")?;
225 } else {
226 self.writer.write_all(b"__builtin__\n")?;
227 }
228 self.writer.write_all(name)?;
229 self.writer.write_all(b"\n")?;
230 self.write_opcode(EMPTY_LIST)?;
231 self.write_opcode(MARK)?;
232 for (n, item) in items.iter().enumerate() {
233 if n % 1000 == 999 {
234 self.write_opcode(APPENDS)?;
235 self.write_opcode(MARK)?;
236 }
237 self.serialize_hashable_value(item)?;
238 }
239 self.write_opcode(APPENDS)?;
240 self.write_opcode(TUPLE1)?;
241 self.write_opcode(REDUCE)
242 }
243}
244
245pub struct Compound<'a, W: io::Write + 'a> {
246 ser: &'a mut Serializer<W>,
247 state: Option<usize>,
248}
249
250impl<'a, W: io::Write> ser::SerializeSeq for Compound<'a, W> {
251 type Ok = ();
252 type Error = Error;
253
254 #[inline]
255 fn serialize_element<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
256 value.serialize(&mut *self.ser)?;
257 *self.state.as_mut().unwrap() += 1;
259 if self.state.unwrap() == 1000 {
260 self.ser.write_opcode(APPENDS)?;
261 self.ser.write_opcode(MARK)?;
262 self.state = Some(0);
263 }
264 Ok(())
265 }
266
267 #[inline]
268 fn end(self) -> Result<()> {
269 if self.state.is_some() {
270 self.ser.write_opcode(APPENDS)?;
271 }
272 Ok(())
273 }
274}
275
276impl<'a, W: io::Write> ser::SerializeTuple for Compound<'a, W> {
277 type Ok = ();
278 type Error = Error;
279
280 #[inline]
281 fn serialize_element<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
282 value.serialize(&mut *self.ser)
283 }
284
285 #[inline]
286 fn end(self) -> Result<()> {
287 if self.state.is_some() {
288 self.ser.write_opcode(TUPLE)?;
289 }
290 Ok(())
291 }
292}
293
294impl<'a, W: io::Write> ser::SerializeTupleStruct for Compound<'a, W> {
295 type Ok = ();
296 type Error = Error;
297
298 #[inline]
299 fn serialize_field<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
300 ser::SerializeTuple::serialize_element(self, value)
301 }
302
303 #[inline]
304 fn end(self) -> Result<()> {
305 ser::SerializeTuple::end(self)
306 }
307}
308
309impl<'a, W: io::Write> ser::SerializeTupleVariant for Compound<'a, W> {
310 type Ok = ();
311 type Error = Error;
312
313 #[inline]
314 fn serialize_field<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
315 value.serialize(&mut *self.ser)
316 }
317
318 #[inline]
319 fn end(self) -> Result<()> {
320 self.ser.write_opcode(APPENDS)?;
321 if self.ser.options.compat_enum_repr {
322 self.ser.write_opcode(TUPLE2)
323 } else {
324 self.ser.write_opcode(SETITEM)
325 }
326 }
327}
328
329impl<'a, W: io::Write> ser::SerializeMap for Compound<'a, W> {
330 type Ok = ();
331 type Error = Error;
332
333 #[inline]
334 fn serialize_key<T: Serialize + ?Sized>(&mut self, key: &T) -> Result<()> {
335 key.serialize(&mut *self.ser)
336 }
337
338 #[inline]
339 fn serialize_value<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
340 value.serialize(&mut *self.ser)?;
341 *self.state.as_mut().unwrap() += 1;
343 if self.state.unwrap() == 1000 {
344 self.ser.write_opcode(SETITEMS)?;
345 self.ser.write_opcode(MARK)?;
346 self.state = Some(0);
347 }
348 Ok(())
349 }
350
351 #[inline]
352 fn end(self) -> Result<()> {
353 if self.state.is_some() {
354 self.ser.write_opcode(SETITEMS)?;
355 }
356 Ok(())
357 }
358}
359
360impl<'a, W: io::Write> ser::SerializeStruct for Compound<'a, W> {
361 type Ok = ();
362 type Error = Error;
363
364 #[inline]
365 fn serialize_field<T: Serialize + ?Sized>(&mut self, key: &'static str, value: &T) -> Result<()> {
366 ser::SerializeMap::serialize_key(self, key)?;
367 ser::SerializeMap::serialize_value(self, value)
368 }
369
370 #[inline]
371 fn end(self) -> Result<()> {
372 ser::SerializeMap::end(self)
373 }
374}
375
376impl<'a, W: io::Write> ser::SerializeStructVariant for Compound<'a, W> {
377 type Ok = ();
378 type Error = Error;
379
380 #[inline]
381 fn serialize_field<T: Serialize + ?Sized>(&mut self, key: &'static str, value: &T) -> Result<()> {
382 ser::SerializeStruct::serialize_field(self, key, value)
383 }
384
385 #[inline]
386 fn end(self) -> Result<()> {
387 if self.state.is_some() {
388 self.ser.write_opcode(SETITEMS)?;
389 }
390 if self.ser.options.compat_enum_repr {
391 self.ser.write_opcode(TUPLE2)
392 } else {
393 self.ser.write_opcode(SETITEM)
394 }
395 }
396}
397
398impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer<W> {
399 type Ok = ();
400 type Error = Error;
401
402 type SerializeSeq = Compound<'a, W>;
403 type SerializeTuple = Self::SerializeSeq;
404 type SerializeTupleStruct = Self::SerializeTuple;
405 type SerializeTupleVariant = Self::SerializeTuple;
406 type SerializeMap = Compound<'a, W>;
407 type SerializeStruct = Self::SerializeMap;
408 type SerializeStructVariant = Self::SerializeMap;
409
410 #[inline]
411 fn serialize_bool(self, value: bool) -> Result<()> {
412 self.write_opcode(if value { NEWTRUE } else { NEWFALSE })
413 }
414
415 #[inline]
416 fn serialize_i8(self, value: i8) -> Result<()> {
417 if value > 0 {
418 self.write_opcode(BININT1)?;
419 self.writer.write_i8(value).map_err(From::from)
420 } else {
421 self.write_opcode(BININT)?;
422 self.writer.write_i32::<LittleEndian>(value.into()).map_err(From::from)
423 }
424 }
425
426 #[inline]
427 fn serialize_i16(self, value: i16) -> Result<()> {
428 if value > 0 {
429 self.write_opcode(BININT2)?;
430 self.writer.write_i16::<LittleEndian>(value).map_err(From::from)
431 } else {
432 self.write_opcode(BININT)?;
433 self.writer.write_i32::<LittleEndian>(value.into()).map_err(From::from)
434 }
435 }
436
437 #[inline]
438 fn serialize_i32(self, value: i32) -> Result<()> {
439 self.write_opcode(BININT)?;
440 self.writer.write_i32::<LittleEndian>(value).map_err(From::from)
441 }
442
443 #[inline]
444 fn serialize_i64(self, value: i64) -> Result<()> {
445 if (-0x8000_0000..0x8000_0000).contains(&value) {
446 self.write_opcode(BININT)?;
447 self.writer.write_i32::<LittleEndian>(value as i32).map_err(From::from)
448 } else {
449 self.write_opcode(LONG1)?;
450 self.writer.write_i8(8)?;
451 self.writer.write_i64::<LittleEndian>(value).map_err(From::from)
452 }
453 }
454
455 #[inline]
456 fn serialize_u8(self, value: u8) -> Result<()> {
457 self.write_opcode(BININT1)?;
458 self.writer.write_u8(value).map_err(From::from)
459 }
460
461 #[inline]
462 fn serialize_u16(self, value: u16) -> Result<()> {
463 self.write_opcode(BININT2)?;
464 self.writer.write_u16::<LittleEndian>(value).map_err(From::from)
465 }
466
467 #[inline]
468 fn serialize_u32(self, value: u32) -> Result<()> {
469 if value < 0x8000_0000 {
470 self.write_opcode(BININT)?;
471 self.writer.write_u32::<LittleEndian>(value).map_err(From::from)
472 } else {
473 self.write_opcode(LONG1)?;
474 self.writer.write_i8(5)?;
475 self.writer.write_u32::<LittleEndian>(value)?;
476 self.writer.write_i8(0).map_err(From::from)
479 }
480 }
481
482 #[inline]
483 fn serialize_u64(self, value: u64) -> Result<()> {
484 if value < 0x8000_0000 {
485 self.write_opcode(BININT)?;
486 self.writer.write_u32::<LittleEndian>(value as u32).map_err(From::from)
487 } else {
488 self.write_opcode(LONG1)?;
489 self.writer.write_i8(9)?;
490 self.writer.write_u64::<LittleEndian>(value)?;
491 self.writer.write_i8(0).map_err(From::from)
494 }
495 }
496
497 #[inline]
498 fn serialize_f32(self, value: f32) -> Result<()> {
499 self.write_opcode(BINFLOAT)?;
500 self.writer.write_f64::<BigEndian>(value.into()).map_err(From::from)
502 }
503
504 #[inline]
505 fn serialize_f64(self, value: f64) -> Result<()> {
506 self.write_opcode(BINFLOAT)?;
507 self.writer.write_f64::<BigEndian>(value).map_err(From::from)
508 }
509
510 #[inline]
511 fn serialize_char(self, value: char) -> Result<()> {
512 let mut string = String::with_capacity(4); string.push(value);
514 self.serialize_str(&string)
515 }
516
517 #[inline]
518 fn serialize_str(self, value: &str) -> Result<()> {
519 self.write_opcode(BINUNICODE)?;
520 self.writer.write_u32::<LittleEndian>(value.len() as u32)?;
521 self.writer.write_all(value.as_bytes()).map_err(From::from)
522 }
523
524 #[inline]
525 fn serialize_bytes(self, value: &[u8]) -> Result<()> {
526 if self.options.proto == PickleProto::V3 {
527 if value.len() < 256 {
528 self.write_opcode(SHORT_BINBYTES)?;
529 self.writer.write_u8(value.len() as u8)?;
530 } else {
531 self.write_opcode(BINBYTES)?;
532 self.writer.write_u32::<LittleEndian>(value.len() as u32)?;
533 }
534 self.writer.write_all(value).map_err(From::from)
535 } else {
536 self.write_opcode(GLOBAL)?;
544 self.writer.write_all(b"_codecs\nencode\n")?;
545 let utf8_value: String = value.iter().map(|&c| c as char).collect();
551 self.serialize_str(&utf8_value)?;
552 self.serialize_str("latin1")?;
553 self.write_opcode(TUPLE2)?;
554 self.write_opcode(REDUCE).map_err(From::from)
555 }
556 }
557
558 #[inline]
559 fn serialize_unit(self) -> Result<()> {
560 self.write_opcode(NONE)
563 }
564
565 #[inline]
566 fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
567 self.write_opcode(NONE)
568 }
569
570 #[inline]
571 fn serialize_unit_variant(
572 self, _name: &'static str, _variant_index: u32, variant: &'static str,
573 ) -> Result<()> {
574 self.serialize_str(variant)?;
575 if self.options.compat_enum_repr {
576 self.write_opcode(TUPLE1)
577 } else {
578 Ok(())
579 }
580 }
581
582 #[inline]
583 fn serialize_newtype_struct<T: Serialize + ?Sized>(self, _name: &'static str, value: &T) -> Result<()> {
584 value.serialize(self)
585 }
586
587 #[inline]
588 fn serialize_newtype_variant<T: Serialize + ?Sized>(
589 self, _name: &'static str, _variant_index: u32, variant: &'static str, value: &T,
590 ) -> Result<()> {
591 if self.options.compat_enum_repr {
592 self.serialize_str(variant)?;
593 value.serialize(&mut *self)?;
594 self.write_opcode(TUPLE2)
595 } else {
596 self.write_opcode(EMPTY_DICT)?;
597 self.serialize_str(variant)?;
598 value.serialize(&mut *self)?;
599 self.write_opcode(SETITEM)
600 }
601 }
602
603 #[inline]
604 fn serialize_none(self) -> Result<()> {
605 self.serialize_unit()
606 }
607
608 #[inline]
609 fn serialize_some<T: Serialize + ?Sized>(self, value: &T) -> Result<()> {
610 value.serialize(self)
611 }
612
613 #[inline]
614 fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
615 self.write_opcode(EMPTY_LIST)?;
616 match len {
617 Some(0) => Ok(Compound { ser: self, state: None }),
618 _ => {
619 self.write_opcode(MARK)?;
620 Ok(Compound { ser: self, state: Some(0) })
621 }
622 }
623 }
624
625 #[inline]
626 fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
627 if len == 0 {
628 self.write_opcode(EMPTY_TUPLE)?;
629 Ok(Compound { ser: self, state: None })
630 } else {
631 self.write_opcode(MARK)?;
632 Ok(Compound { ser: self, state: Some(0) })
633 }
634 }
635
636 #[inline]
637 fn serialize_tuple_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeTupleStruct> {
638 self.serialize_tuple(len)
639 }
640
641 #[inline]
642 fn serialize_tuple_variant(
643 self, _name: &'static str, _variant_index: u32, variant: &'static str, _len: usize,
644 ) -> Result<Self::SerializeTupleVariant> {
645 if !self.options.compat_enum_repr {
646 self.write_opcode(EMPTY_DICT)?;
647 }
648 self.serialize_str(variant)?;
649 self.write_opcode(EMPTY_LIST)?;
650 self.write_opcode(MARK)?;
651 Ok(Compound { ser: self, state: None })
652 }
653
654 #[inline]
655 fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap> {
656 self.write_opcode(EMPTY_DICT)?;
657 match len {
658 Some(0) => Ok(Compound { ser: self, state: None }),
659 _ => {
660 self.write_opcode(MARK)?;
661 Ok(Compound { ser: self, state: Some(0) })
662 }
663 }
664 }
665
666 #[inline]
667 fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
668 self.serialize_map(Some(len))
669 }
670
671 #[inline]
672 fn serialize_struct_variant(
673 self, _name: &'static str, _variant_index: u32, variant: &'static str, len: usize,
674 ) -> Result<Self::SerializeStructVariant> {
675 if !self.options.compat_enum_repr {
676 self.write_opcode(EMPTY_DICT)?;
677 }
678 self.serialize_str(variant)?;
679 self.serialize_map(Some(len))
680 }
681}
682
683fn wrap_write<W: io::Write, F>(mut writer: W, inner: F, options: SerOptions) -> Result<()>
684where
685 F: FnOnce(&mut Serializer<W>) -> Result<()>,
686{
687 writer.write_all(&[PROTO])?;
688 if options.proto == PickleProto::V3 {
689 writer.write_all(b"\x03")?;
690 } else {
691 writer.write_all(b"\x02")?;
692 }
693 let mut ser = Serializer::new(writer, options);
694 inner(&mut ser)?;
695 let mut writer = ser.into_inner();
696 writer.write_all(&[STOP]).map_err(From::from)
697}
698
699pub fn value_to_writer<W: io::Write>(writer: &mut W, value: &Value, options: SerOptions) -> Result<()> {
701 wrap_write(writer, |ser| ser.serialize_value(value), options)
702}
703
704#[inline]
706pub fn to_writer<W: io::Write, T: Serialize>(writer: &mut W, value: &T, options: SerOptions) -> Result<()> {
707 wrap_write(writer, |ser| value.serialize(ser), options)
708}
709
710#[inline]
712pub fn value_to_vec(value: &Value, options: SerOptions) -> Result<Vec<u8>> {
713 let mut writer = Vec::with_capacity(128);
714 value_to_writer(&mut writer, value, options)?;
715 Ok(writer)
716}
717
718#[inline]
720pub fn to_vec<T: Serialize>(value: &T, options: SerOptions) -> Result<Vec<u8>> {
721 let mut writer = Vec::with_capacity(128);
722 to_writer(&mut writer, value, options)?;
723 Ok(writer)
724}