Skip to main content

rusmes_smtp/
outbound_pool.rs

1//! Outbound SMTP connection pool
2//!
3//! [`OutboundPool`] maintains a bounded set of reusable TCP connections to
4//! remote SMTP servers.  Reusing an established connection avoids the per-
5//! message cost of TCP + (optionally) TLS handshake and SMTP greeting/EHLO
6//! round-trips.
7//!
8//! ## Design
9//!
10//! Each remote address (host:port string key) has its own
11//! [`tokio::sync::Mutex`]-guarded [`std::collections::VecDeque`] of idle
12//! [`PooledConn`]s.  `get_or_connect` pops from the front; `return_conn`
13//! pushes to the back.  Before returning a connection to the pool the caller
14//! MUST send `RSET\r\n` and wait for a `250` response; if that fails the
15//! connection is dropped rather than returned.
16//!
17//! A background *idle reaper* task wakes every `idle_timeout / 2` and removes
18//! connections whose `last_used` timestamp exceeds `idle_timeout`.
19//!
20//! ## Caps
21//!
22//! * **per-remote cap** — at most `per_remote_cap` idle connections are kept
23//!   for any single remote address.  Connections beyond the cap are dropped on
24//!   return.
25//! * **global cap** — the sum of all idle connections across all remotes must
26//!   not exceed `global_cap`.  Connections are dropped on return when the
27//!   global counter would exceed the cap.
28
29use dashmap::DashMap;
30use std::collections::VecDeque;
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, SystemTime};
34use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
35use tokio::net::TcpStream;
36use tokio::sync::Mutex;
37
38// ── SMTP extensions advertised by the remote server ──────────────────────────
39
40/// Extensions advertised by the remote SMTP server in its EHLO response.
41///
42/// These are captured once during connection establishment and re-used for
43/// every subsequent message sent over a pooled connection, avoiding a
44/// redundant EHLO round-trip.
45#[derive(Debug, Clone, Default)]
46pub struct SmtpExtensions {
47    /// Maximum message size (from `250-SIZE <n>`), or `None` if not advertised.
48    pub max_size: Option<usize>,
49    /// Whether the remote advertises `PIPELINING`.
50    pub pipelining: bool,
51    /// Whether the remote advertises `8BITMIME`.
52    pub eight_bit_mime: bool,
53    /// Whether the remote advertises `STARTTLS`.
54    pub starttls: bool,
55}
56
57impl SmtpExtensions {
58    /// Parse the EHLO multi-line response text into an `SmtpExtensions` value.
59    pub fn from_ehlo(ehlo_text: &str) -> Self {
60        let mut ext = SmtpExtensions::default();
61        for line in ehlo_text.lines() {
62            // Lines look like "250-PIPELINING" or "250 SIZE 10240000"
63            let keyword = line
64                .trim_start_matches(|c: char| c.is_ascii_digit())
65                .trim_start_matches(['-', ' '])
66                .to_ascii_uppercase();
67
68            if keyword.starts_with("SIZE") {
69                let parts: Vec<&str> = keyword.splitn(2, ' ').collect();
70                if parts.len() == 2 {
71                    ext.max_size = parts[1].trim().parse().ok();
72                }
73            } else if keyword == "PIPELINING" {
74                ext.pipelining = true;
75            } else if keyword == "8BITMIME" {
76                ext.eight_bit_mime = true;
77            } else if keyword == "STARTTLS" {
78                ext.starttls = true;
79            }
80        }
81        ext
82    }
83}
84
85// ── Pooled connection ─────────────────────────────────────────────────────────
86
87/// A single idle SMTP connection held in the pool.
88pub struct PooledConn {
89    /// The underlying TCP stream, wrapped in a line-oriented buffer.
90    pub reader: BufReader<TcpStream>,
91    /// Wall-clock time of last use (used by the idle reaper).
92    pub last_used: SystemTime,
93    /// Extensions advertised by the remote server in the initial EHLO.
94    pub extensions: SmtpExtensions,
95    /// The canonical "host:port" key under which this connection is pooled.
96    pub remote_key: String,
97}
98
99impl PooledConn {
100    /// Extract the underlying [`TcpStream`] reference (read-only; write via
101    /// `reader.get_mut()`).
102    pub fn stream_mut(&mut self) -> &mut TcpStream {
103        self.reader.get_mut()
104    }
105}
106
107// ── Pool configuration ────────────────────────────────────────────────────────
108
109/// Configuration snapshot used when constructing an [`OutboundPool`].
110///
111/// Sourced from [`rusmes_config::SmtpOutboundConfig`]; duplicated here to
112/// avoid a compile-time dependency on `rusmes-config` inside `rusmes-smtp`.
113#[derive(Debug, Clone)]
114pub struct OutboundPoolConfig {
115    /// Maximum connections kept idle for a single remote address.
116    pub per_remote_cap: usize,
117    /// Total connections kept idle across all remote addresses.
118    pub global_cap: usize,
119    /// Duration after which an idle connection is reaped.
120    pub idle_timeout: Duration,
121}
122
123impl Default for OutboundPoolConfig {
124    fn default() -> Self {
125        Self {
126            per_remote_cap: 8,
127            global_cap: 256,
128            idle_timeout: Duration::from_secs(30),
129        }
130    }
131}
132
133// ── OutboundPool ──────────────────────────────────────────────────────────────
134
135/// Bounded pool of idle outbound SMTP connections.
136///
137/// # Thread safety
138///
139/// `OutboundPool` is `Send + Sync`; share it via `Arc<OutboundPool>`.
140pub struct OutboundPool {
141    /// Keyed by "host:port" string.
142    conns: DashMap<String, Mutex<VecDeque<PooledConn>>>,
143    config: OutboundPoolConfig,
144    /// Running total of idle connections across all remotes.
145    total_idle: Arc<AtomicUsize>,
146}
147
148impl OutboundPool {
149    /// Create a new pool and spawn the background idle-reaper task.
150    ///
151    /// The reaper stops automatically when `shutdown_rx` yields `true`.
152    pub fn new(
153        config: OutboundPoolConfig,
154        mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
155    ) -> Arc<Self> {
156        let pool = Arc::new(Self {
157            conns: DashMap::new(),
158            config: config.clone(),
159            total_idle: Arc::new(AtomicUsize::new(0)),
160        });
161
162        // Spawn the background reaper.
163        let reaper_pool = pool.clone();
164        let reap_interval = config.idle_timeout / 2;
165
166        tokio::spawn(async move {
167            loop {
168                tokio::select! {
169                    _ = tokio::time::sleep(reap_interval) => {}
170                    _ = shutdown_rx.changed() => {
171                        if *shutdown_rx.borrow() {
172                            break;
173                        }
174                    }
175                }
176                reaper_pool.reap_idle().await;
177            }
178        });
179
180        pool
181    }
182
183    /// Obtain a connection to `remote_key` (format: `"host:port"`).
184    ///
185    /// If a pooled idle connection is available, returns it immediately.
186    /// Otherwise opens a new TCP connection, reads the `220` greeting, sends
187    /// `EHLO localhost`, reads the response, and wraps everything in a
188    /// [`PooledConn`].
189    pub async fn get_or_connect(&self, remote_key: &str) -> anyhow::Result<PooledConn> {
190        // Attempt to pop an idle connection.
191        if let Some(bucket) = self.conns.get(remote_key) {
192            let mut deque = bucket.lock().await;
193            if let Some(conn) = deque.pop_front() {
194                self.total_idle.fetch_sub(1, Ordering::Relaxed);
195                return Ok(conn);
196            }
197        }
198
199        // No idle connection — open a fresh one.
200        self.open_fresh(remote_key).await
201    }
202
203    /// Return a connection to the pool after use.
204    ///
205    /// Sends `RSET\r\n` and waits for a `250` response.  On any I/O or
206    /// protocol error the connection is silently dropped.  If the pool would
207    /// exceed its caps the connection is also dropped.
208    pub async fn return_conn(&self, mut conn: PooledConn) {
209        // Send RSET to clear server-side transaction state.
210        if let Err(e) = rset_connection(&mut conn).await {
211            tracing::debug!(
212                remote = conn.remote_key.as_str(),
213                "dropping connection after failed RSET: {}",
214                e
215            );
216            return;
217        }
218
219        // Enforce global cap before taking per-remote lock.
220        if self.total_idle.load(Ordering::Relaxed) >= self.config.global_cap {
221            tracing::debug!(
222                remote = conn.remote_key.as_str(),
223                "global pool cap reached, dropping connection"
224            );
225            return;
226        }
227
228        let remote_key = conn.remote_key.clone();
229
230        // Insert the bucket lazily.
231        let bucket = self
232            .conns
233            .entry(remote_key.clone())
234            .or_insert_with(|| Mutex::new(VecDeque::new()));
235
236        let mut deque = bucket.lock().await;
237
238        // Enforce per-remote cap.
239        if deque.len() >= self.config.per_remote_cap {
240            tracing::debug!(
241                remote = remote_key.as_str(),
242                "per-remote cap reached, dropping connection"
243            );
244            return;
245        }
246
247        conn.last_used = SystemTime::now();
248        deque.push_back(conn);
249        self.total_idle.fetch_add(1, Ordering::Relaxed);
250    }
251
252    /// Count currently idle connections (for testing / metrics).
253    pub fn idle_count(&self) -> usize {
254        self.total_idle.load(Ordering::Relaxed)
255    }
256
257    // ── Internal helpers ──────────────────────────────────────────────────
258
259    /// Open a fresh TCP connection, perform the SMTP handshake (220 greeting +
260    /// EHLO) and return the ready-to-use connection.
261    async fn open_fresh(&self, remote_key: &str) -> anyhow::Result<PooledConn> {
262        let stream = TcpStream::connect(remote_key)
263            .await
264            .map_err(|e| anyhow::anyhow!("SMTP outbound connect to {}: {}", remote_key, e))?;
265
266        let mut reader = BufReader::new(stream);
267
268        // Read 220 greeting.
269        let greeting = smtp_read_response_raw(&mut reader).await?;
270        if !greeting.starts_with("220") {
271            anyhow::bail!(
272                "unexpected SMTP greeting from {}: {}",
273                remote_key,
274                greeting.trim()
275            );
276        }
277
278        // Send EHLO.
279        smtp_write(&mut reader, "EHLO localhost\r\n").await?;
280        let ehlo_text = smtp_read_response_raw(&mut reader).await?;
281        if !ehlo_text.starts_with("250") {
282            anyhow::bail!("EHLO rejected by {}: {}", remote_key, ehlo_text.trim());
283        }
284
285        let extensions = SmtpExtensions::from_ehlo(&ehlo_text);
286
287        Ok(PooledConn {
288            reader,
289            last_used: SystemTime::now(),
290            extensions,
291            remote_key: remote_key.to_string(),
292        })
293    }
294
295    /// Iterate all buckets and drop connections idle longer than
296    /// `config.idle_timeout`.
297    async fn reap_idle(&self) {
298        let now = SystemTime::now();
299        let mut total_reaped = 0usize;
300
301        for bucket_ref in self.conns.iter() {
302            let mut deque = bucket_ref.value().lock().await;
303            let before = deque.len();
304            deque.retain(|conn| {
305                match conn.last_used.elapsed() {
306                    Ok(elapsed) => elapsed <= self.config.idle_timeout,
307                    // If the system clock went backward, keep the connection.
308                    Err(_) => true,
309                }
310            });
311            let reaped = before - deque.len();
312            total_reaped += reaped;
313        }
314
315        if total_reaped > 0 {
316            self.total_idle.fetch_sub(total_reaped, Ordering::Relaxed);
317            tracing::debug!(
318                "outbound pool idle reaper: closed {} connections",
319                total_reaped
320            );
321        }
322
323        let _ = now; // suppress unused warning
324    }
325}
326
327// ── Low-level SMTP helpers ────────────────────────────────────────────────────
328
329/// Write a command to the underlying `TcpStream` inside a `BufReader`.
330pub(crate) async fn smtp_write(
331    reader: &mut BufReader<TcpStream>,
332    cmd: &str,
333) -> std::io::Result<()> {
334    let stream = reader.get_mut();
335    stream.write_all(cmd.as_bytes()).await?;
336    stream.flush().await
337}
338
339/// Read a (possibly multi-line) SMTP response from `reader`.
340///
341/// Returns the entire response as a single `String` (lines joined with `\n`).
342pub(crate) async fn smtp_read_response_raw(
343    reader: &mut BufReader<TcpStream>,
344) -> std::io::Result<String> {
345    let mut full = String::new();
346    loop {
347        let mut line = String::new();
348        reader.read_line(&mut line).await?;
349        let is_last = line.len() >= 4 && line.as_bytes().get(3) == Some(&b' ');
350        full.push_str(&line);
351        if is_last {
352            break;
353        }
354    }
355    Ok(full)
356}
357
358/// Send `RSET\r\n` on `conn` and read the response.  Returns `Ok(())` on a
359/// `250` response; `Err` for any I/O error or unexpected response code.
360async fn rset_connection(conn: &mut PooledConn) -> anyhow::Result<()> {
361    smtp_write(&mut conn.reader, "RSET\r\n").await?;
362    let rset_resp = smtp_read_response_raw(&mut conn.reader).await?;
363    if !rset_resp.starts_with("250") {
364        anyhow::bail!("RSET rejected: {}", rset_resp.trim());
365    }
366    Ok(())
367}
368
369// ── Tests ─────────────────────────────────────────────────────────────────────
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use tokio::io::AsyncReadExt;
375    use tokio::net::TcpListener;
376
377    // ── Minimal fake SMTP server ──────────────────────────────────────────
378
379    /// Describes the behaviour the fake server should exhibit.
380    #[derive(Debug, Clone)]
381    struct FakeServerBehaviour {
382        /// How many connections to accept.
383        accept_count: usize,
384        /// Canned multi-line EHLO response (without trailing CRLF).
385        ehlo_response: String,
386        /// Whether to accept RSET.
387        accept_rset: bool,
388        /// Whether to accept MAIL FROM.
389        accept_mail: bool,
390        /// Whether to accept RCPT TO.
391        accept_rcpt: bool,
392        /// Whether to accept DATA.
393        accept_data: bool,
394    }
395
396    impl Default for FakeServerBehaviour {
397        fn default() -> Self {
398            Self {
399                accept_count: 1,
400                ehlo_response: "250-localhost\r\n250 PIPELINING\r\n".to_string(),
401                accept_rset: true,
402                accept_mail: true,
403                accept_rcpt: true,
404                accept_data: true,
405            }
406        }
407    }
408
409    /// Returns (port, connect_count_receiver).
410    ///
411    /// The channel yields the cumulative number of accepted connections; read
412    /// after your operations to assert exactly how many times a new TCP
413    /// connection was made.
414    async fn spawn_fake_smtp(
415        behaviour: FakeServerBehaviour,
416    ) -> (u16, tokio::sync::watch::Receiver<usize>) {
417        let listener = TcpListener::bind("127.0.0.1:0")
418            .await
419            .expect("bind fake smtp");
420        let port = listener.local_addr().expect("local addr").port();
421        let (tx, rx) = tokio::sync::watch::channel(0usize);
422
423        tokio::spawn(async move {
424            let mut count = 0usize;
425            while count < behaviour.accept_count {
426                let Ok((mut socket, _)) = listener.accept().await else {
427                    break;
428                };
429                count += 1;
430                let _ = tx.send(count);
431                let beh = behaviour.clone();
432
433                tokio::spawn(async move {
434                    // Greeting
435                    socket.write_all(b"220 localhost ESMTP\r\n").await.ok();
436
437                    // Read lines and respond until connection is closed or QUIT received.
438                    let mut buf = [0u8; 4096];
439                    loop {
440                        let n = match socket.read(&mut buf).await {
441                            Ok(0) | Err(_) => break,
442                            Ok(n) => n,
443                        };
444                        let raw = String::from_utf8_lossy(&buf[..n]);
445                        let cmd = raw.trim().to_ascii_uppercase();
446
447                        if cmd.starts_with("EHLO") || cmd.starts_with("HELO") {
448                            socket.write_all(beh.ehlo_response.as_bytes()).await.ok();
449                        } else if cmd.starts_with("RSET") {
450                            if beh.accept_rset {
451                                socket.write_all(b"250 OK\r\n").await.ok();
452                            } else {
453                                socket
454                                    .write_all(b"500 Command not recognized\r\n")
455                                    .await
456                                    .ok();
457                            }
458                        } else if cmd.starts_with("MAIL") {
459                            if beh.accept_mail {
460                                socket.write_all(b"250 OK\r\n").await.ok();
461                            } else {
462                                socket.write_all(b"550 Rejected\r\n").await.ok();
463                            }
464                        } else if cmd.starts_with("RCPT") {
465                            if beh.accept_rcpt {
466                                socket.write_all(b"250 OK\r\n").await.ok();
467                            } else {
468                                socket.write_all(b"550 Rejected\r\n").await.ok();
469                            }
470                        } else if cmd.starts_with("DATA") {
471                            if beh.accept_data {
472                                socket.write_all(b"354 Go ahead\r\n").await.ok();
473                                // Read until ".\r\n"
474                                let mut data_buf = [0u8; 4096];
475                                loop {
476                                    let dn = socket.read(&mut data_buf).await.unwrap_or(0);
477                                    if dn == 0 {
478                                        break;
479                                    }
480                                    let chunk = String::from_utf8_lossy(&data_buf[..dn]);
481                                    if chunk.contains("\r\n.\r\n") || chunk.trim() == "." {
482                                        socket.write_all(b"250 Queued\r\n").await.ok();
483                                        break;
484                                    }
485                                }
486                            } else {
487                                socket.write_all(b"550 Rejected\r\n").await.ok();
488                            }
489                        } else if cmd.starts_with("QUIT") {
490                            socket.write_all(b"221 Bye\r\n").await.ok();
491                            break;
492                        }
493                        // Unknown commands: ignore.
494                    }
495                });
496            }
497        });
498
499        (port, rx)
500    }
501
502    // ── Helper: build a pool with the fake server ─────────────────────────
503
504    fn make_pool(
505        config: OutboundPoolConfig,
506    ) -> (Arc<OutboundPool>, tokio::sync::watch::Sender<bool>) {
507        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
508        let pool = OutboundPool::new(config, shutdown_rx);
509        (pool, shutdown_tx)
510    }
511
512    // ── Tests ─────────────────────────────────────────────────────────────
513
514    /// Two back-to-back `get_or_connect` + `return_conn` cycles should result
515    /// in only ONE TCP connection being established (the second delivery reuses
516    /// the pooled connection).
517    #[tokio::test]
518    async fn test_outbound_pool_basic_reuse() {
519        let beh = FakeServerBehaviour {
520            accept_count: 2, // allow up to 2 but we should only use 1
521            ..Default::default()
522        };
523        let (port, connect_rx) = spawn_fake_smtp(beh).await;
524        let remote = format!("127.0.0.1:{}", port);
525
526        let config = OutboundPoolConfig {
527            per_remote_cap: 4,
528            global_cap: 16,
529            idle_timeout: Duration::from_secs(30),
530        };
531        let (pool, _tx) = make_pool(config);
532
533        // First delivery.
534        let conn1 = pool
535            .get_or_connect(&remote)
536            .await
537            .expect("first connect should succeed");
538        assert_eq!(
539            *connect_rx.borrow(),
540            1,
541            "one TCP connection after first get"
542        );
543        pool.return_conn(conn1).await;
544        assert_eq!(pool.idle_count(), 1, "one idle conn after return");
545
546        // Second delivery — should reuse.
547        let conn2 = pool
548            .get_or_connect(&remote)
549            .await
550            .expect("second get should succeed");
551        // Still exactly 1 TCP connect.
552        assert_eq!(
553            *connect_rx.borrow(),
554            1,
555            "connection count must stay at 1 (pooled reuse)"
556        );
557        pool.return_conn(conn2).await;
558        assert_eq!(pool.idle_count(), 1);
559    }
560
561    /// After idle_timeout elapses the background reaper must close idle connections.
562    #[tokio::test]
563    async fn test_outbound_pool_idle_reaper() {
564        let beh = FakeServerBehaviour {
565            accept_count: 1,
566            ..Default::default()
567        };
568        let (port, _connect_rx) = spawn_fake_smtp(beh).await;
569        let remote = format!("127.0.0.1:{}", port);
570
571        // Very short idle_timeout so the reaper fires quickly.
572        let idle_timeout = Duration::from_millis(80);
573        let config = OutboundPoolConfig {
574            per_remote_cap: 4,
575            global_cap: 16,
576            idle_timeout,
577        };
578        let (pool, _tx) = make_pool(config);
579
580        // Get a connection, return it, then wait longer than idle_timeout.
581        let conn = pool
582            .get_or_connect(&remote)
583            .await
584            .expect("connect must succeed");
585        pool.return_conn(conn).await;
586        assert_eq!(pool.idle_count(), 1, "one idle conn before timeout");
587
588        // Wait for the reaper to run (reaps every idle_timeout/2 = 40 ms).
589        tokio::time::sleep(idle_timeout * 3).await;
590
591        assert_eq!(
592            pool.idle_count(),
593            0,
594            "idle conn must be reaped after timeout"
595        );
596    }
597
598    /// `return_conn` must send RSET before putting the connection back.
599    #[tokio::test]
600    async fn test_outbound_pool_rset_on_return() {
601        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
602        let port = listener.local_addr().expect("local_addr").port();
603        let remote = format!("127.0.0.1:{}", port);
604
605        // Collect commands seen by the server.
606        let (seen_tx, mut seen_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
607
608        tokio::spawn(async move {
609            let Ok((mut socket, _)) = listener.accept().await else {
610                return;
611            };
612            socket.write_all(b"220 localhost ESMTP\r\n").await.ok();
613
614            let mut buf = [0u8; 4096];
615            loop {
616                let n = match socket.read(&mut buf).await {
617                    Ok(0) | Err(_) => break,
618                    Ok(n) => n,
619                };
620                let raw = String::from_utf8_lossy(&buf[..n]).to_string();
621                let cmd = raw.trim().to_ascii_uppercase();
622
623                if cmd.starts_with("EHLO") || cmd.starts_with("HELO") {
624                    socket.write_all(b"250 localhost\r\n").await.ok();
625                } else if cmd.starts_with("RSET") {
626                    let _ = seen_tx.send("RSET".to_string());
627                    socket.write_all(b"250 OK\r\n").await.ok();
628                } else if cmd.starts_with("QUIT") {
629                    socket.write_all(b"221 Bye\r\n").await.ok();
630                    break;
631                }
632            }
633        });
634
635        let config = OutboundPoolConfig::default();
636        let (pool, _tx) = make_pool(config);
637
638        let conn = pool
639            .get_or_connect(&remote)
640            .await
641            .expect("connect must succeed");
642        pool.return_conn(conn).await;
643
644        // The server should have received RSET.
645        let cmd = tokio::time::timeout(Duration::from_secs(2), seen_rx.recv())
646            .await
647            .expect("timed out waiting for RSET")
648            .expect("channel closed");
649        assert_eq!(cmd, "RSET");
650    }
651
652    /// `SmtpExtensions::from_ehlo` correctly parses a multi-line EHLO response.
653    #[test]
654    fn test_smtp_extensions_parsing() {
655        let ehlo = "250-localhost\r\n250-SIZE 10240000\r\n250-PIPELINING\r\n250-8BITMIME\r\n250 STARTTLS\r\n";
656        let ext = SmtpExtensions::from_ehlo(ehlo);
657        assert_eq!(ext.max_size, Some(10_240_000));
658        assert!(ext.pipelining);
659        assert!(ext.eight_bit_mime);
660        assert!(ext.starttls);
661    }
662}