1use dashmap::DashMap;
30use std::collections::VecDeque;
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, SystemTime};
34use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
35use tokio::net::TcpStream;
36use tokio::sync::Mutex;
37
38#[derive(Debug, Clone, Default)]
46pub struct SmtpExtensions {
47 pub max_size: Option<usize>,
49 pub pipelining: bool,
51 pub eight_bit_mime: bool,
53 pub starttls: bool,
55}
56
57impl SmtpExtensions {
58 pub fn from_ehlo(ehlo_text: &str) -> Self {
60 let mut ext = SmtpExtensions::default();
61 for line in ehlo_text.lines() {
62 let keyword = line
64 .trim_start_matches(|c: char| c.is_ascii_digit())
65 .trim_start_matches(['-', ' '])
66 .to_ascii_uppercase();
67
68 if keyword.starts_with("SIZE") {
69 let parts: Vec<&str> = keyword.splitn(2, ' ').collect();
70 if parts.len() == 2 {
71 ext.max_size = parts[1].trim().parse().ok();
72 }
73 } else if keyword == "PIPELINING" {
74 ext.pipelining = true;
75 } else if keyword == "8BITMIME" {
76 ext.eight_bit_mime = true;
77 } else if keyword == "STARTTLS" {
78 ext.starttls = true;
79 }
80 }
81 ext
82 }
83}
84
85pub struct PooledConn {
89 pub reader: BufReader<TcpStream>,
91 pub last_used: SystemTime,
93 pub extensions: SmtpExtensions,
95 pub remote_key: String,
97}
98
99impl PooledConn {
100 pub fn stream_mut(&mut self) -> &mut TcpStream {
103 self.reader.get_mut()
104 }
105}
106
107#[derive(Debug, Clone)]
114pub struct OutboundPoolConfig {
115 pub per_remote_cap: usize,
117 pub global_cap: usize,
119 pub idle_timeout: Duration,
121}
122
123impl Default for OutboundPoolConfig {
124 fn default() -> Self {
125 Self {
126 per_remote_cap: 8,
127 global_cap: 256,
128 idle_timeout: Duration::from_secs(30),
129 }
130 }
131}
132
133pub struct OutboundPool {
141 conns: DashMap<String, Mutex<VecDeque<PooledConn>>>,
143 config: OutboundPoolConfig,
144 total_idle: Arc<AtomicUsize>,
146}
147
148impl OutboundPool {
149 pub fn new(
153 config: OutboundPoolConfig,
154 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
155 ) -> Arc<Self> {
156 let pool = Arc::new(Self {
157 conns: DashMap::new(),
158 config: config.clone(),
159 total_idle: Arc::new(AtomicUsize::new(0)),
160 });
161
162 let reaper_pool = pool.clone();
164 let reap_interval = config.idle_timeout / 2;
165
166 tokio::spawn(async move {
167 loop {
168 tokio::select! {
169 _ = tokio::time::sleep(reap_interval) => {}
170 _ = shutdown_rx.changed() => {
171 if *shutdown_rx.borrow() {
172 break;
173 }
174 }
175 }
176 reaper_pool.reap_idle().await;
177 }
178 });
179
180 pool
181 }
182
183 pub async fn get_or_connect(&self, remote_key: &str) -> anyhow::Result<PooledConn> {
190 if let Some(bucket) = self.conns.get(remote_key) {
192 let mut deque = bucket.lock().await;
193 if let Some(conn) = deque.pop_front() {
194 self.total_idle.fetch_sub(1, Ordering::Relaxed);
195 return Ok(conn);
196 }
197 }
198
199 self.open_fresh(remote_key).await
201 }
202
203 pub async fn return_conn(&self, mut conn: PooledConn) {
209 if let Err(e) = rset_connection(&mut conn).await {
211 tracing::debug!(
212 remote = conn.remote_key.as_str(),
213 "dropping connection after failed RSET: {}",
214 e
215 );
216 return;
217 }
218
219 if self.total_idle.load(Ordering::Relaxed) >= self.config.global_cap {
221 tracing::debug!(
222 remote = conn.remote_key.as_str(),
223 "global pool cap reached, dropping connection"
224 );
225 return;
226 }
227
228 let remote_key = conn.remote_key.clone();
229
230 let bucket = self
232 .conns
233 .entry(remote_key.clone())
234 .or_insert_with(|| Mutex::new(VecDeque::new()));
235
236 let mut deque = bucket.lock().await;
237
238 if deque.len() >= self.config.per_remote_cap {
240 tracing::debug!(
241 remote = remote_key.as_str(),
242 "per-remote cap reached, dropping connection"
243 );
244 return;
245 }
246
247 conn.last_used = SystemTime::now();
248 deque.push_back(conn);
249 self.total_idle.fetch_add(1, Ordering::Relaxed);
250 }
251
252 pub fn idle_count(&self) -> usize {
254 self.total_idle.load(Ordering::Relaxed)
255 }
256
257 async fn open_fresh(&self, remote_key: &str) -> anyhow::Result<PooledConn> {
262 let stream = TcpStream::connect(remote_key)
263 .await
264 .map_err(|e| anyhow::anyhow!("SMTP outbound connect to {}: {}", remote_key, e))?;
265
266 let mut reader = BufReader::new(stream);
267
268 let greeting = smtp_read_response_raw(&mut reader).await?;
270 if !greeting.starts_with("220") {
271 anyhow::bail!(
272 "unexpected SMTP greeting from {}: {}",
273 remote_key,
274 greeting.trim()
275 );
276 }
277
278 smtp_write(&mut reader, "EHLO localhost\r\n").await?;
280 let ehlo_text = smtp_read_response_raw(&mut reader).await?;
281 if !ehlo_text.starts_with("250") {
282 anyhow::bail!("EHLO rejected by {}: {}", remote_key, ehlo_text.trim());
283 }
284
285 let extensions = SmtpExtensions::from_ehlo(&ehlo_text);
286
287 Ok(PooledConn {
288 reader,
289 last_used: SystemTime::now(),
290 extensions,
291 remote_key: remote_key.to_string(),
292 })
293 }
294
295 async fn reap_idle(&self) {
298 let now = SystemTime::now();
299 let mut total_reaped = 0usize;
300
301 for bucket_ref in self.conns.iter() {
302 let mut deque = bucket_ref.value().lock().await;
303 let before = deque.len();
304 deque.retain(|conn| {
305 match conn.last_used.elapsed() {
306 Ok(elapsed) => elapsed <= self.config.idle_timeout,
307 Err(_) => true,
309 }
310 });
311 let reaped = before - deque.len();
312 total_reaped += reaped;
313 }
314
315 if total_reaped > 0 {
316 self.total_idle.fetch_sub(total_reaped, Ordering::Relaxed);
317 tracing::debug!(
318 "outbound pool idle reaper: closed {} connections",
319 total_reaped
320 );
321 }
322
323 let _ = now; }
325}
326
327pub(crate) async fn smtp_write(
331 reader: &mut BufReader<TcpStream>,
332 cmd: &str,
333) -> std::io::Result<()> {
334 let stream = reader.get_mut();
335 stream.write_all(cmd.as_bytes()).await?;
336 stream.flush().await
337}
338
339pub(crate) async fn smtp_read_response_raw(
343 reader: &mut BufReader<TcpStream>,
344) -> std::io::Result<String> {
345 let mut full = String::new();
346 loop {
347 let mut line = String::new();
348 reader.read_line(&mut line).await?;
349 let is_last = line.len() >= 4 && line.as_bytes().get(3) == Some(&b' ');
350 full.push_str(&line);
351 if is_last {
352 break;
353 }
354 }
355 Ok(full)
356}
357
358async fn rset_connection(conn: &mut PooledConn) -> anyhow::Result<()> {
361 smtp_write(&mut conn.reader, "RSET\r\n").await?;
362 let rset_resp = smtp_read_response_raw(&mut conn.reader).await?;
363 if !rset_resp.starts_with("250") {
364 anyhow::bail!("RSET rejected: {}", rset_resp.trim());
365 }
366 Ok(())
367}
368
369#[cfg(test)]
372mod tests {
373 use super::*;
374 use tokio::io::AsyncReadExt;
375 use tokio::net::TcpListener;
376
377 #[derive(Debug, Clone)]
381 struct FakeServerBehaviour {
382 accept_count: usize,
384 ehlo_response: String,
386 accept_rset: bool,
388 accept_mail: bool,
390 accept_rcpt: bool,
392 accept_data: bool,
394 }
395
396 impl Default for FakeServerBehaviour {
397 fn default() -> Self {
398 Self {
399 accept_count: 1,
400 ehlo_response: "250-localhost\r\n250 PIPELINING\r\n".to_string(),
401 accept_rset: true,
402 accept_mail: true,
403 accept_rcpt: true,
404 accept_data: true,
405 }
406 }
407 }
408
409 async fn spawn_fake_smtp(
415 behaviour: FakeServerBehaviour,
416 ) -> (u16, tokio::sync::watch::Receiver<usize>) {
417 let listener = TcpListener::bind("127.0.0.1:0")
418 .await
419 .expect("bind fake smtp");
420 let port = listener.local_addr().expect("local addr").port();
421 let (tx, rx) = tokio::sync::watch::channel(0usize);
422
423 tokio::spawn(async move {
424 let mut count = 0usize;
425 while count < behaviour.accept_count {
426 let Ok((mut socket, _)) = listener.accept().await else {
427 break;
428 };
429 count += 1;
430 let _ = tx.send(count);
431 let beh = behaviour.clone();
432
433 tokio::spawn(async move {
434 socket.write_all(b"220 localhost ESMTP\r\n").await.ok();
436
437 let mut buf = [0u8; 4096];
439 loop {
440 let n = match socket.read(&mut buf).await {
441 Ok(0) | Err(_) => break,
442 Ok(n) => n,
443 };
444 let raw = String::from_utf8_lossy(&buf[..n]);
445 let cmd = raw.trim().to_ascii_uppercase();
446
447 if cmd.starts_with("EHLO") || cmd.starts_with("HELO") {
448 socket.write_all(beh.ehlo_response.as_bytes()).await.ok();
449 } else if cmd.starts_with("RSET") {
450 if beh.accept_rset {
451 socket.write_all(b"250 OK\r\n").await.ok();
452 } else {
453 socket
454 .write_all(b"500 Command not recognized\r\n")
455 .await
456 .ok();
457 }
458 } else if cmd.starts_with("MAIL") {
459 if beh.accept_mail {
460 socket.write_all(b"250 OK\r\n").await.ok();
461 } else {
462 socket.write_all(b"550 Rejected\r\n").await.ok();
463 }
464 } else if cmd.starts_with("RCPT") {
465 if beh.accept_rcpt {
466 socket.write_all(b"250 OK\r\n").await.ok();
467 } else {
468 socket.write_all(b"550 Rejected\r\n").await.ok();
469 }
470 } else if cmd.starts_with("DATA") {
471 if beh.accept_data {
472 socket.write_all(b"354 Go ahead\r\n").await.ok();
473 let mut data_buf = [0u8; 4096];
475 loop {
476 let dn = socket.read(&mut data_buf).await.unwrap_or(0);
477 if dn == 0 {
478 break;
479 }
480 let chunk = String::from_utf8_lossy(&data_buf[..dn]);
481 if chunk.contains("\r\n.\r\n") || chunk.trim() == "." {
482 socket.write_all(b"250 Queued\r\n").await.ok();
483 break;
484 }
485 }
486 } else {
487 socket.write_all(b"550 Rejected\r\n").await.ok();
488 }
489 } else if cmd.starts_with("QUIT") {
490 socket.write_all(b"221 Bye\r\n").await.ok();
491 break;
492 }
493 }
495 });
496 }
497 });
498
499 (port, rx)
500 }
501
502 fn make_pool(
505 config: OutboundPoolConfig,
506 ) -> (Arc<OutboundPool>, tokio::sync::watch::Sender<bool>) {
507 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
508 let pool = OutboundPool::new(config, shutdown_rx);
509 (pool, shutdown_tx)
510 }
511
512 #[tokio::test]
518 async fn test_outbound_pool_basic_reuse() {
519 let beh = FakeServerBehaviour {
520 accept_count: 2, ..Default::default()
522 };
523 let (port, connect_rx) = spawn_fake_smtp(beh).await;
524 let remote = format!("127.0.0.1:{}", port);
525
526 let config = OutboundPoolConfig {
527 per_remote_cap: 4,
528 global_cap: 16,
529 idle_timeout: Duration::from_secs(30),
530 };
531 let (pool, _tx) = make_pool(config);
532
533 let conn1 = pool
535 .get_or_connect(&remote)
536 .await
537 .expect("first connect should succeed");
538 assert_eq!(
539 *connect_rx.borrow(),
540 1,
541 "one TCP connection after first get"
542 );
543 pool.return_conn(conn1).await;
544 assert_eq!(pool.idle_count(), 1, "one idle conn after return");
545
546 let conn2 = pool
548 .get_or_connect(&remote)
549 .await
550 .expect("second get should succeed");
551 assert_eq!(
553 *connect_rx.borrow(),
554 1,
555 "connection count must stay at 1 (pooled reuse)"
556 );
557 pool.return_conn(conn2).await;
558 assert_eq!(pool.idle_count(), 1);
559 }
560
561 #[tokio::test]
563 async fn test_outbound_pool_idle_reaper() {
564 let beh = FakeServerBehaviour {
565 accept_count: 1,
566 ..Default::default()
567 };
568 let (port, _connect_rx) = spawn_fake_smtp(beh).await;
569 let remote = format!("127.0.0.1:{}", port);
570
571 let idle_timeout = Duration::from_millis(80);
573 let config = OutboundPoolConfig {
574 per_remote_cap: 4,
575 global_cap: 16,
576 idle_timeout,
577 };
578 let (pool, _tx) = make_pool(config);
579
580 let conn = pool
582 .get_or_connect(&remote)
583 .await
584 .expect("connect must succeed");
585 pool.return_conn(conn).await;
586 assert_eq!(pool.idle_count(), 1, "one idle conn before timeout");
587
588 tokio::time::sleep(idle_timeout * 3).await;
590
591 assert_eq!(
592 pool.idle_count(),
593 0,
594 "idle conn must be reaped after timeout"
595 );
596 }
597
598 #[tokio::test]
600 async fn test_outbound_pool_rset_on_return() {
601 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
602 let port = listener.local_addr().expect("local_addr").port();
603 let remote = format!("127.0.0.1:{}", port);
604
605 let (seen_tx, mut seen_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
607
608 tokio::spawn(async move {
609 let Ok((mut socket, _)) = listener.accept().await else {
610 return;
611 };
612 socket.write_all(b"220 localhost ESMTP\r\n").await.ok();
613
614 let mut buf = [0u8; 4096];
615 loop {
616 let n = match socket.read(&mut buf).await {
617 Ok(0) | Err(_) => break,
618 Ok(n) => n,
619 };
620 let raw = String::from_utf8_lossy(&buf[..n]).to_string();
621 let cmd = raw.trim().to_ascii_uppercase();
622
623 if cmd.starts_with("EHLO") || cmd.starts_with("HELO") {
624 socket.write_all(b"250 localhost\r\n").await.ok();
625 } else if cmd.starts_with("RSET") {
626 let _ = seen_tx.send("RSET".to_string());
627 socket.write_all(b"250 OK\r\n").await.ok();
628 } else if cmd.starts_with("QUIT") {
629 socket.write_all(b"221 Bye\r\n").await.ok();
630 break;
631 }
632 }
633 });
634
635 let config = OutboundPoolConfig::default();
636 let (pool, _tx) = make_pool(config);
637
638 let conn = pool
639 .get_or_connect(&remote)
640 .await
641 .expect("connect must succeed");
642 pool.return_conn(conn).await;
643
644 let cmd = tokio::time::timeout(Duration::from_secs(2), seen_rx.recv())
646 .await
647 .expect("timed out waiting for RSET")
648 .expect("channel closed");
649 assert_eq!(cmd, "RSET");
650 }
651
652 #[test]
654 fn test_smtp_extensions_parsing() {
655 let ehlo = "250-localhost\r\n250-SIZE 10240000\r\n250-PIPELINING\r\n250-8BITMIME\r\n250 STARTTLS\r\n";
656 let ext = SmtpExtensions::from_ehlo(ehlo);
657 assert_eq!(ext.max_size, Some(10_240_000));
658 assert!(ext.pipelining);
659 assert!(ext.eight_bit_mime);
660 assert!(ext.starttls);
661 }
662}