1use core::convert::{TryFrom, TryInto};
2use crate::{Encodable, ErrorKind, header::Header, Length, Result, Tag};
3
4#[derive(Debug)]
6pub struct Encoder<'a> {
7 bytes: Option<&'a mut [u8]>,
9
10 position: Length,
12}
13
14impl<'a> Encoder<'a> {
15 pub fn new(bytes: &'a mut [u8]) -> Self {
17 Self {
18 bytes: Some(bytes),
19 position: Length::zero(),
20 }
21 }
22
23 pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
25 if self.is_failed() {
26 self.error(ErrorKind::Failed)?;
27 }
28
29 encodable.encode(self).map_err(|e| {
30 self.bytes.take();
31 e.nested(self.position)
32 })
33 }
34
35 pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
38 self.bytes.take();
39 Err(kind.at(self.position))
40 }
41
42 pub fn is_failed(&self) -> bool {
44 self.bytes.is_none()
45 }
46
47 pub fn finish(self) -> Result<&'a [u8]> {
50 let position = self.position;
51
52 match self.bytes {
53 Some(bytes) => bytes
54 .get(..self.position.into())
55 .ok_or_else(|| ErrorKind::Truncated.at(position)),
56 None => Err(ErrorKind::Failed.at(position)),
57 }
58 }
59
60 pub fn encode_tagged_collection(&mut self, tag: Tag, encodables: &[&dyn Encodable]) -> Result<()> {
62 let expected_len = Length::try_from(encodables)?;
63 Header::new(tag, expected_len).and_then(|header| header.encode(self))?;
64
65 let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
66
67 for encodable in encodables {
68 encodable.encode(&mut nested_encoder)?;
69 }
70
71 if nested_encoder.finish()?.len() == expected_len.into() {
72 Ok(())
73 } else {
74 self.error(ErrorKind::Length { tag })
75 }
76 }
77
78 pub fn encode_untagged_collection(&mut self, encodables: &[&dyn Encodable]) -> Result<()> {
80 let expected_len = Length::try_from(encodables)?;
81 let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
82
83 for encodable in encodables {
84 encodable.encode(&mut nested_encoder)?;
85 }
86 Ok(())
87 }
88
89 pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
91 match self.reserve(1u8)?.first_mut() {
92 Some(b) => {
93 *b = byte;
94 Ok(())
95 }
96 None => self.error(ErrorKind::Truncated),
97 }
98 }
99
100 pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
102 self.reserve(slice.len())?.copy_from_slice(slice);
103 Ok(())
104 }
105
106 fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
109 let len = len
110 .try_into()
111 .or_else(|_| self.error(ErrorKind::Overflow))?;
112
113 if len > self.remaining_len()? {
114 self.error(ErrorKind::Overlength)?;
115 }
116
117 let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
118 let range = self.position.into()..end.into();
119 let position = &mut self.position;
120
121 let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
130 *position = end;
131
132 Ok(slice)
133 }
134
135 fn buffer_len(&self) -> Result<Length> {
137 self.bytes
138 .as_ref()
139 .map(|bytes| bytes.len())
140 .ok_or_else(|| ErrorKind::Failed.at(self.position))
141 .and_then(TryInto::try_into)
142 }
143
144 fn remaining_len(&self) -> Result<Length> {
146 self.buffer_len()?
147 .to_usize()
148 .checked_sub(self.position.into())
149 .ok_or_else(|| ErrorKind::Truncated.at(self.position))
150 .and_then(TryInto::try_into)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use core::convert::TryFrom;
157 use crate::{Encodable, Tag, TaggedSlice};
158
159 #[test]
160 fn zero_length() {
161 let tv = TaggedSlice::from(Tag::try_from(42).unwrap(), &[]).unwrap();
162 let mut buf = [0u8; 4];
163 assert_eq!(tv.encode_to_slice(&mut buf).unwrap(), &[0x2A, 0x00]);
164 }
165}
166
167