Skip to main content

wireguard_netstack/
dns.rs

1//! DNS-over-HTTPS (DoH) resolver with configurable DNS servers.
2//!
3//! This module provides DNS resolution using DNS-over-HTTPS. It can work in two modes:
4//!
5//! 1. **Direct mode**: Uses regular TCP/TLS connections (for use before WireGuard is up)
6//! 2. **Tunnel mode**: Uses the WireGuard tunnel for DNS queries
7//!
8//! Both modes ensure DNS privacy by using encrypted HTTPS connections.
9//!
10//! # Configurable DNS Servers
11//!
12//! You can configure different DNS servers for:
13//! - **Pre-connection (direct mode)**: Used before the WireGuard tunnel is established
14//! - **Post-connection (tunnel mode)**: Used after the tunnel is up, queries go through VPN
15//!
16//! By default, Cloudflare DNS (1.1.1.1, 1.0.0.1) is used.
17
18use crate::error::{Error, Result};
19use crate::netstack::{NetStack, TcpConnection};
20use parking_lot::Mutex;
21use std::collections::HashMap;
22use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28
29/// Configuration for a DNS-over-HTTPS server.
30#[derive(Debug, Clone)]
31pub struct DohServerConfig {
32    /// The hostname of the DoH server (used for TLS SNI and Host header).
33    pub hostname: String,
34    /// The IP addresses of the DoH server (we try these in order).
35    /// These must be hardcoded since we can't resolve the DoH server using DoH itself.
36    pub ips: Vec<Ipv4Addr>,
37}
38
39impl DohServerConfig {
40    /// Create a new DoH server configuration.
41    pub fn new(hostname: impl Into<String>, ips: Vec<Ipv4Addr>) -> Self {
42        Self {
43            hostname: hostname.into(),
44            ips,
45        }
46    }
47
48    /// Cloudflare DNS (1.1.1.1, 1.0.0.1) - the default.
49    pub fn cloudflare() -> Self {
50        Self {
51            hostname: "1dot1dot1dot1.cloudflare-dns.com".into(),
52            ips: vec![Ipv4Addr::new(1, 1, 1, 1), Ipv4Addr::new(1, 0, 0, 1)],
53        }
54    }
55
56    /// Google DNS (8.8.8.8, 8.8.4.4).
57    pub fn google() -> Self {
58        Self {
59            hostname: "dns.google".into(),
60            ips: vec![Ipv4Addr::new(8, 8, 8, 8), Ipv4Addr::new(8, 8, 4, 4)],
61        }
62    }
63
64    /// Quad9 DNS (9.9.9.9, 149.112.112.112).
65    pub fn quad9() -> Self {
66        Self {
67            hostname: "dns.quad9.net".into(),
68            ips: vec![Ipv4Addr::new(9, 9, 9, 9), Ipv4Addr::new(149, 112, 112, 112)],
69        }
70    }
71
72    /// AdGuard DNS (94.140.14.14, 94.140.15.15).
73    pub fn adguard() -> Self {
74        Self {
75            hostname: "dns.adguard-dns.com".into(),
76            ips: vec![Ipv4Addr::new(94, 140, 14, 14), Ipv4Addr::new(94, 140, 15, 15)],
77        }
78    }
79
80    /// NextDNS - requires your NextDNS configuration ID.
81    /// The IPs are the anycast addresses for NextDNS.
82    pub fn nextdns(config_id: &str) -> Self {
83        Self {
84            hostname: format!("{}.dns.nextdns.io", config_id),
85            ips: vec![Ipv4Addr::new(45, 90, 28, 0), Ipv4Addr::new(45, 90, 30, 0)],
86        }
87    }
88}
89
90impl Default for DohServerConfig {
91    fn default() -> Self {
92        Self::cloudflare()
93    }
94}
95
96/// DNS configuration for the library.
97///
98/// Allows configuring different DNS servers for pre-connection (direct mode)
99/// and post-connection (tunnel mode) DNS resolution.
100#[derive(Debug, Clone)]
101pub struct DnsConfig {
102    /// DNS server to use before the WireGuard tunnel is established.
103    /// This is used to resolve the WireGuard endpoint hostname.
104    pub pre_connection: DohServerConfig,
105    /// DNS server to use after the WireGuard tunnel is established.
106    /// All DNS queries will go through the VPN tunnel.
107    pub post_connection: DohServerConfig,
108}
109
110impl DnsConfig {
111    /// Create a new DNS configuration with the same server for both modes.
112    pub fn new(server: DohServerConfig) -> Self {
113        Self {
114            pre_connection: server.clone(),
115            post_connection: server,
116        }
117    }
118
119    /// Create a DNS configuration with different servers for pre and post connection.
120    pub fn with_different_servers(pre_connection: DohServerConfig, post_connection: DohServerConfig) -> Self {
121        Self {
122            pre_connection,
123            post_connection,
124        }
125    }
126
127    /// Use Cloudflare DNS for both modes (default).
128    pub fn cloudflare() -> Self {
129        Self::new(DohServerConfig::cloudflare())
130    }
131
132    /// Use Google DNS for both modes.
133    pub fn google() -> Self {
134        Self::new(DohServerConfig::google())
135    }
136
137    /// Use Quad9 DNS for both modes.
138    pub fn quad9() -> Self {
139        Self::new(DohServerConfig::quad9())
140    }
141}
142
143impl Default for DnsConfig {
144    fn default() -> Self {
145        Self::cloudflare()
146    }
147}
148
149/// DNS cache entry.
150#[derive(Clone)]
151struct CacheEntry {
152    addresses: Vec<Ipv4Addr>,
153    expires_at: Instant,
154}
155
156/// Transport mode for DoH queries.
157#[derive(Clone)]
158enum Transport {
159    /// Use regular TCP connections (direct internet access).
160    Direct,
161    /// Use WireGuard tunnel for connections.
162    Tunnel(Arc<NetStack>),
163}
164
165/// A DNS-over-HTTPS resolver with configurable DNS servers.
166///
167/// This resolver can work in two modes:
168/// - Direct: Uses regular TCP/TLS for DNS queries (before WireGuard is up)
169/// - Tunnel: Routes DNS queries through the WireGuard tunnel
170///
171/// # Example
172///
173/// ```no_run
174/// use wireguard_netstack::{DohResolver, DohServerConfig};
175///
176/// // Use default Cloudflare DNS
177/// let resolver = DohResolver::new_direct();
178///
179/// // Use custom DNS server
180/// let resolver = DohResolver::new_direct_with_config(DohServerConfig::google());
181/// ```
182pub struct DohResolver {
183    transport: Transport,
184    tls_connector: TlsConnector,
185    /// DNS cache.
186    cache: Mutex<HashMap<String, CacheEntry>>,
187    /// Cache TTL (default 5 minutes).
188    cache_ttl: Duration,
189    /// DoH server configuration.
190    server_config: DohServerConfig,
191}
192
193impl DohResolver {
194    /// Create a new DoH resolver that uses the WireGuard tunnel with default Cloudflare DNS.
195    pub fn new_tunneled(netstack: Arc<NetStack>) -> Self {
196        Self::new_tunneled_with_config(netstack, DohServerConfig::default())
197    }
198
199    /// Create a new DoH resolver that uses the WireGuard tunnel with custom DNS config.
200    pub fn new_tunneled_with_config(netstack: Arc<NetStack>, config: DohServerConfig) -> Self {
201        Self::new_with_transport(Transport::Tunnel(netstack), config)
202    }
203
204    /// Create a new DoH resolver that uses direct TCP connections with default Cloudflare DNS.
205    /// Use this before the WireGuard tunnel is established.
206    pub fn new_direct() -> Self {
207        Self::new_direct_with_config(DohServerConfig::default())
208    }
209
210    /// Create a new DoH resolver that uses direct TCP connections with custom DNS config.
211    /// Use this before the WireGuard tunnel is established.
212    pub fn new_direct_with_config(config: DohServerConfig) -> Self {
213        Self::new_with_transport(Transport::Direct, config)
214    }
215
216    /// Create a resolver with the specified transport and server configuration.
217    fn new_with_transport(transport: Transport, server_config: DohServerConfig) -> Self {
218        // Install ring as the crypto provider (may already be installed)
219        let _ = rustls::crypto::ring::default_provider().install_default();
220
221        // Set up rustls with webpki roots
222        let root_store =
223            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
224
225        let tls_config = rustls::ClientConfig::builder()
226            .with_root_certificates(root_store)
227            .with_no_client_auth();
228
229        let tls_connector = TlsConnector::from(Arc::new(tls_config));
230
231        Self {
232            transport,
233            tls_connector,
234            cache: Mutex::new(HashMap::new()),
235            cache_ttl: Duration::from_secs(300), // 5 minutes
236            server_config,
237        }
238    }
239
240    /// Get the current server configuration.
241    pub fn server_config(&self) -> &DohServerConfig {
242        &self.server_config
243    }
244
245    /// Resolve a hostname to IPv4 addresses using DNS-over-HTTPS.
246    pub async fn resolve(&self, hostname: &str) -> Result<Vec<Ipv4Addr>> {
247        // Check if it's already an IP address
248        if let Ok(ip) = hostname.parse::<Ipv4Addr>() {
249            return Ok(vec![ip]);
250        }
251
252        // Check cache
253        {
254            let cache = self.cache.lock();
255            if let Some(entry) = cache.get(hostname) {
256                if entry.expires_at > Instant::now() {
257                    log::debug!("DNS cache hit for {}", hostname);
258                    return Ok(entry.addresses.clone());
259                }
260            }
261        }
262
263        let mode = match &self.transport {
264            Transport::Direct => "direct",
265            Transport::Tunnel(_) => "tunneled",
266        };
267        log::info!("Resolving {} via DoH ({})", hostname, mode);
268
269        // Try each DoH server IP
270        let mut last_error = None;
271        for doh_ip in &self.server_config.ips {
272            match self.query_doh(*doh_ip, hostname).await {
273                Ok(addrs) => {
274                    // Cache the result
275                    {
276                        let mut cache = self.cache.lock();
277                        cache.insert(
278                            hostname.to_string(),
279                            CacheEntry {
280                                addresses: addrs.clone(),
281                                expires_at: Instant::now() + self.cache_ttl,
282                            },
283                        );
284                    }
285                    return Ok(addrs);
286                }
287                Err(e) => {
288                    log::warn!("DoH query to {} failed: {}", doh_ip, e);
289                    last_error = Some(e);
290                }
291            }
292        }
293
294        Err(last_error.unwrap_or(Error::DnsAllServersFailed))
295    }
296
297    /// Resolve a hostname to a single socket address.
298    pub async fn resolve_addr(&self, hostname: &str, port: u16) -> Result<SocketAddr> {
299        let addrs = self.resolve(hostname).await?;
300        let ip = addrs
301            .into_iter()
302            .next()
303            .ok_or_else(|| Error::DnsNoRecords(hostname.to_string()))?;
304        Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
305    }
306
307    /// Query a DoH server for DNS records.
308    async fn query_doh(&self, doh_ip: Ipv4Addr, hostname: &str) -> Result<Vec<Ipv4Addr>> {
309        let addr = SocketAddr::V4(SocketAddrV4::new(doh_ip, 443));
310
311        // Build the DNS wire format query
312        let dns_query = build_dns_query(hostname)?;
313
314        // Build HTTP/1.1 request for DNS-over-HTTPS (using POST with application/dns-message)
315        let http_request = format!(
316            "POST /dns-query HTTP/1.1\r\n\
317             Host: {}\r\n\
318             Content-Type: application/dns-message\r\n\
319             Accept: application/dns-message\r\n\
320             Content-Length: {}\r\n\
321             Connection: close\r\n\
322             \r\n",
323            self.server_config.hostname,
324            dns_query.len()
325        );
326
327        // Connect and perform TLS handshake based on transport mode
328        let response = match &self.transport {
329            Transport::Direct => {
330                self.query_direct(addr, &http_request, &dns_query).await?
331            }
332            Transport::Tunnel(netstack) => {
333                self.query_tunneled(netstack.clone(), addr, &http_request, &dns_query)
334                    .await?
335            }
336        };
337
338        log::debug!("Received {} bytes from DoH server", response.len());
339
340        // Parse HTTP response
341        parse_doh_response(&response, hostname)
342    }
343
344    /// Query DoH server using direct TCP connection.
345    async fn query_direct(
346        &self,
347        addr: SocketAddr,
348        http_request: &str,
349        dns_query: &[u8],
350    ) -> Result<Vec<u8>> {
351        // Connect via regular TCP
352        let tcp_stream = TcpStream::connect(addr).await?;
353
354        // TLS handshake
355        let server_name = rustls::pki_types::ServerName::try_from(self.server_config.hostname.clone())
356            .map_err(|e| Error::TlsHandshake(format!("Invalid server name: {}", e)))?;
357
358        log::debug!("Starting TLS handshake with DoH server {} (direct)", addr);
359        let mut tls_stream = self
360            .tls_connector
361            .connect(server_name, tcp_stream)
362            .await
363            .map_err(|e| Error::TlsHandshake(e.to_string()))?;
364
365        log::debug!("TLS handshake completed, sending DNS query");
366
367        // Send HTTP request
368        tls_stream.write_all(http_request.as_bytes()).await?;
369        tls_stream.write_all(dns_query).await?;
370        tls_stream.flush().await?;
371
372        log::debug!("DNS query sent, waiting for response");
373
374        // Read response
375        let mut response = Vec::new();
376        let mut buf = [0u8; 4096];
377        loop {
378            match tls_stream.read(&mut buf).await {
379                Ok(0) => break,
380                Ok(n) => {
381                    response.extend_from_slice(&buf[..n]);
382                    if response.len() > 4 && response_complete(&response) {
383                        break;
384                    }
385                }
386                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
387                Err(e) => return Err(e.into()),
388            }
389        }
390
391        Ok(response)
392    }
393
394    /// Query DoH server using WireGuard tunnel.
395    async fn query_tunneled(
396        &self,
397        netstack: Arc<NetStack>,
398        addr: SocketAddr,
399        http_request: &str,
400        dns_query: &[u8],
401    ) -> Result<Vec<u8>> {
402        // Connect via WireGuard tunnel
403        let tcp_conn = TcpConnection::connect(netstack, addr).await?;
404
405        let tcp_stream = TunnelTcpStream {
406            conn: Arc::new(tcp_conn),
407        };
408
409        // TLS handshake
410        let server_name = rustls::pki_types::ServerName::try_from(self.server_config.hostname.clone())
411            .map_err(|e| Error::TlsHandshake(format!("Invalid server name: {}", e)))?;
412
413        log::debug!(
414            "Starting TLS handshake with DoH server {} (tunneled)",
415            addr
416        );
417        let mut tls_stream = self
418            .tls_connector
419            .connect(server_name, tcp_stream)
420            .await
421            .map_err(|e| Error::TlsHandshake(e.to_string()))?;
422
423        log::debug!("TLS handshake completed, sending DNS query");
424
425        // Send HTTP request
426        tls_stream.write_all(http_request.as_bytes()).await?;
427        tls_stream.write_all(dns_query).await?;
428        tls_stream.flush().await?;
429
430        log::debug!("DNS query sent, waiting for response");
431
432        // Read response
433        let mut response = Vec::new();
434        let mut buf = [0u8; 4096];
435        loop {
436            match tls_stream.read(&mut buf).await {
437                Ok(0) => break,
438                Ok(n) => {
439                    response.extend_from_slice(&buf[..n]);
440                    if response.len() > 4 && response_complete(&response) {
441                        break;
442                    }
443                }
444                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
445                Err(e) => return Err(e.into()),
446            }
447        }
448
449        Ok(response)
450    }
451}
452
453/// Check if we've received a complete HTTP response.
454fn response_complete(data: &[u8]) -> bool {
455    // Look for end of headers
456    if let Some(header_end) = find_header_end(data) {
457        // Parse Content-Length if present
458        let headers = &data[..header_end];
459        if let Some(content_length) = parse_content_length(headers) {
460            let body_start = header_end + 4; // Skip \r\n\r\n
461            let body_len = data.len().saturating_sub(body_start);
462            return body_len >= content_length;
463        }
464        // For chunked encoding or connection close, assume complete if we see 0-length read
465        return true;
466    }
467    false
468}
469
470/// Find the position of the header/body separator.
471fn find_header_end(data: &[u8]) -> Option<usize> {
472    for i in 0..data.len().saturating_sub(3) {
473        if &data[i..i + 4] == b"\r\n\r\n" {
474            return Some(i);
475        }
476    }
477    None
478}
479
480/// Parse Content-Length header.
481fn parse_content_length(headers: &[u8]) -> Option<usize> {
482    let headers_str = std::str::from_utf8(headers).ok()?;
483    for line in headers_str.lines() {
484        if line.to_lowercase().starts_with("content-length:") {
485            let value = line.split(':').nth(1)?.trim();
486            return value.parse().ok();
487        }
488    }
489    None
490}
491
492/// Build a DNS query in wire format.
493fn build_dns_query(hostname: &str) -> Result<Vec<u8>> {
494    let mut query = Vec::new();
495
496    // Transaction ID (random)
497    let id: u16 = rand::random();
498    query.extend_from_slice(&id.to_be_bytes());
499
500    // Flags: standard query, recursion desired
501    query.extend_from_slice(&[0x01, 0x00]); // QR=0, OPCODE=0, RD=1
502
503    // QDCOUNT = 1
504    query.extend_from_slice(&[0x00, 0x01]);
505    // ANCOUNT = 0
506    query.extend_from_slice(&[0x00, 0x00]);
507    // NSCOUNT = 0
508    query.extend_from_slice(&[0x00, 0x00]);
509    // ARCOUNT = 0
510    query.extend_from_slice(&[0x00, 0x00]);
511
512    // Question section
513    // Encode hostname as DNS name
514    for label in hostname.split('.') {
515        if label.len() > 63 {
516            return Err(Error::DnsLabelTooLong(label.to_string()));
517        }
518        query.push(label.len() as u8);
519        query.extend_from_slice(label.as_bytes());
520    }
521    query.push(0); // Root label
522
523    // QTYPE = A (1)
524    query.extend_from_slice(&[0x00, 0x01]);
525    // QCLASS = IN (1)
526    query.extend_from_slice(&[0x00, 0x01]);
527
528    Ok(query)
529}
530
531/// Parse DNS-over-HTTPS response.
532fn parse_doh_response(response: &[u8], hostname: &str) -> Result<Vec<Ipv4Addr>> {
533    // Find header/body separator
534    let header_end = find_header_end(response)
535        .ok_or_else(|| Error::InvalidHttpResponse("no header end found".into()))?;
536
537    let body_start = header_end + 4;
538    if body_start >= response.len() {
539        return Err(Error::InvalidHttpResponse("empty body".into()));
540    }
541
542    // Check HTTP status
543    let headers =
544        std::str::from_utf8(&response[..header_end]).map_err(|_| Error::InvalidHttpResponse("invalid headers".into()))?;
545
546    let status_line = headers.lines().next().unwrap_or("");
547    if !status_line.contains("200") {
548        return Err(Error::DohServerError(status_line.to_string()));
549    }
550
551    let dns_response = &response[body_start..];
552    parse_dns_response(dns_response, hostname)
553}
554
555/// Parse DNS response wire format.
556fn parse_dns_response(data: &[u8], hostname: &str) -> Result<Vec<Ipv4Addr>> {
557    if data.len() < 12 {
558        return Err(Error::DnsResponseTooShort);
559    }
560
561    // Parse header
562    let flags = u16::from_be_bytes([data[2], data[3]]);
563    let rcode = flags & 0x000F;
564
565    if rcode != 0 {
566        return Err(Error::DnsError(rcode));
567    }
568
569    let ancount = u16::from_be_bytes([data[6], data[7]]) as usize;
570    if ancount == 0 {
571        return Err(Error::DnsNoRecords(hostname.to_string()));
572    }
573
574    log::debug!("DNS response has {} answers", ancount);
575
576    // Skip header and question section
577    let mut pos = 12;
578
579    // Skip question section (QDCOUNT questions)
580    let qdcount = u16::from_be_bytes([data[4], data[5]]) as usize;
581    for _ in 0..qdcount {
582        pos = skip_dns_name(data, pos)?;
583        pos += 4; // QTYPE + QCLASS
584    }
585
586    // Parse answer section
587    let mut addresses = Vec::new();
588    for _ in 0..ancount {
589        if pos >= data.len() {
590            break;
591        }
592
593        // Skip name
594        pos = skip_dns_name(data, pos)?;
595
596        if pos + 10 > data.len() {
597            break;
598        }
599
600        let rtype = u16::from_be_bytes([data[pos], data[pos + 1]]);
601        let _rclass = u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
602        let _ttl = u32::from_be_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]]);
603        let rdlength = u16::from_be_bytes([data[pos + 8], data[pos + 9]]) as usize;
604
605        pos += 10;
606
607        if pos + rdlength > data.len() {
608            break;
609        }
610
611        // Type A = 1
612        if rtype == 1 && rdlength == 4 {
613            let ip = Ipv4Addr::new(data[pos], data[pos + 1], data[pos + 2], data[pos + 3]);
614            log::debug!("Resolved {} -> {}", hostname, ip);
615            addresses.push(ip);
616        }
617
618        pos += rdlength;
619    }
620
621    if addresses.is_empty() {
622        return Err(Error::DnsNoRecords(hostname.to_string()));
623    }
624
625    Ok(addresses)
626}
627
628/// Skip a DNS name (handles compression).
629fn skip_dns_name(data: &[u8], mut pos: usize) -> Result<usize> {
630    loop {
631        if pos >= data.len() {
632            return Err(Error::DnsNameTooLong);
633        }
634
635        let len = data[pos] as usize;
636
637        // Check for compression pointer
638        if len & 0xC0 == 0xC0 {
639            // Compression pointer is 2 bytes
640            return Ok(pos + 2);
641        }
642
643        // Check for end of name
644        if len == 0 {
645            return Ok(pos + 1);
646        }
647
648        // Skip label
649        pos += 1 + len;
650    }
651}
652
653/// A TCP stream wrapper for tunneled DoH connections.
654pub(crate) struct TunnelTcpStream {
655    conn: Arc<TcpConnection>,
656}
657
658impl tokio::io::AsyncRead for TunnelTcpStream {
659    fn poll_read(
660        self: std::pin::Pin<&mut Self>,
661        cx: &mut std::task::Context<'_>,
662        buf: &mut tokio::io::ReadBuf<'_>,
663    ) -> std::task::Poll<std::io::Result<()>> {
664        let conn = self.conn.clone();
665        let unfilled = buf.initialize_unfilled();
666
667        conn.netstack.poll();
668
669        if conn.netstack.can_recv(conn.handle) {
670            match conn.netstack.recv(conn.handle, unfilled) {
671                Ok(n) if n > 0 => {
672                    buf.advance(n);
673                    return std::task::Poll::Ready(Ok(()));
674                }
675                Ok(_) => {}
676                Err(e) => {
677                    return std::task::Poll::Ready(Err(std::io::Error::new(
678                        std::io::ErrorKind::Other,
679                        e.to_string(),
680                    )));
681                }
682            }
683        }
684
685        if !conn.netstack.may_recv(conn.handle) {
686            return std::task::Poll::Ready(Ok(()));
687        }
688
689        let waker = cx.waker().clone();
690        tokio::spawn(async move {
691            tokio::time::sleep(Duration::from_millis(1)).await;
692            waker.wake();
693        });
694
695        std::task::Poll::Pending
696    }
697}
698
699impl tokio::io::AsyncWrite for TunnelTcpStream {
700    fn poll_write(
701        self: std::pin::Pin<&mut Self>,
702        cx: &mut std::task::Context<'_>,
703        buf: &[u8],
704    ) -> std::task::Poll<std::io::Result<usize>> {
705        let conn = self.conn.clone();
706
707        conn.netstack.poll();
708
709        if conn.netstack.can_send(conn.handle) {
710            match conn.netstack.send(conn.handle, buf) {
711                Ok(n) => {
712                    conn.netstack.poll();
713                    return std::task::Poll::Ready(Ok(n));
714                }
715                Err(e) => {
716                    return std::task::Poll::Ready(Err(std::io::Error::new(
717                        std::io::ErrorKind::Other,
718                        e.to_string(),
719                    )));
720                }
721            }
722        }
723
724        if !conn.netstack.may_send(conn.handle) {
725            return std::task::Poll::Ready(Err(std::io::Error::new(
726                std::io::ErrorKind::BrokenPipe,
727                "Connection closed",
728            )));
729        }
730
731        let waker = cx.waker().clone();
732        tokio::spawn(async move {
733            tokio::time::sleep(Duration::from_millis(1)).await;
734            waker.wake();
735        });
736
737        std::task::Poll::Pending
738    }
739
740    fn poll_flush(
741        self: std::pin::Pin<&mut Self>,
742        _cx: &mut std::task::Context<'_>,
743    ) -> std::task::Poll<std::io::Result<()>> {
744        self.conn.netstack.poll();
745        std::task::Poll::Ready(Ok(()))
746    }
747
748    fn poll_shutdown(
749        self: std::pin::Pin<&mut Self>,
750        _cx: &mut std::task::Context<'_>,
751    ) -> std::task::Poll<std::io::Result<()>> {
752        self.conn.shutdown();
753        self.conn.netstack.poll();
754        std::task::Poll::Ready(Ok(()))
755    }
756}