1use bitvec::prelude::*;
4use derive_deftly::Deftly;
5use oneshot_fused_workaround as oneshot;
6
7use tor_cell::relaycell::{RelayCellFormat, RelayCmd, StreamId, UnparsedRelayMsg, msg};
8use tor_cell::restricted_msg;
9use tor_error::internal;
10use tor_memquota::derive_deftly_template_HasMemoryCost;
11use tor_memquota::mq_queue::{self, MpscSpec};
12use tor_rtcompat::DynTimeProvider;
13
14use crate::circuit::CircHopSyncView;
15use crate::circuit::circhop::ReactorStreamComponents;
16use crate::stream::cmdcheck::{AnyCmdChecker, CmdChecker, StreamStatus};
17use crate::stream::{CloseStreamBehavior, StreamComponents};
18use crate::{Error, Result};
19
20use crate::client::stream::DataStream;
22
23use crate::memquota::StreamAccount;
24use crate::{HopLocation, HopNum};
25
26#[derive(Debug, Default)]
28pub(crate) struct InboundDataCmdChecker;
29
30restricted_msg! {
31 enum IncomingDataStreamMsg:RelayMsg {
33 Data, End,
35 }
36}
37
38impl CmdChecker for InboundDataCmdChecker {
39 fn check_msg(&mut self, msg: &tor_cell::relaycell::UnparsedRelayMsg) -> Result<StreamStatus> {
40 use StreamStatus::*;
41 match msg.cmd() {
42 RelayCmd::DATA => Ok(Open),
43 RelayCmd::END => Ok(Closed),
44 _ => Err(Error::StreamProto(format!(
45 "Unexpected {} on an incoming data stream!",
46 msg.cmd()
47 ))),
48 }
49 }
50
51 fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
52 let _ = msg
53 .decode::<IncomingDataStreamMsg>()
54 .map_err(|err| Error::from_bytes_err(err, "cell on half-closed stream"))?;
55 Ok(())
56 }
57}
58
59impl InboundDataCmdChecker {
60 pub(crate) fn new_connected() -> AnyCmdChecker {
66 Box::new(Self)
67 }
68}
69
70#[derive(Debug)]
80pub struct IncomingStream {
81 time_provider: DynTimeProvider,
83 request: IncomingStreamRequest,
85 components: StreamComponents,
87}
88
89impl IncomingStream {
90 pub(crate) fn new(
92 time_provider: DynTimeProvider,
93 request: IncomingStreamRequest,
94 components: StreamComponents,
95 ) -> Self {
96 Self {
97 time_provider,
98 request,
99 components,
100 }
101 }
102
103 pub fn request(&self) -> &IncomingStreamRequest {
105 &self.request
106 }
107
108 pub async fn accept_data(self, message: msg::Connected) -> Result<DataStream> {
111 let Self {
112 time_provider,
113 request,
114 components:
115 StreamComponents {
116 mut target,
117 stream_receiver,
118 xon_xoff_reader_ctrl,
119 memquota,
120 },
121 } = self;
122
123 match request {
124 IncomingStreamRequest::Begin(_) | IncomingStreamRequest::BeginDir(_) => {
125 target.send(message.into()).await?;
126 Ok(DataStream::new_connected(
127 time_provider,
128 stream_receiver,
129 xon_xoff_reader_ctrl,
130 target,
131 memquota,
132 ))
133 }
134 IncomingStreamRequest::Resolve(_) => {
135 Err(internal!("Cannot accept data on a RESOLVE stream").into())
136 }
137 }
138 }
139
140 pub async fn reject(mut self, message: msg::End) -> Result<()> {
142 let rx = self.reject_inner(CloseStreamBehavior::SendEnd(message))?;
143
144 rx.await.map_err(|_| Error::CircuitClosed)?
145 }
146
147 fn reject_inner(
151 &mut self,
152 message: CloseStreamBehavior,
153 ) -> Result<oneshot::Receiver<Result<()>>> {
154 self.components.target.close_pending(message)
155 }
156
157 pub async fn discard(mut self) -> Result<()> {
163 let rx = self.reject_inner(CloseStreamBehavior::SendNothing)?;
164
165 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
166 }
167}
168
169restricted_msg! {
175 #[derive(Clone, Debug, Deftly)]
177 #[derive_deftly(HasMemoryCost)]
178 #[non_exhaustive]
179 pub enum IncomingStreamRequest: RelayMsg {
180 Begin,
182 BeginDir,
184 Resolve,
186 }
187}
188
189type RelayCmdSet = bitvec::BitArr!(for 256);
194
195#[derive(Debug)]
198pub(crate) struct IncomingCmdChecker {
199 allow_commands: RelayCmdSet,
207}
208
209impl IncomingCmdChecker {
210 pub(crate) fn new_any(allow_commands: &[RelayCmd]) -> AnyCmdChecker {
212 let mut array = BitArray::ZERO;
213 for c in allow_commands {
214 array.set(u8::from(*c) as usize, true);
215 }
216 Box::new(Self {
217 allow_commands: array,
218 })
219 }
220}
221
222impl CmdChecker for IncomingCmdChecker {
223 fn check_msg(&mut self, msg: &UnparsedRelayMsg) -> Result<StreamStatus> {
224 if self.allow_commands[u8::from(msg.cmd()) as usize] {
225 Ok(StreamStatus::Open)
226 } else {
227 Err(Error::StreamProto(format!(
228 "Unexpected {} on incoming stream",
229 msg.cmd()
230 )))
231 }
232 }
233
234 fn consume_checked_msg(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
235 let _ = msg
236 .decode::<IncomingStreamRequest>()
237 .map_err(|err| Error::from_bytes_err(err, "invalid message on incoming stream"))?;
238
239 Ok(())
240 }
241}
242
243pub trait IncomingStreamRequestFilter: Send + 'static {
251 fn disposition(
253 &mut self,
254 ctx: &IncomingStreamRequestContext<'_>,
255 circ: &CircHopSyncView<'_>,
256 ) -> Result<IncomingStreamRequestDisposition>;
257}
258
259#[derive(Clone, Debug)]
261#[non_exhaustive]
262pub enum IncomingStreamRequestDisposition {
263 Accept,
266 CloseCircuit,
268 RejectRequest(msg::End),
270}
271
272pub struct IncomingStreamRequestContext<'a> {
274 pub(crate) request: &'a IncomingStreamRequest,
276}
277impl<'a> IncomingStreamRequestContext<'a> {
278 pub fn request(&self) -> &'a IncomingStreamRequest {
280 self.request
281 }
282}
283
284#[derive(Debug, Deftly)]
286#[derive_deftly(HasMemoryCost)]
287pub(crate) struct StreamReqInfo {
288 pub(crate) req: IncomingStreamRequest,
290 pub(crate) stream_id: StreamId,
292 pub(crate) hop: Option<HopLocation>,
299 #[deftly(has_memory_cost(indirect_size = "0"))]
301 pub(crate) relay_cell_format: RelayCellFormat,
302 pub(crate) stream_components: ReactorStreamComponents,
304 #[deftly(has_memory_cost(indirect_size = "0"))] pub(crate) memquota: StreamAccount,
307}
308
309#[cfg(any(feature = "hs-service", feature = "relay"))]
311pub(crate) type StreamReqSender = mq_queue::Sender<StreamReqInfo, MpscSpec>;
312
313#[derive(educe::Educe)]
315#[educe(Debug)]
316#[cfg(any(feature = "hs-service", feature = "relay"))]
317pub(crate) struct IncomingStreamRequestHandler {
318 pub(crate) incoming_sender: StreamReqSender,
320 pub(crate) hop_num: Option<HopNum>,
324 pub(crate) cmd_checker: AnyCmdChecker,
326 #[educe(Debug(ignore))]
329 pub(crate) filter: Box<dyn IncomingStreamRequestFilter>,
330}
331
332#[cfg(test)]
333mod test {
334 #![allow(clippy::bool_assert_comparison)]
336 #![allow(clippy::clone_on_copy)]
337 #![allow(clippy::dbg_macro)]
338 #![allow(clippy::mixed_attributes_style)]
339 #![allow(clippy::print_stderr)]
340 #![allow(clippy::print_stdout)]
341 #![allow(clippy::single_char_pattern)]
342 #![allow(clippy::unwrap_used)]
343 #![allow(clippy::unchecked_time_subtraction)]
344 #![allow(clippy::useless_vec)]
345 #![allow(clippy::needless_pass_by_value)]
346 #![allow(clippy::string_slice)] use tor_cell::relaycell::{
350 AnyRelayMsgOuter, RelayCellFormat,
351 msg::{Begin, BeginDir, Data, Resolve},
352 };
353
354 use super::*;
355
356 #[test]
357 fn incoming_cmd_checker() {
358 let u = |msg| {
360 let body = AnyRelayMsgOuter::new(None, msg)
361 .encode(RelayCellFormat::V0, &mut rand::rng())
362 .unwrap();
363 UnparsedRelayMsg::from_singleton_body(RelayCellFormat::V0, body).unwrap()
364 };
365 let begin = u(Begin::new("allium.example.com", 443, 0).unwrap().into());
366 let begin_dir = u(BeginDir::default().into());
367 let resolve = u(Resolve::new("allium.example.com").into());
368 let data = u(Data::new(&[1, 2, 3]).unwrap().into());
369
370 {
371 let mut cc_none = IncomingCmdChecker::new_any(&[]);
372 for m in [&begin, &begin_dir, &resolve, &data] {
373 assert!(cc_none.check_msg(m).is_err());
374 }
375 }
376
377 {
378 let mut cc_begin = IncomingCmdChecker::new_any(&[RelayCmd::BEGIN]);
379 assert_eq!(cc_begin.check_msg(&begin).unwrap(), StreamStatus::Open);
380 for m in [&begin_dir, &resolve, &data] {
381 assert!(cc_begin.check_msg(m).is_err());
382 }
383 }
384
385 {
386 let mut cc_any = IncomingCmdChecker::new_any(&[
387 RelayCmd::BEGIN,
388 RelayCmd::BEGIN_DIR,
389 RelayCmd::RESOLVE,
390 ]);
391 for m in [&begin, &begin_dir, &resolve] {
392 assert_eq!(cc_any.check_msg(m).unwrap(), StreamStatus::Open);
393 }
394 assert!(cc_any.check_msg(&data).is_err());
395 }
396 }
397}