tor_proto/stream/
incoming.rs1use bitvec::prelude::*;
4
5use super::{AnyCmdChecker, DataStream, StreamReader, StreamStatus};
6use crate::circuit::ClientCircSyncView;
7use crate::memquota::StreamAccount;
8use crate::tunnel::reactor::CloseStreamBehavior;
9use crate::tunnel::StreamTarget;
10use crate::{Error, Result};
11use derive_deftly::Deftly;
12use oneshot_fused_workaround as oneshot;
13use tor_cell::relaycell::{msg, RelayCmd, UnparsedRelayMsg};
14use tor_cell::restricted_msg;
15use tor_error::internal;
16use tor_memquota::derive_deftly_template_HasMemoryCost;
17
18#[derive(Debug)]
28pub struct IncomingStream {
29 request: IncomingStreamRequest,
31 stream: StreamTarget,
33 reader: StreamReader,
35 memquota: StreamAccount,
37}
38
39impl IncomingStream {
40 pub(crate) fn new(
42 request: IncomingStreamRequest,
43 stream: StreamTarget,
44 reader: StreamReader,
45 memquota: StreamAccount,
46 ) -> Self {
47 Self {
48 request,
49 stream,
50 reader,
51 memquota,
52 }
53 }
54
55 pub fn request(&self) -> &IncomingStreamRequest {
57 &self.request
58 }
59
60 pub async fn accept_data(self, message: msg::Connected) -> Result<DataStream> {
63 let Self {
64 request,
65 mut stream,
66 reader,
67 memquota,
68 } = self;
69
70 match request {
71 IncomingStreamRequest::Begin(_) | IncomingStreamRequest::BeginDir(_) => {
72 stream.send(message.into()).await?;
73 Ok(DataStream::new_connected(reader, stream, memquota))
74 }
75 IncomingStreamRequest::Resolve(_) => {
76 Err(internal!("Cannot accept data on a RESOLVE stream").into())
77 }
78 }
79 }
80
81 pub async fn reject(mut self, message: msg::End) -> Result<()> {
83 let rx = self.reject_inner(CloseStreamBehavior::SendEnd(message))?;
84
85 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
86 }
87
88 fn reject_inner(
92 &mut self,
93 message: CloseStreamBehavior,
94 ) -> Result<oneshot::Receiver<Result<()>>> {
95 self.stream.close_pending(message)
96 }
97
98 pub async fn discard(mut self) -> Result<()> {
104 let rx = self.reject_inner(CloseStreamBehavior::SendNothing)?;
105
106 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
107 }
108}
109
110restricted_msg! {
116 #[derive(Clone, Debug, Deftly)]
118 #[derive_deftly(HasMemoryCost)]
119 #[non_exhaustive]
120 pub enum IncomingStreamRequest: RelayMsg {
121 Begin,
123 BeginDir,
125 Resolve,
127 }
128}
129
130type RelayCmdSet = bitvec::BitArr!(for 256);
135
136#[derive(Debug)]
139pub(crate) struct IncomingCmdChecker {
140 allow_commands: RelayCmdSet,
148}
149
150impl IncomingCmdChecker {
151 pub(crate) fn new_any(allow_commands: &[RelayCmd]) -> AnyCmdChecker {
153 let mut array = BitArray::ZERO;
154 for c in allow_commands {
155 array.set(u8::from(*c) as usize, true);
156 }
157 Box::new(Self {
158 allow_commands: array,
159 })
160 }
161}
162
163impl super::CmdChecker for IncomingCmdChecker {
164 fn check_msg(&mut self, msg: &UnparsedRelayMsg) -> Result<StreamStatus> {
165 if self.allow_commands[u8::from(msg.cmd()) as usize] {
166 Ok(StreamStatus::Open)
167 } else {
168 Err(Error::StreamProto(format!(
169 "Unexpected {} on incoming stream",
170 msg.cmd()
171 )))
172 }
173 }
174
175 fn consume_checked_msg(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
176 let _ = msg
177 .decode::<IncomingStreamRequest>()
178 .map_err(|err| Error::from_bytes_err(err, "invalid message on incoming stream"))?;
179
180 Ok(())
181 }
182}
183
184pub trait IncomingStreamRequestFilter: Send + 'static {
191 fn disposition(
195 &mut self,
196 ctx: &IncomingStreamRequestContext<'_>,
197 circ: &ClientCircSyncView<'_>,
198 ) -> Result<IncomingStreamRequestDisposition>;
199}
200
201#[derive(Clone, Debug)]
203#[non_exhaustive]
204pub enum IncomingStreamRequestDisposition {
205 Accept,
208 CloseCircuit,
210 RejectRequest(msg::End),
212}
213
214pub struct IncomingStreamRequestContext<'a> {
216 pub(crate) request: &'a IncomingStreamRequest,
218}
219
220impl<'a> IncomingStreamRequestContext<'a> {
221 pub fn request(&self) -> &'a IncomingStreamRequest {
223 self.request
224 }
225}
226
227#[cfg(test)]
228mod test {
229 #![allow(clippy::bool_assert_comparison)]
231 #![allow(clippy::clone_on_copy)]
232 #![allow(clippy::dbg_macro)]
233 #![allow(clippy::mixed_attributes_style)]
234 #![allow(clippy::print_stderr)]
235 #![allow(clippy::print_stdout)]
236 #![allow(clippy::single_char_pattern)]
237 #![allow(clippy::unwrap_used)]
238 #![allow(clippy::unchecked_duration_subtraction)]
239 #![allow(clippy::useless_vec)]
240 #![allow(clippy::needless_pass_by_value)]
241 use tor_cell::relaycell::{
244 msg::{Begin, BeginDir, Data, Resolve},
245 AnyRelayMsgOuter, RelayCellFormat,
246 };
247
248 use super::*;
249
250 #[test]
251 fn incoming_cmd_checker() {
252 let u = |msg| {
254 let body = AnyRelayMsgOuter::new(None, msg)
255 .encode(&mut rand::thread_rng())
256 .unwrap();
257 UnparsedRelayMsg::from_singleton_body(RelayCellFormat::V0, body).unwrap()
258 };
259 let begin = u(Begin::new("allium.example.com", 443, 0).unwrap().into());
260 let begin_dir = u(BeginDir::default().into());
261 let resolve = u(Resolve::new("allium.example.com").into());
262 let data = u(Data::new(&[1, 2, 3]).unwrap().into());
263
264 {
265 let mut cc_none = IncomingCmdChecker::new_any(&[]);
266 for m in [&begin, &begin_dir, &resolve, &data] {
267 assert!(cc_none.check_msg(m).is_err());
268 }
269 }
270
271 {
272 let mut cc_begin = IncomingCmdChecker::new_any(&[RelayCmd::BEGIN]);
273 assert_eq!(cc_begin.check_msg(&begin).unwrap(), StreamStatus::Open);
274 for m in [&begin_dir, &resolve, &data] {
275 assert!(cc_begin.check_msg(m).is_err());
276 }
277 }
278
279 {
280 let mut cc_any = IncomingCmdChecker::new_any(&[
281 RelayCmd::BEGIN,
282 RelayCmd::BEGIN_DIR,
283 RelayCmd::RESOLVE,
284 ]);
285 for m in [&begin, &begin_dir, &resolve] {
286 assert_eq!(cc_any.check_msg(m).unwrap(), StreamStatus::Open);
287 }
288 assert!(cc_any.check_msg(&data).is_err());
289 }
290 }
291}