tiny_mailcatcher/
smtp.rs

1use crate::email::parse_message;
2use crate::repository::MessageRepository;
3use log::{info, warn};
4use std::net::TcpListener as StdTcpListener;
5use std::sync::{Arc, Mutex};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpListener, TcpStream};
8
9type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
10
11pub async fn run_smtp_server(
12    tcp_listener: StdTcpListener,
13    repository: Arc<Mutex<MessageRepository>>,
14) -> Result<()> {
15    info!(
16        "Starting SMTP server on {}",
17        tcp_listener.local_addr().unwrap()
18    );
19
20    tcp_listener.set_nonblocking(true).unwrap();
21    let listener = TcpListener::from_std(tcp_listener)?;
22
23    loop {
24        let (socket, remote_ip) = listener.accept().await?;
25        let session_repository = Arc::clone(&repository);
26
27        tokio::spawn(async move {
28            let server_impl = SmtpServerImplementation::new(session_repository);
29            let mut protocol = SmtpProtocol::new(socket, server_impl);
30
31            match protocol.execute().await {
32                Ok(_) => {}
33                Err(e) => {
34                    warn!(
35                        "An error occurred while executing the SMTP protocol with {}: {}",
36                        remote_ip, e
37                    );
38                }
39            };
40        });
41    }
42}
43
44#[derive(PartialEq)]
45enum SmtpProtocolState {
46    ReceivedEhlo,
47    ReceivedSender,
48    ReceivedRecipients,
49}
50
51struct SmtpProtocol {
52    server_impl: SmtpServerImplementation,
53    stream: SmtpStream,
54    state: Vec<SmtpProtocolState>,
55}
56
57impl SmtpProtocol {
58    fn new(stream: TcpStream, server_impl: SmtpServerImplementation) -> Self {
59        SmtpProtocol {
60            server_impl,
61            stream: SmtpStream::new(stream),
62            state: vec![],
63        }
64    }
65
66    async fn execute(&mut self) -> Result<()> {
67        self.say("220 Hello from Tiny MailCatcher's simple and dumb SMTP server\r\n")
68            .await?;
69
70        loop {
71            let line = self.stream.read_line().await?;
72            if line.is_none() {
73                // EOF
74                return Ok(());
75            }
76
77            let line = std::str::from_utf8(line.unwrap());
78            if line.is_err() {
79                self.say("500 Invalid UTF-8 encountered\r\n").await?;
80
81                continue;
82            }
83
84            let line = line.unwrap();
85            let line_upper = line.to_ascii_uppercase();
86
87            if line_upper.starts_with("EHLO") {
88                let message = "\
89                    250-Tiny MailCatcher - do not expect too much from me please\r\n\
90                    250-NO-SOLLICITING\r\n\
91                    250 SIZE 20000000\r\n";
92
93                self.say(message).await?;
94                self.reset();
95                self.state.push(SmtpProtocolState::ReceivedEhlo);
96            } else if line_upper.starts_with("HELO") {
97                self.say("250-Tiny MailCatcher - do not expect too much from me please\r\n")
98                    .await?;
99                self.reset();
100                self.state.push(SmtpProtocolState::ReceivedEhlo);
101            } else if line_upper.starts_with("QUIT") {
102                self.say("220 Ok\r\n").await?;
103                return Ok(());
104            } else if line_upper.starts_with("MAIL FROM:") {
105                if self.state.contains(&SmtpProtocolState::ReceivedSender) {
106                    self.say("503 MAIL already given\r\n").await?;
107                } else {
108                    let sender = line["MAIL FROM:".len()..].trim();
109                    self.server_impl.mail(sender);
110                    self.say("250 Ok\r\n").await?;
111                    self.state.push(SmtpProtocolState::ReceivedSender);
112                }
113            } else if line_upper.starts_with("RCPT TO:") {
114                if self.state.contains(&SmtpProtocolState::ReceivedSender) {
115                    let recipient = line["RCPT TO:".len()..].trim();
116                    self.server_impl.rcpt(recipient);
117                    self.state.push(SmtpProtocolState::ReceivedRecipients);
118                    self.say("250 Ok\r\n").await?;
119                } else {
120                    self.say("503 MAIL is required before RCPT\r\n").await?;
121                }
122            } else if line_upper.starts_with("DATA") {
123                if self.state.contains(&SmtpProtocolState::ReceivedRecipients) {
124                    self.say("354 Send it\r\n").await?;
125
126                    let data = self.stream.read_data().await?;
127                    if data.is_none() {
128                        return Ok(());
129                    }
130
131                    self.server_impl.data(data.unwrap())?;
132
133                    self.say("250 Message accepted\r\n").await?;
134
135                    // Remove these two states
136                    self.state
137                        .retain(|s| s != &SmtpProtocolState::ReceivedSender);
138                    self.state
139                        .retain(|s| s != &SmtpProtocolState::ReceivedRecipients);
140                } else {
141                    self.say("503 Operation sequence error\r\n").await?;
142                }
143            } else if line_upper.starts_with("NOOP") {
144                self.say("250 Ok\r\n").await?;
145            } else if line_upper.starts_with("RSET") {
146                self.reset();
147                self.say("250 Ok\r\n").await?;
148            } else {
149                // Unknown command
150                self.say("500 Unknown command\r\n").await?;
151            }
152        }
153    }
154
155    async fn say(&mut self, message: &str) -> Result<()> {
156        self.stream.write_all(message.as_bytes()).await?;
157
158        Ok(())
159    }
160
161    fn reset(&mut self) {
162        self.state.clear();
163        self.server_impl.reset();
164    }
165}
166
167struct SmtpStream {
168    inner: TcpStream,
169    line_buffer: Vec<u8>,
170    data_buffer: Vec<u8>,
171}
172
173impl SmtpStream {
174    fn new(stream: TcpStream) -> Self {
175        SmtpStream {
176            inner: stream,
177            line_buffer: Vec::with_capacity(8 * 1024),
178            data_buffer: Vec::with_capacity(64 * 1024),
179        }
180    }
181
182    async fn write_all(&mut self, message: &[u8]) -> Result<()> {
183        self.inner.write_all(message).await?;
184
185        Ok(())
186    }
187
188    async fn read_line(&mut self) -> Result<Option<&[u8]>> {
189        self.line_buffer.clear();
190
191        loop {
192            let mut read_buffer = [0u8; 8 * 1024];
193            let bytes_read = self.inner.read(&mut read_buffer).await?;
194
195            if bytes_read == 0 {
196                return Ok(None);
197            }
198
199            if let Some(i) = memchr::memchr(b'\n', &read_buffer[..bytes_read]) {
200                self.line_buffer.extend_from_slice(&read_buffer[0..=i]);
201
202                // slice off \r\n
203                return Ok(Some(&self.line_buffer[..self.line_buffer.len() - 2]));
204            } else {
205                self.line_buffer
206                    .extend_from_slice(&read_buffer[..bytes_read]);
207            }
208        }
209    }
210
211    async fn read_data(&mut self) -> Result<Option<&[u8]>> {
212        self.data_buffer.clear();
213
214        loop {
215            let mut read_buffer = [0u8; 16 * 1024];
216            let bytes_read = self.inner.read(&mut read_buffer).await?;
217
218            if bytes_read == 0 {
219                return Ok(None);
220            }
221
222            self.data_buffer
223                .extend_from_slice(&read_buffer[..bytes_read]);
224
225            let data_end_delimiter = b"\r\n.\r\n";
226            if self.data_buffer.len() >= data_end_delimiter.len()
227                && &self.data_buffer[self.data_buffer.len() - data_end_delimiter.len()..]
228                    == data_end_delimiter
229            {
230                return Ok(Some(
231                    &self.data_buffer[..self.data_buffer.len() - data_end_delimiter.len() + 2],
232                ));
233            }
234        }
235    }
236}
237
238struct SmtpServerImplementation {
239    sender: Option<String>,
240    recipients: Vec<String>,
241    repository: Arc<Mutex<MessageRepository>>,
242}
243
244impl SmtpServerImplementation {
245    fn new(repository: Arc<Mutex<MessageRepository>>) -> Self {
246        SmtpServerImplementation {
247            sender: None,
248            recipients: Vec::new(),
249            repository,
250        }
251    }
252
253    fn reset(&mut self) {
254        self.sender = None;
255        self.recipients = Vec::new();
256    }
257
258    fn mail(&mut self, sender: &str) {
259        self.sender = Some(sender.to_string());
260    }
261
262    fn rcpt(&mut self, recipient: &str) {
263        self.recipients.push(recipient.to_string());
264    }
265
266    fn data(&mut self, buf: &[u8]) -> Result<()> {
267        // strip size extension
268        let size_extension_index = self.sender.as_ref().and_then(|s| s.rfind("SIZE="));
269
270        let sender = if let Some(idx) = size_extension_index {
271            self.sender.as_ref().map(|s| s[..idx].to_string())
272        } else {
273            self.sender.clone()
274        };
275
276        let message = parse_message(&sender, &self.recipients, buf);
277        if message.is_err() {
278            return Err(Box::new(message.err().unwrap()));
279        }
280        let message = message.unwrap();
281
282        info!(
283            "Received message from {} ({} bytes)",
284            sender.as_deref().unwrap_or("unknown sender"),
285            buf.len(),
286        );
287
288        self.repository.lock().unwrap().persist(message);
289
290        Ok(())
291    }
292}