tor_cell/relaycell/hs/
ext.rs

1//! Helpers to manage lists of HS cell extensions.
2//
3// TODO: We might generalize this even more in the future to handle other
4// similar lists in our cell protocol.
5
6use derive_deftly::Deftly;
7use tor_bytes::{EncodeError, EncodeResult, Readable, Reader, Result, Writeable, Writer};
8use tor_memquota::{derive_deftly_template_HasMemoryCost, HasMemoryCostStructural};
9
10/// A list of extensions, represented in a common format used by many HS-related
11/// message.
12///
13/// The common format is:
14/// ```text
15///      N_EXTENSIONS     [1 byte]
16///      N_EXTENSIONS times:
17///           EXT_FIELD_TYPE [1 byte]
18///           EXT_FIELD_LEN  [1 byte]
19///           EXT_FIELD      [EXT_FIELD_LEN bytes]
20/// ```
21///
22/// It is subject to the additional restraints:
23///
24/// * Each extension type SHOULD be sent only once in a message.
25/// * Parties MUST ignore any occurrences all occurrences of an extension
26///   with a given type after the first such occurrence.
27/// * Extensions SHOULD be sent in numerically ascending order by type.
28#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, Deftly)]
29#[derive_deftly(HasMemoryCost)]
30#[deftly(has_memory_cost(bounds = "T: HasMemoryCostStructural"))]
31pub(super) struct ExtList<T> {
32    /// The extensions themselves.
33    extensions: Vec<T>,
34}
35impl<T> Default for ExtList<T> {
36    fn default() -> Self {
37        Self {
38            extensions: Vec::new(),
39        }
40    }
41}
42/// An kind of extension that can be used with some kind of HS-related message.
43///
44/// Each extendible message will likely define its own enum,
45/// implementing this trait,
46/// representing the possible extensions.
47pub(super) trait ExtGroup: Readable + Writeable {
48    /// An identifier kind used with this sort of extension
49    type Id: From<u8> + Into<u8> + Eq + PartialEq + Ord + Copy;
50    /// The field-type id for this particular extension.
51    fn type_id(&self) -> Self::Id;
52}
53/// A single typed extension that can be used with some kind of HS-related message.
54pub(super) trait Ext: Sized {
55    /// An identifier kind used with this sort of extension.
56    ///
57    /// Typically defined with caret_int.
58    type Id: From<u8> + Into<u8>;
59    /// The field-type id for this particular extension.
60    fn type_id(&self) -> Self::Id;
61    /// Extract the body (not the type or the length) from a single
62    /// extension.
63    fn take_body_from(b: &mut Reader<'_>) -> Result<Self>;
64    /// Write the body (not the type or the length) for a single extension.
65    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()>;
66}
67impl<T: ExtGroup> Readable for ExtList<T> {
68    fn take_from(b: &mut Reader<'_>) -> Result<Self> {
69        let n_extensions = b.take_u8()?;
70        let extensions: Result<Vec<T>> = (0..n_extensions).map(|_| b.extract::<T>()).collect();
71        Ok(Self {
72            extensions: extensions?,
73        })
74    }
75}
76impl<T: ExtGroup> Writeable for ExtList<T> {
77    fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
78        let n_extensions = self
79            .extensions
80            .len()
81            .try_into()
82            .map_err(|_| EncodeError::BadLengthValue)?;
83        b.write_u8(n_extensions);
84        let mut exts_sorted: Vec<&T> = self.extensions.iter().collect();
85        exts_sorted.sort_by_key(|ext| ext.type_id());
86        exts_sorted.iter().try_for_each(|ext| ext.write_onto(b))?;
87        Ok(())
88    }
89}
90impl<T: ExtGroup> ExtList<T> {
91    /// Insert `ext` into this list of extensions, replacing any previous
92    /// extension with the same field type ID.
93    pub(super) fn replace_by_type(&mut self, ext: T) {
94        self.retain(|e| e.type_id() != ext.type_id());
95        self.push(ext);
96    }
97}
98
99/// An unrecognized or unencoded extension for some HS-related message.
100#[derive(Clone, Debug, Deftly)]
101#[derive_deftly(HasMemoryCost)]
102// Use `Copy + 'static` and `#[deftly(has_memory_cost(copy))]` so that we don't
103// need to derive HasMemoryCost for the id types, which are indeed all Copy.
104#[deftly(has_memory_cost(bounds = "ID: Copy + 'static"))]
105pub struct UnrecognizedExt<ID> {
106    /// The field type ID for this extension.
107    #[deftly(has_memory_cost(copy))]
108    pub(super) type_id: ID,
109    /// The body of this extension.
110    pub(super) body: Vec<u8>,
111}
112
113impl<ID> UnrecognizedExt<ID> {
114    /// Return a new unrecognized extension with a given ID and body.
115    ///
116    /// NOTE: nothing actually enforces that this type ID is not
117    /// recognized.
118    ///
119    /// NOTE: This function accepts bodies longer than 255 bytes, but
120    /// it is not possible to encode them.
121    pub fn new(type_id: ID, body: impl Into<Vec<u8>>) -> Self {
122        Self {
123            type_id,
124            body: body.into(),
125        }
126    }
127}
128
129/// Declare an Extension group that takes a given identifier.
130//
131// TODO: This is rather similar to restrict_msg(), isn't it?  Also, We use this
132// pattern of (number, (cmd, length, body)*) a few of times in Tor outside the
133// hs module.  Perhaps we can extend and unify our code here...
134macro_rules! decl_extension_group {
135    {
136        $( #[$meta:meta] )*
137        $v:vis enum $id:ident [ $type_id:ty ] {
138            $(
139                $(#[$cmeta:meta])*
140                $case:ident),*
141            $(,)?
142        }
143    } => {paste::paste!{
144        $( #[$meta] )*
145        $v enum $id {
146            $( $(#[$cmeta])*
147               $case($case),
148            )*
149            /// An extension of a type we do not recognize, or which we have not
150            /// encoded.
151            Unrecognized(UnrecognizedExt<$type_id>)
152        }
153        impl Readable for $id {
154            fn take_from(b: &mut Reader<'_>) -> Result<Self> {
155                let type_id = b.take_u8()?.into();
156                Ok(match type_id {
157                    $(
158                        $type_id::[< $case:snake:upper >] => {
159                            Self::$case( b.read_nested_u8len(|r| $case::take_body_from(r))? )
160                        }
161                    )*
162                    _ => {
163                        Self::Unrecognized(UnrecognizedExt {
164                            type_id,
165                            body: b.read_nested_u8len(|r| Ok(r.take_rest().into()))?,
166                        })
167                    }
168                })
169            }
170        }
171        impl Writeable for $id {
172            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<
173()> {
174                #[allow(unused)]
175                use std::ops::DerefMut;
176                match self {
177                    $(
178                        Self::$case(val) => {
179                            b.write_u8(val.type_id().into());
180                            let mut nested = b.write_nested_u8len();
181                            val.write_body_onto(nested.deref_mut())?;
182                            nested.finish()?;
183                        }
184                    )*
185                    Self::Unrecognized(unrecognized) => {
186                        b.write_u8(unrecognized.type_id.into());
187                        let mut nested = b.write_nested_u8len();
188                        nested.write_all(&unrecognized.body[..]);
189                        nested.finish()?;
190                    }
191                }
192                Ok(())
193            }
194        }
195        impl ExtGroup for $id {
196            type Id = $type_id;
197            fn type_id(&self) -> Self::Id {
198                match self {
199                    $(
200                        Self::$case(val) => val.type_id(),
201                    )*
202                    Self::Unrecognized(unrecognized) => unrecognized.type_id,
203                }
204            }
205        }
206        $(
207        impl From<$case> for $id {
208            fn from(val: $case) -> $id {
209                $id :: $case ( val )
210            }
211        }
212        )*
213        impl From<UnrecognizedExt<$type_id>> for $id {
214            fn from(val: UnrecognizedExt<$type_id>) -> $id {
215                $id :: Unrecognized(val)
216            }
217        }
218}}
219}
220pub(super) use decl_extension_group;