rustydht_lib/storage/
outbound_request_storage.rs1use std::net::SocketAddr;
2
3use crate::common::{Id, TransactionId};
4use crate::packets::{Message, MessageType};
5
6use std::time::{Duration, Instant};
7
8use log::debug;
9use tokio::sync::mpsc;
10
11#[derive(Default)]
12pub struct OutboundRequestStorage {
13 requests: std::collections::HashMap<TransactionId, RequestInfo>,
14}
15
16impl OutboundRequestStorage {
17 pub fn new() -> OutboundRequestStorage {
18 Self::default()
19 }
20
21 pub fn add_request(&mut self, info: RequestInfo) {
22 self.requests
23 .insert(info.packet.transaction_id.clone().into(), info);
24 }
25
26 #[cfg(test)]
27 pub fn has_request<T>(&self, tid: &T) -> bool
28 where
29 T: Into<TransactionId>,
30 T: Clone,
31 {
32 self.requests.contains_key(&tid.clone().into())
33 }
34
35 pub fn get_matching_request_info(
36 &self,
37 msg: &Message,
38 src_addr: SocketAddr,
39 ) -> Option<&RequestInfo> {
40 let tid = msg.transaction_id.clone().into();
41
42 if let MessageType::Response(res_specific) = &msg.message_type {
44 if let Some(request_info) = self.requests.get(&tid) {
46 if request_info.addr == src_addr {
48 let response_sender_id = msg.get_author_id();
49 if request_info.id.is_none()
51 || (response_sender_id.is_some()
52 && request_info.id.unwrap() == response_sender_id.unwrap())
53 {
54 if let MessageType::Request(req_specific) =
56 &request_info.packet.message_type
57 {
58 if crate::packets::response_matches_request(res_specific, req_specific)
60 {
61 return Some(request_info);
62 }
63 }
64 }
65 }
66 }
67 }
68
69 None
70 }
71
72 pub fn take_matching_request_info(
73 &mut self,
74 response: &Message,
75 src_addr: SocketAddr,
76 ) -> Option<RequestInfo> {
77 if self.get_matching_request_info(response, src_addr).is_some() {
78 let tid = response.transaction_id.clone().into();
79 return self.requests.remove(&tid);
80 }
81
82 None
83 }
84
85 pub fn prune_older_than(&mut self, duration: Duration) {
86 match Instant::now().checked_sub(duration) {
87 None => {
88 debug!(target: "rustydht_lib::OutboundRequestStorage",
89 "Outbound request storage skipping pruning due to monotonic clock underflow"
90 );
91 }
92
93 Some(time) => {
94 let len_before = self.requests.len();
95 self.requests
96 .retain(|_, v| -> bool { v.created_at >= time });
97 let len_after = self.requests.len();
98 debug!(target: "rustydht_lib::OutboundRequestStorage", "Pruned {} request records", len_before - len_after);
99 }
100 }
101 }
102
103 pub fn len(&self) -> usize {
104 self.requests.len()
105 }
106
107 pub fn is_empty(&self) -> bool {
108 self.requests.is_empty()
109 }
110}
111
112#[derive(Debug)]
113pub struct RequestInfo {
114 addr: SocketAddr,
115 id: Option<Id>,
116 packet: Message,
117 created_at: Instant,
118 pub(crate) response_channel: Option<mpsc::Sender<Message>>,
119}
120
121impl RequestInfo {
122 pub fn new(
123 addr: SocketAddr,
124 id: Option<Id>,
125 packet: Message,
126 response_channel: Option<mpsc::Sender<Message>>,
127 ) -> RequestInfo {
128 RequestInfo {
129 addr,
130 id,
131 packet,
132 created_at: Instant::now(),
133 response_channel,
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::packets::MessageBuilder;
142
143 #[test]
144 fn test_outbound_request_storage() {
145 let mut storage = OutboundRequestStorage::new();
146
147 let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
148 let req = MessageBuilder::new_ping_request()
149 .sender_id(our_id)
150 .build()
151 .unwrap();
152
153 let request_target_addr = "127.0.0.1:1234".parse().unwrap();
154 let request_info = RequestInfo::new(request_target_addr, None, req.clone(), None);
155
156 storage.add_request(request_info);
158 assert!(storage.has_request(&req.transaction_id));
159
160 let simulated_response = MessageBuilder::new_ping_response()
162 .sender_id(our_id)
163 .transaction_id(req.transaction_id.clone())
164 .requester_ip("127.0.0.1:1235".parse().unwrap())
165 .build()
166 .unwrap();
167
168 assert!(storage
170 .get_matching_request_info(&simulated_response, request_target_addr)
171 .is_some());
172
173 assert!(storage
175 .get_matching_request_info(&simulated_response, "5.5.5.5:1234".parse().unwrap())
176 .is_none());
177
178 assert!(storage
180 .take_matching_request_info(&simulated_response, request_target_addr)
181 .is_some());
182
183 assert!(!storage.has_request(&req.transaction_id));
185 }
186
187 #[test]
188 fn test_outbound_storage_prune() {
189 let mut storage = OutboundRequestStorage::new();
190
191 let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
192 let req = MessageBuilder::new_ping_request()
193 .sender_id(our_id)
194 .build()
195 .unwrap();
196 let req_2 = MessageBuilder::new_ping_request()
197 .sender_id(our_id)
198 .build()
199 .unwrap();
200
201 let request_info =
202 RequestInfo::new("127.0.0.1:1234".parse().unwrap(), None, req.clone(), None);
203 let mut request_info_2 =
204 RequestInfo::new("127.0.0.1:1234".parse().unwrap(), None, req_2.clone(), None);
205 request_info_2.created_at = Instant::now() + Duration::from_secs(10);
206
207 storage.add_request(request_info);
209 storage.add_request(request_info_2);
210 assert!(storage.has_request(&req.transaction_id));
211
212 storage.prune_older_than(Duration::from_secs(0));
214 assert!(!storage.has_request(&req.transaction_id));
215 assert!(storage.has_request(&req_2.transaction_id));
216 }
217}