shadowsocks_service/local/dns/
client_cache.rs

1//! DNS Client cache
2
3#[cfg(unix)]
4use std::path::Path;
5use std::{
6    collections::{HashMap, VecDeque, hash_map::Entry},
7    future::Future,
8    io,
9    net::SocketAddr,
10    time::Duration,
11};
12
13use hickory_resolver::proto::{ProtoError, op::Message};
14use log::{debug, trace};
15use tokio::sync::Mutex;
16
17use shadowsocks::{config::ServerConfig, net::ConnectOpts, relay::socks5::Address};
18
19use crate::local::context::ServiceContext;
20
21use super::upstream::DnsClient;
22
23#[derive(Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
24enum DnsClientKey {
25    TcpLocal(SocketAddr),
26    UdpLocal(SocketAddr),
27    TcpRemote(Address),
28    UdpRemote(Address),
29}
30
31pub struct DnsClientCache {
32    cache: Mutex<HashMap<DnsClientKey, VecDeque<DnsClient>>>,
33    timeout: Duration,
34    retry_count: usize,
35    max_client_per_addr: usize,
36}
37
38impl DnsClientCache {
39    pub fn new(max_client_per_addr: usize) -> Self {
40        Self {
41            cache: Mutex::new(HashMap::new()),
42            timeout: Duration::from_secs(5),
43            retry_count: 1,
44            max_client_per_addr,
45        }
46    }
47
48    pub async fn lookup_local(
49        &self,
50        ns: SocketAddr,
51        msg: Message,
52        connect_opts: &ConnectOpts,
53        is_udp: bool,
54    ) -> Result<Message, ProtoError> {
55        let key = match is_udp {
56            true => DnsClientKey::UdpLocal(ns),
57            false => DnsClientKey::TcpLocal(ns),
58        };
59        self.lookup_dns(&key, msg, Some(connect_opts), None, None).await
60    }
61
62    pub async fn lookup_remote(
63        &self,
64        context: &ServiceContext,
65        svr_cfg: &ServerConfig,
66        ns: &Address,
67        msg: Message,
68        is_udp: bool,
69    ) -> Result<Message, ProtoError> {
70        let key = match is_udp {
71            true => DnsClientKey::UdpRemote(ns.clone()),
72            false => DnsClientKey::TcpRemote(ns.clone()),
73        };
74        self.lookup_dns(&key, msg, None, Some(context), Some(svr_cfg)).await
75    }
76
77    #[cfg(unix)]
78    pub async fn lookup_unix_stream<P: AsRef<Path>>(&self, ns: &P, msg: Message) -> Result<Message, ProtoError> {
79        let mut last_err = None;
80
81        for _ in 0..self.retry_count {
82            // UNIX stream won't keep connection alive
83            //
84            // https://github.com/shadowsocks/shadowsocks-rust/pull/567
85            //
86            // 1. The cost of recreating UNIX stream sockets are very low
87            // 2. This feature is only used by shadowsocks-android, and it doesn't support connection reuse
88
89            let mut client = match DnsClient::connect_unix_stream(ns).await {
90                Ok(client) => client,
91                Err(err) => {
92                    last_err = Some(From::from(err));
93                    continue;
94                }
95            };
96
97            let res = match client.lookup_timeout(msg.clone(), self.timeout).await {
98                Ok(msg) => msg,
99                Err(error) => {
100                    last_err = Some(error);
101                    continue;
102                }
103            };
104            return Ok(res);
105        }
106        Err(last_err.unwrap())
107    }
108
109    async fn lookup_dns(
110        &self,
111        dck: &DnsClientKey,
112        msg: Message,
113        connect_opts: Option<&ConnectOpts>,
114        context: Option<&ServiceContext>,
115        svr_cfg: Option<&ServerConfig>,
116    ) -> Result<Message, ProtoError> {
117        let mut last_err = None;
118        for _ in 0..self.retry_count {
119            let create_fn = async {
120                match dck {
121                    DnsClientKey::TcpLocal(tcp_l) => {
122                        let connect_opts = connect_opts.expect("connect options is required for local DNS");
123                        DnsClient::connect_tcp_local(*tcp_l, connect_opts).await
124                    }
125                    DnsClientKey::UdpLocal(udp_l) => {
126                        let connect_opts = connect_opts.expect("connect options is required for local DNS");
127                        DnsClient::connect_udp_local(*udp_l, connect_opts).await
128                    }
129                    DnsClientKey::TcpRemote(tcp_l) => {
130                        let context = context.expect("context is required for remote DNS");
131                        let svr_cfg = svr_cfg.expect("server config is required for remote DNS");
132
133                        DnsClient::connect_tcp_remote(
134                            context.context(),
135                            svr_cfg,
136                            tcp_l,
137                            context.connect_opts_ref(),
138                            context.flow_stat(),
139                        )
140                        .await
141                    }
142                    DnsClientKey::UdpRemote(udp_l) => {
143                        let context = context.expect("context is required for remote DNS");
144                        let svr_cfg = svr_cfg.expect("server config is required for remote DNS");
145
146                        DnsClient::connect_udp_remote(
147                            context.context(),
148                            svr_cfg,
149                            udp_l.clone(),
150                            context.connect_opts_ref(),
151                            context.flow_stat(),
152                        )
153                        .await
154                    }
155                }
156            };
157            match self.get_client_or_create(dck, create_fn).await {
158                Ok(mut client) => match client.lookup_timeout(msg.clone(), self.timeout).await {
159                    Ok(msg) => {
160                        self.save_client(dck.clone(), client).await;
161                        return Ok(msg);
162                    }
163                    Err(err) => {
164                        last_err = Some(err);
165                        continue;
166                    }
167                },
168                Err(err) => {
169                    last_err = Some(From::from(err));
170                    continue;
171                }
172            }
173        }
174        Err(last_err.unwrap())
175    }
176
177    async fn get_client_or_create<C>(&self, key: &DnsClientKey, create_fn: C) -> io::Result<DnsClient>
178    where
179        C: Future<Output = io::Result<DnsClient>>,
180    {
181        // Check if there already is a cached client
182        if let Some(q) = self.cache.lock().await.get_mut(key) {
183            while let Some(mut c) = q.pop_front() {
184                trace!("take cached DNS client for {:?}", key);
185                if !c.check_connected().await {
186                    debug!("cached DNS client for {:?} is lost", key);
187                    continue;
188                }
189                return Ok(c);
190            }
191        }
192        trace!("creating connection to DNS server {:?}", key);
193
194        // Create one
195        create_fn.await
196    }
197
198    async fn save_client(&self, key: DnsClientKey, client: DnsClient) {
199        match self.cache.lock().await.entry(key) {
200            Entry::Occupied(occ) => {
201                let q = occ.into_mut();
202                q.push_back(client);
203                if q.len() > self.max_client_per_addr {
204                    q.pop_front();
205                }
206            }
207            Entry::Vacant(vac) => {
208                let mut q = VecDeque::with_capacity(self.max_client_per_addr);
209                q.push_back(client);
210                vac.insert(q);
211            }
212        }
213    }
214}