tor_cell/relaycell/
extend.rs1use super::extlist::{Ext, ExtList, ExtListRef, decl_extension_group};
4#[cfg(feature = "hs")]
5use super::hs::pow::ProofOfWork;
6use caret::caret_int;
7use itertools::Itertools as _;
8use tor_bytes::{EncodeResult, Reader, Writeable as _, Writer};
9use tor_protover::NumberedSubver;
10
11caret_int! {
12 #[derive(PartialOrd,Ord)]
14 pub struct CircRequestExtType(u8) {
15 CC_REQUEST = 1,
17 PROOF_OF_WORK = 2,
20 SUBPROTOCOL_REQUEST = 3,
22 }
23}
24
25caret_int! {
26 #[derive(PartialOrd,Ord)]
28 pub struct CircResponseExtType(u8) {
29 CC_RESPONSE = 2
31 }
32}
33
34#[derive(Clone, Debug, PartialEq, Eq, Default)]
38#[non_exhaustive]
39pub struct CcRequest {}
40
41impl Ext for CcRequest {
42 type Id = CircRequestExtType;
43 fn type_id(&self) -> Self::Id {
44 CircRequestExtType::CC_REQUEST
45 }
46 fn take_body_from(_b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
47 Ok(Self {})
48 }
49 fn write_body_onto<B: Writer + ?Sized>(&self, _b: &mut B) -> EncodeResult<()> {
50 Ok(())
51 }
52}
53
54#[derive(Clone, Debug, PartialEq, Eq)]
58pub struct CcResponse {
59 sendme_inc: u8,
61}
62
63impl CcResponse {
64 pub fn new(sendme_inc: u8) -> Self {
67 CcResponse { sendme_inc }
68 }
69
70 pub fn sendme_inc(&self) -> u8 {
72 self.sendme_inc
73 }
74}
75
76impl Ext for CcResponse {
77 type Id = CircResponseExtType;
78 fn type_id(&self) -> Self::Id {
79 CircResponseExtType::CC_RESPONSE
80 }
81
82 fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
83 let sendme_inc = b.take_u8()?;
84 Ok(Self { sendme_inc })
85 }
86
87 fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
88 b.write_u8(self.sendme_inc);
89 Ok(())
90 }
91}
92
93#[derive(Clone, Debug, PartialEq, Eq)]
95pub struct SubprotocolRequest {
96 protocols: Vec<tor_protover::NumberedSubver>,
98}
99
100impl<A> FromIterator<A> for SubprotocolRequest
101where
102 A: Into<tor_protover::NumberedSubver>,
103{
104 fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
105 let mut protocols: Vec<_> = iter.into_iter().map(Into::into).collect();
106 protocols.sort();
107 protocols.dedup();
108 Self { protocols }
109 }
110}
111
112impl Ext for SubprotocolRequest {
113 type Id = CircRequestExtType;
114
115 fn type_id(&self) -> Self::Id {
116 CircRequestExtType::SUBPROTOCOL_REQUEST
117 }
118
119 fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
120 let mut protocols = Vec::new();
121 while b.remaining() != 0 {
122 protocols.push(b.extract()?);
123 }
124
125 if !is_strictly_ascending(&protocols) {
126 return Err(tor_bytes::Error::InvalidMessage(
127 "SubprotocolRequest not sorted and deduplicated.".into(),
128 ));
129 }
130
131 Ok(Self { protocols })
132 }
133
134 fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
135 for p in self.protocols.iter() {
136 b.write(p)?;
137 }
138 Ok(())
139 }
140}
141impl SubprotocolRequest {
142 pub fn contains(&self, cap: tor_protover::NamedSubver) -> bool {
144 self.protocols.binary_search(&cap.into()).is_ok()
145 }
146
147 pub fn contains_only(&self, list: &tor_protover::Protocols) -> bool {
150 self.protocols
151 .iter()
152 .all(|p| list.supports_numbered_subver(*p))
153 }
154}
155
156decl_extension_group! {
157 #[derive(Debug,Clone,PartialEq)]
160 #[non_exhaustive]
161 pub enum CircRequestExt [ CircRequestExtType ] {
162 CcRequest,
164 [ feature: #[cfg(feature = "hs")] ]
166 ProofOfWork,
167 SubprotocolRequest,
169 }
170}
171
172decl_extension_group! {
173 #[derive(Debug,Clone,PartialEq)]
179 #[non_exhaustive]
180 pub enum CircResponseExt [ CircResponseExtType ] {
181 CcResponse,
183 }
184}
185
186macro_rules! impl_encode_decode {
189 ($extgroup:ty, $name:expr) => {
190 impl $extgroup {
191 pub fn write_many_onto<W: Writer>(exts: &[Self], out: &mut W) -> EncodeResult<()> {
193 ExtListRef::from(exts).write_onto(out)?;
194 Ok(())
195 }
196 pub fn decode(message: &[u8]) -> crate::Result<Vec<Self>> {
199 let err_cvt = |err| crate::Error::BytesErr { err, parsed: $name };
200 let mut r = tor_bytes::Reader::from_slice(message);
201 let list: ExtList<_> = r.extract().map_err(err_cvt)?;
202 r.should_be_exhausted().map_err(err_cvt)?;
203 Ok(list.into_vec())
204 }
205 }
206 };
207}
208
209impl_encode_decode!(CircRequestExt, "CREATE2 extension list");
210impl_encode_decode!(CircResponseExt, "CREATED2 extension list");
211
212fn is_strictly_ascending(vers: &[NumberedSubver]) -> bool {
214 vers.iter().tuple_windows().all(|(a, b)| a < b)
216}
217
218#[cfg(test)]
219mod test {
220 #![allow(clippy::bool_assert_comparison)]
222 #![allow(clippy::clone_on_copy)]
223 #![allow(clippy::dbg_macro)]
224 #![allow(clippy::mixed_attributes_style)]
225 #![allow(clippy::print_stderr)]
226 #![allow(clippy::print_stdout)]
227 #![allow(clippy::single_char_pattern)]
228 #![allow(clippy::unwrap_used)]
229 #![allow(clippy::unchecked_time_subtraction)]
230 #![allow(clippy::useless_vec)]
231 #![allow(clippy::needless_pass_by_value)]
232 use super::*;
234
235 #[test]
236 fn subproto_ext_valid() {
237 use tor_protover::named::*;
238 let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
239 let mut v = Vec::new();
240 sp.write_body_onto(&mut v).unwrap();
241 assert_eq!(&v[..], [0, 4, 2, 4]);
242
243 let mut r = Reader::from_slice(&v[..]);
244 let sp2: SubprotocolRequest = SubprotocolRequest::take_body_from(&mut r).unwrap();
245 assert_eq!(sp, sp2);
246 }
247
248 #[test]
249 fn subproto_invalid() {
250 let mut r = Reader::from_slice(&[0, 4, 2]);
252 let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
253 dbg!(e.to_string());
254 assert!(e.to_string().contains("too short"));
255
256 let mut r = Reader::from_slice(&[0, 4, 0, 4]);
258 let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
259 dbg!(e.to_string());
260 assert!(e.to_string().contains("deduplicated"));
261
262 let mut r = Reader::from_slice(&[2, 4, 0, 4]);
264 let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
265 dbg!(e.to_string());
266 assert!(e.to_string().contains("sorted"));
267 }
268
269 #[test]
270 fn subproto_supported() {
271 use tor_protover::named::*;
272 let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
273 assert!(sp.contains(LINK_V4));
275 assert!(!sp.contains(LINK_V2));
276
277 assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4, CONFLUX_BASE].into_iter().collect()));
280 assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4].into_iter().collect()));
281 assert!(!sp.contains_only(&[LINK_V4].into_iter().collect()));
282 assert!(!sp.contains_only(&[LINK_V4, CONFLUX_BASE].into_iter().collect()));
283 assert!(!sp.contains_only(&[CONFLUX_BASE].into_iter().collect()));
284 }
285}