1use crate::email::parse_message;
2use crate::repository::MessageRepository;
3use log::{info, warn};
4use std::net::TcpListener as StdTcpListener;
5use std::sync::{Arc, Mutex};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpListener, TcpStream};
8
9type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
10
11pub async fn run_smtp_server(
12 tcp_listener: StdTcpListener,
13 repository: Arc<Mutex<MessageRepository>>,
14) -> Result<()> {
15 info!(
16 "Starting SMTP server on {}",
17 tcp_listener.local_addr().unwrap()
18 );
19
20 tcp_listener.set_nonblocking(true).unwrap();
21 let listener = TcpListener::from_std(tcp_listener)?;
22
23 loop {
24 let (socket, remote_ip) = listener.accept().await?;
25 let session_repository = Arc::clone(&repository);
26
27 tokio::spawn(async move {
28 let server_impl = SmtpServerImplementation::new(session_repository);
29 let mut protocol = SmtpProtocol::new(socket, server_impl);
30
31 match protocol.execute().await {
32 Ok(_) => {}
33 Err(e) => {
34 warn!(
35 "An error occurred while executing the SMTP protocol with {}: {}",
36 remote_ip, e
37 );
38 }
39 };
40 });
41 }
42}
43
44#[derive(PartialEq)]
45enum SmtpProtocolState {
46 ReceivedEhlo,
47 ReceivedSender,
48 ReceivedRecipients,
49}
50
51struct SmtpProtocol {
52 server_impl: SmtpServerImplementation,
53 stream: SmtpStream,
54 state: Vec<SmtpProtocolState>,
55}
56
57impl SmtpProtocol {
58 fn new(stream: TcpStream, server_impl: SmtpServerImplementation) -> Self {
59 SmtpProtocol {
60 server_impl,
61 stream: SmtpStream::new(stream),
62 state: vec![],
63 }
64 }
65
66 async fn execute(&mut self) -> Result<()> {
67 self.say("220 Hello from Tiny MailCatcher's simple and dumb SMTP server\r\n")
68 .await?;
69
70 loop {
71 let line = self.stream.read_line().await?;
72 if line.is_none() {
73 return Ok(());
75 }
76
77 let line = std::str::from_utf8(line.unwrap());
78 if line.is_err() {
79 self.say("500 Invalid UTF-8 encountered\r\n").await?;
80
81 continue;
82 }
83
84 let line = line.unwrap();
85 let line_upper = line.to_ascii_uppercase();
86
87 if line_upper.starts_with("EHLO") {
88 let message = "\
89 250-Tiny MailCatcher - do not expect too much from me please\r\n\
90 250-NO-SOLLICITING\r\n\
91 250 SIZE 20000000\r\n";
92
93 self.say(message).await?;
94 self.reset();
95 self.state.push(SmtpProtocolState::ReceivedEhlo);
96 } else if line_upper.starts_with("HELO") {
97 self.say("250-Tiny MailCatcher - do not expect too much from me please\r\n")
98 .await?;
99 self.reset();
100 self.state.push(SmtpProtocolState::ReceivedEhlo);
101 } else if line_upper.starts_with("QUIT") {
102 self.say("220 Ok\r\n").await?;
103 return Ok(());
104 } else if line_upper.starts_with("MAIL FROM:") {
105 if self.state.contains(&SmtpProtocolState::ReceivedSender) {
106 self.say("503 MAIL already given\r\n").await?;
107 } else {
108 let sender = line["MAIL FROM:".len()..].trim();
109 self.server_impl.mail(sender);
110 self.say("250 Ok\r\n").await?;
111 self.state.push(SmtpProtocolState::ReceivedSender);
112 }
113 } else if line_upper.starts_with("RCPT TO:") {
114 if self.state.contains(&SmtpProtocolState::ReceivedSender) {
115 let recipient = line["RCPT TO:".len()..].trim();
116 self.server_impl.rcpt(recipient);
117 self.state.push(SmtpProtocolState::ReceivedRecipients);
118 self.say("250 Ok\r\n").await?;
119 } else {
120 self.say("503 MAIL is required before RCPT\r\n").await?;
121 }
122 } else if line_upper.starts_with("DATA") {
123 if self.state.contains(&SmtpProtocolState::ReceivedRecipients) {
124 self.say("354 Send it\r\n").await?;
125
126 let data = self.stream.read_data().await?;
127 if data.is_none() {
128 return Ok(());
129 }
130
131 self.server_impl.data(data.unwrap())?;
132
133 self.say("250 Message accepted\r\n").await?;
134
135 self.state
137 .retain(|s| s != &SmtpProtocolState::ReceivedSender);
138 self.state
139 .retain(|s| s != &SmtpProtocolState::ReceivedRecipients);
140 } else {
141 self.say("503 Operation sequence error\r\n").await?;
142 }
143 } else if line_upper.starts_with("NOOP") {
144 self.say("250 Ok\r\n").await?;
145 } else if line_upper.starts_with("RSET") {
146 self.reset();
147 self.say("250 Ok\r\n").await?;
148 } else {
149 self.say("500 Unknown command\r\n").await?;
151 }
152 }
153 }
154
155 async fn say(&mut self, message: &str) -> Result<()> {
156 self.stream.write_all(message.as_bytes()).await?;
157
158 Ok(())
159 }
160
161 fn reset(&mut self) {
162 self.state.clear();
163 self.server_impl.reset();
164 }
165}
166
167struct SmtpStream {
168 inner: TcpStream,
169 line_buffer: Vec<u8>,
170 data_buffer: Vec<u8>,
171}
172
173impl SmtpStream {
174 fn new(stream: TcpStream) -> Self {
175 SmtpStream {
176 inner: stream,
177 line_buffer: Vec::with_capacity(8 * 1024),
178 data_buffer: Vec::with_capacity(64 * 1024),
179 }
180 }
181
182 async fn write_all(&mut self, message: &[u8]) -> Result<()> {
183 self.inner.write_all(message).await?;
184
185 Ok(())
186 }
187
188 async fn read_line(&mut self) -> Result<Option<&[u8]>> {
189 self.line_buffer.clear();
190
191 loop {
192 let mut read_buffer = [0u8; 8 * 1024];
193 let bytes_read = self.inner.read(&mut read_buffer).await?;
194
195 if bytes_read == 0 {
196 return Ok(None);
197 }
198
199 if let Some(i) = memchr::memchr(b'\n', &read_buffer[..bytes_read]) {
200 self.line_buffer.extend_from_slice(&read_buffer[0..=i]);
201
202 return Ok(Some(&self.line_buffer[..self.line_buffer.len() - 2]));
204 } else {
205 self.line_buffer
206 .extend_from_slice(&read_buffer[..bytes_read]);
207 }
208 }
209 }
210
211 async fn read_data(&mut self) -> Result<Option<&[u8]>> {
212 self.data_buffer.clear();
213
214 loop {
215 let mut read_buffer = [0u8; 16 * 1024];
216 let bytes_read = self.inner.read(&mut read_buffer).await?;
217
218 if bytes_read == 0 {
219 return Ok(None);
220 }
221
222 self.data_buffer
223 .extend_from_slice(&read_buffer[..bytes_read]);
224
225 let data_end_delimiter = b"\r\n.\r\n";
226 if self.data_buffer.len() >= data_end_delimiter.len()
227 && &self.data_buffer[self.data_buffer.len() - data_end_delimiter.len()..]
228 == data_end_delimiter
229 {
230 return Ok(Some(
231 &self.data_buffer[..self.data_buffer.len() - data_end_delimiter.len() + 2],
232 ));
233 }
234 }
235 }
236}
237
238struct SmtpServerImplementation {
239 sender: Option<String>,
240 recipients: Vec<String>,
241 repository: Arc<Mutex<MessageRepository>>,
242}
243
244impl SmtpServerImplementation {
245 fn new(repository: Arc<Mutex<MessageRepository>>) -> Self {
246 SmtpServerImplementation {
247 sender: None,
248 recipients: Vec::new(),
249 repository,
250 }
251 }
252
253 fn reset(&mut self) {
254 self.sender = None;
255 self.recipients = Vec::new();
256 }
257
258 fn mail(&mut self, sender: &str) {
259 self.sender = Some(sender.to_string());
260 }
261
262 fn rcpt(&mut self, recipient: &str) {
263 self.recipients.push(recipient.to_string());
264 }
265
266 fn data(&mut self, buf: &[u8]) -> Result<()> {
267 let size_extension_index = self.sender.as_ref().and_then(|s| s.rfind("SIZE="));
269
270 let sender = if let Some(idx) = size_extension_index {
271 self.sender.as_ref().map(|s| s[..idx].to_string())
272 } else {
273 self.sender.clone()
274 };
275
276 let message = parse_message(&sender, &self.recipients, buf);
277 if message.is_err() {
278 return Err(Box::new(message.err().unwrap()));
279 }
280 let message = message.unwrap();
281
282 info!(
283 "Received message from {} ({} bytes)",
284 sender.as_deref().unwrap_or("unknown sender"),
285 buf.len(),
286 );
287
288 self.repository.lock().unwrap().persist(message);
289
290 Ok(())
291 }
292}