tor_cell/relaycell/hs/ext.rs
1//! Helpers to manage lists of HS cell extensions.
2//
3// TODO: We might generalize this even more in the future to handle other
4// similar lists in our cell protocol.
5
6use derive_deftly::Deftly;
7use tor_bytes::{EncodeError, EncodeResult, Readable, Reader, Result, Writeable, Writer};
8use tor_memquota::{derive_deftly_template_HasMemoryCost, HasMemoryCostStructural};
9
10/// A list of extensions, represented in a common format used by many HS-related
11/// message.
12///
13/// The common format is:
14/// ```text
15/// N_EXTENSIONS [1 byte]
16/// N_EXTENSIONS times:
17/// EXT_FIELD_TYPE [1 byte]
18/// EXT_FIELD_LEN [1 byte]
19/// EXT_FIELD [EXT_FIELD_LEN bytes]
20/// ```
21///
22/// It is subject to the additional restraints:
23///
24/// * Each extension type SHOULD be sent only once in a message.
25/// * Parties MUST ignore any occurrences all occurrences of an extension
26/// with a given type after the first such occurrence.
27/// * Extensions SHOULD be sent in numerically ascending order by type.
28#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, Deftly)]
29#[derive_deftly(HasMemoryCost)]
30#[deftly(has_memory_cost(bounds = "T: HasMemoryCostStructural"))]
31pub(super) struct ExtList<T> {
32 /// The extensions themselves.
33 extensions: Vec<T>,
34}
35impl<T> Default for ExtList<T> {
36 fn default() -> Self {
37 Self {
38 extensions: Vec::new(),
39 }
40 }
41}
42/// An kind of extension that can be used with some kind of HS-related message.
43///
44/// Each extendible message will likely define its own enum,
45/// implementing this trait,
46/// representing the possible extensions.
47pub(super) trait ExtGroup: Readable + Writeable {
48 /// An identifier kind used with this sort of extension
49 type Id: From<u8> + Into<u8> + Eq + PartialEq + Ord + Copy;
50 /// The field-type id for this particular extension.
51 fn type_id(&self) -> Self::Id;
52}
53/// A single typed extension that can be used with some kind of HS-related message.
54pub(super) trait Ext: Sized {
55 /// An identifier kind used with this sort of extension.
56 ///
57 /// Typically defined with caret_int.
58 type Id: From<u8> + Into<u8>;
59 /// The field-type id for this particular extension.
60 fn type_id(&self) -> Self::Id;
61 /// Extract the body (not the type or the length) from a single
62 /// extension.
63 fn take_body_from(b: &mut Reader<'_>) -> Result<Self>;
64 /// Write the body (not the type or the length) for a single extension.
65 fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()>;
66}
67impl<T: ExtGroup> Readable for ExtList<T> {
68 fn take_from(b: &mut Reader<'_>) -> Result<Self> {
69 let n_extensions = b.take_u8()?;
70 let extensions: Result<Vec<T>> = (0..n_extensions).map(|_| b.extract::<T>()).collect();
71 Ok(Self {
72 extensions: extensions?,
73 })
74 }
75}
76impl<T: ExtGroup> Writeable for ExtList<T> {
77 fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
78 let n_extensions = self
79 .extensions
80 .len()
81 .try_into()
82 .map_err(|_| EncodeError::BadLengthValue)?;
83 b.write_u8(n_extensions);
84 let mut exts_sorted: Vec<&T> = self.extensions.iter().collect();
85 exts_sorted.sort_by_key(|ext| ext.type_id());
86 exts_sorted.iter().try_for_each(|ext| ext.write_onto(b))?;
87 Ok(())
88 }
89}
90impl<T: ExtGroup> ExtList<T> {
91 /// Insert `ext` into this list of extensions, replacing any previous
92 /// extension with the same field type ID.
93 pub(super) fn replace_by_type(&mut self, ext: T) {
94 self.retain(|e| e.type_id() != ext.type_id());
95 self.push(ext);
96 }
97}
98
99/// An unrecognized or unencoded extension for some HS-related message.
100#[derive(Clone, Debug, Deftly)]
101#[derive_deftly(HasMemoryCost)]
102// Use `Copy + 'static` and `#[deftly(has_memory_cost(copy))]` so that we don't
103// need to derive HasMemoryCost for the id types, which are indeed all Copy.
104#[deftly(has_memory_cost(bounds = "ID: Copy + 'static"))]
105pub struct UnrecognizedExt<ID> {
106 /// The field type ID for this extension.
107 #[deftly(has_memory_cost(copy))]
108 pub(super) type_id: ID,
109 /// The body of this extension.
110 pub(super) body: Vec<u8>,
111}
112
113impl<ID> UnrecognizedExt<ID> {
114 /// Return a new unrecognized extension with a given ID and body.
115 ///
116 /// NOTE: nothing actually enforces that this type ID is not
117 /// recognized.
118 ///
119 /// NOTE: This function accepts bodies longer than 255 bytes, but
120 /// it is not possible to encode them.
121 pub fn new(type_id: ID, body: impl Into<Vec<u8>>) -> Self {
122 Self {
123 type_id,
124 body: body.into(),
125 }
126 }
127}
128
129/// Declare an Extension group that takes a given identifier.
130//
131// TODO: This is rather similar to restrict_msg(), isn't it? Also, We use this
132// pattern of (number, (cmd, length, body)*) a few of times in Tor outside the
133// hs module. Perhaps we can extend and unify our code here...
134macro_rules! decl_extension_group {
135 {
136 $( #[$meta:meta] )*
137 $v:vis enum $id:ident [ $type_id:ty ] {
138 $(
139 $(#[$cmeta:meta])*
140 $case:ident),*
141 $(,)?
142 }
143 } => {paste::paste!{
144 $( #[$meta] )*
145 $v enum $id {
146 $( $(#[$cmeta])*
147 $case($case),
148 )*
149 /// An extension of a type we do not recognize, or which we have not
150 /// encoded.
151 Unrecognized(UnrecognizedExt<$type_id>)
152 }
153 impl Readable for $id {
154 fn take_from(b: &mut Reader<'_>) -> Result<Self> {
155 let type_id = b.take_u8()?.into();
156 Ok(match type_id {
157 $(
158 $type_id::[< $case:snake:upper >] => {
159 Self::$case( b.read_nested_u8len(|r| $case::take_body_from(r))? )
160 }
161 )*
162 _ => {
163 Self::Unrecognized(UnrecognizedExt {
164 type_id,
165 body: b.read_nested_u8len(|r| Ok(r.take_rest().into()))?,
166 })
167 }
168 })
169 }
170 }
171 impl Writeable for $id {
172 fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<
173()> {
174 #[allow(unused)]
175 use std::ops::DerefMut;
176 match self {
177 $(
178 Self::$case(val) => {
179 b.write_u8(val.type_id().into());
180 let mut nested = b.write_nested_u8len();
181 val.write_body_onto(nested.deref_mut())?;
182 nested.finish()?;
183 }
184 )*
185 Self::Unrecognized(unrecognized) => {
186 b.write_u8(unrecognized.type_id.into());
187 let mut nested = b.write_nested_u8len();
188 nested.write_all(&unrecognized.body[..]);
189 nested.finish()?;
190 }
191 }
192 Ok(())
193 }
194 }
195 impl ExtGroup for $id {
196 type Id = $type_id;
197 fn type_id(&self) -> Self::Id {
198 match self {
199 $(
200 Self::$case(val) => val.type_id(),
201 )*
202 Self::Unrecognized(unrecognized) => unrecognized.type_id,
203 }
204 }
205 }
206 $(
207 impl From<$case> for $id {
208 fn from(val: $case) -> $id {
209 $id :: $case ( val )
210 }
211 }
212 )*
213 impl From<UnrecognizedExt<$type_id>> for $id {
214 fn from(val: UnrecognizedExt<$type_id>) -> $id {
215 $id :: Unrecognized(val)
216 }
217 }
218}}
219}
220pub(super) use decl_extension_group;