rustydht_lib/storage/
outbound_request_storage.rs

1use std::net::SocketAddr;
2
3use crate::common::{Id, TransactionId};
4use crate::packets::{Message, MessageType};
5
6use std::time::{Duration, Instant};
7
8use log::debug;
9use tokio::sync::mpsc;
10
11#[derive(Default)]
12pub struct OutboundRequestStorage {
13    requests: std::collections::HashMap<TransactionId, RequestInfo>,
14}
15
16impl OutboundRequestStorage {
17    pub fn new() -> OutboundRequestStorage {
18        Self::default()
19    }
20
21    pub fn add_request(&mut self, info: RequestInfo) {
22        self.requests
23            .insert(info.packet.transaction_id.clone().into(), info);
24    }
25
26    #[cfg(test)]
27    pub fn has_request<T>(&self, tid: &T) -> bool
28    where
29        T: Into<TransactionId>,
30        T: Clone,
31    {
32        self.requests.contains_key(&tid.clone().into())
33    }
34
35    pub fn get_matching_request_info(
36        &self,
37        msg: &Message,
38        src_addr: SocketAddr,
39    ) -> Option<&RequestInfo> {
40        let tid = msg.transaction_id.clone().into();
41
42        // Is this packet a response?
43        if let MessageType::Response(res_specific) = &msg.message_type {
44            // Is there a matching transaction id in storage?
45            if let Some(request_info) = self.requests.get(&tid) {
46                // Did this response come from the expected IP address?
47                if request_info.addr == src_addr {
48                    let response_sender_id = msg.get_author_id();
49                    // Does the Id of the sender match the recorded addressee of the original request (if any)?
50                    if request_info.id.is_none()
51                        || (response_sender_id.is_some()
52                            && request_info.id.unwrap() == response_sender_id.unwrap())
53                    {
54                        // Is the thing in storage a request packet (It should always be...)
55                        if let MessageType::Request(req_specific) =
56                            &request_info.packet.message_type
57                        {
58                            // Does the response type match the request type?
59                            if crate::packets::response_matches_request(res_specific, req_specific)
60                            {
61                                return Some(request_info);
62                            }
63                        }
64                    }
65                }
66            }
67        }
68
69        None
70    }
71
72    pub fn take_matching_request_info(
73        &mut self,
74        response: &Message,
75        src_addr: SocketAddr,
76    ) -> Option<RequestInfo> {
77        if self.get_matching_request_info(response, src_addr).is_some() {
78            let tid = response.transaction_id.clone().into();
79            return self.requests.remove(&tid);
80        }
81
82        None
83    }
84
85    pub fn prune_older_than(&mut self, duration: Duration) {
86        match Instant::now().checked_sub(duration) {
87            None => {
88                debug!(target: "rustydht_lib::OutboundRequestStorage",
89                    "Outbound request storage skipping pruning due to monotonic clock underflow"
90                );
91            }
92
93            Some(time) => {
94                let len_before = self.requests.len();
95                self.requests
96                    .retain(|_, v| -> bool { v.created_at >= time });
97                let len_after = self.requests.len();
98                debug!(target: "rustydht_lib::OutboundRequestStorage", "Pruned {} request records", len_before - len_after);
99            }
100        }
101    }
102
103    pub fn len(&self) -> usize {
104        self.requests.len()
105    }
106
107    pub fn is_empty(&self) -> bool {
108        self.requests.is_empty()
109    }
110}
111
112#[derive(Debug)]
113pub struct RequestInfo {
114    addr: SocketAddr,
115    id: Option<Id>,
116    packet: Message,
117    created_at: Instant,
118    pub(crate) response_channel: Option<mpsc::Sender<Message>>,
119}
120
121impl RequestInfo {
122    pub fn new(
123        addr: SocketAddr,
124        id: Option<Id>,
125        packet: Message,
126        response_channel: Option<mpsc::Sender<Message>>,
127    ) -> RequestInfo {
128        RequestInfo {
129            addr,
130            id,
131            packet,
132            created_at: Instant::now(),
133            response_channel,
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::packets::MessageBuilder;
142
143    #[test]
144    fn test_outbound_request_storage() {
145        let mut storage = OutboundRequestStorage::new();
146
147        let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
148        let req = MessageBuilder::new_ping_request()
149            .sender_id(our_id)
150            .build()
151            .unwrap();
152
153        let request_target_addr = "127.0.0.1:1234".parse().unwrap();
154        let request_info = RequestInfo::new(request_target_addr, None, req.clone(), None);
155
156        // Add request to storage, make sure it's there
157        storage.add_request(request_info);
158        assert!(storage.has_request(&req.transaction_id));
159
160        // Simulate a response, see if we correctly get the requet back from storage
161        let simulated_response = MessageBuilder::new_ping_response()
162            .sender_id(our_id)
163            .transaction_id(req.transaction_id.clone())
164            .requester_ip("127.0.0.1:1235".parse().unwrap())
165            .build()
166            .unwrap();
167
168        // We should get something if the SocketAddr matches
169        assert!(storage
170            .get_matching_request_info(&simulated_response, request_target_addr)
171            .is_some());
172
173        // We should NOT get something if the SocketAddr doesn't match
174        assert!(storage
175            .get_matching_request_info(&simulated_response, "5.5.5.5:1234".parse().unwrap())
176            .is_none());
177
178        // Take the response
179        assert!(storage
180            .take_matching_request_info(&simulated_response, request_target_addr)
181            .is_some());
182
183        // Should have nothing left
184        assert!(!storage.has_request(&req.transaction_id));
185    }
186
187    #[test]
188    fn test_outbound_storage_prune() {
189        let mut storage = OutboundRequestStorage::new();
190
191        let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
192        let req = MessageBuilder::new_ping_request()
193            .sender_id(our_id)
194            .build()
195            .unwrap();
196        let req_2 = MessageBuilder::new_ping_request()
197            .sender_id(our_id)
198            .build()
199            .unwrap();
200
201        let request_info =
202            RequestInfo::new("127.0.0.1:1234".parse().unwrap(), None, req.clone(), None);
203        let mut request_info_2 =
204            RequestInfo::new("127.0.0.1:1234".parse().unwrap(), None, req_2.clone(), None);
205        request_info_2.created_at = Instant::now() + Duration::from_secs(10);
206
207        // Add request to storage, make sure it's there
208        storage.add_request(request_info);
209        storage.add_request(request_info_2);
210        assert!(storage.has_request(&req.transaction_id));
211
212        // Prune, make sure request isn't there
213        storage.prune_older_than(Duration::from_secs(0));
214        assert!(!storage.has_request(&req.transaction_id));
215        assert!(storage.has_request(&req_2.transaction_id));
216    }
217}