samotop_delivery/smtp/util/
proto.rs

1use crate::smtp::authentication::Authentication;
2use crate::smtp::commands::*;
3use crate::smtp::extension::{ClientId, ServerInfo};
4use crate::smtp::response::parse_response;
5use crate::smtp::response::Response;
6use async_std::io::prelude::{ReadExt, WriteExt};
7use bytes::{Buf, BufMut, BytesMut};
8use samotop_core::common::*;
9use std::fmt::Display;
10use std::pin::Pin;
11use std::time::Duration;
12
13use crate::smtp::error::{Error, SmtpResult};
14use std::result::Result;
15
16/// Basic SMTP mail protocol client
17/// As a rule of thumb, this code only takes care of SMTP.
18/// No encryption or connection setup. Separating concerns.
19/// It wraps lightly around the provided stream
20/// to facilitate the execution of an SMTP session.
21#[derive(Debug)]
22pub struct SmtpProto<'s, S> {
23    stream: Pin<&'s mut S>,
24    buffer: BytesMut,
25    line_limit: usize,
26}
27
28impl<'s, S> SmtpProto<'s, S> {
29    pub fn new(stream: Pin<&'s mut S>) -> Self {
30        SmtpProto {
31            stream,
32            buffer: BytesMut::new(),
33            line_limit: 8000,
34        }
35    }
36    // pub fn with_line_limit(mut self, limit: usize) -> Self {
37    //     self.line_limit = limit;
38    //     self
39    // }
40    pub fn buffer(&self) -> &[u8] {
41        self.buffer.chunk()
42    }
43    pub fn stream_mut(&mut self) -> Pin<&mut S> {
44        self.stream.as_mut()
45    }
46    pub fn stream(&self) -> Pin<&S> {
47        self.stream.as_ref()
48    }
49    // pub fn into_stream(self) -> Pin<&'s mut S> {
50    //     self.stream
51    // }
52}
53impl<'s, S: io::Read + io::Write> SmtpProto<'s, S> {
54    /// Gets the server banner after connection.
55    pub async fn read_banner(&mut self, timeout: Duration) -> SmtpResult {
56        let banner_response = self.read_response(timeout).await?;
57        banner_response.is([220].as_ref())
58    }
59    /// Gets the server response after mail data have been fully sent.
60    pub async fn read_data_sent_response(&mut self, timeout: Duration) -> SmtpResult {
61        let data_response = self.read_response(timeout).await?;
62        data_response.is([250].as_ref())
63    }
64    /// Gets the EHLO (ESMTP) response and updates server information.
65    /// If this fails with 5xx error (pure SMTP), plain old HELO is used instead.
66    pub async fn execute_ehlo_or_helo(
67        &mut self,
68        me: ClientId,
69        timeout: Duration,
70    ) -> Result<(Response, ServerInfo), Error> {
71        match self.execute_ehlo(me.clone(), timeout).await {
72            Err(Error::Permanent(_resp)) => self.execute_helo(me, timeout).await,
73            otherwise => otherwise,
74        }
75    }
76    /// Gets the EHLO (ESMTP) response and updates server information.
77    /// If this fails with 5xx error (pure SMTP), one should try HELO instead.
78    pub async fn execute_ehlo(
79        &mut self,
80        me: ClientId,
81        timeout: Duration,
82    ) -> Result<(Response, ServerInfo), Error> {
83        // Extended Hello
84        let ehlo_response = self
85            .execute_command(HeloCommand::ehlo(me), [250], timeout)
86            .await?;
87        // extract extensions
88        let server_info = ServerInfo::from_response(&ehlo_response)?;
89        // Print server information
90        debug!("ehlo server info {}", server_info);
91
92        Ok((ehlo_response, server_info))
93    }
94    /// Gets the LHLO (LMTP) response and updates server information.
95    /// If this fails with 5xx error (pure SMTP), one should try HELO instead.
96    pub async fn execute_lhlo(
97        &mut self,
98        me: ClientId,
99        timeout: Duration,
100    ) -> Result<(Response, ServerInfo), Error> {
101        // LMTP HELO
102        let ehlo_response = self
103            .execute_command(HeloCommand::lhlo(me), [250], timeout)
104            .await?;
105        // extract extensions
106        let server_info = ServerInfo::from_response(&ehlo_response)?;
107        // Print server information
108        debug!("lhlo server info {}", server_info);
109
110        Ok((ehlo_response, server_info))
111    }
112    /// Gets the HELO (bare SMTP) response and updates server information.
113    pub async fn execute_helo(
114        &mut self,
115        me: ClientId,
116        timeout: Duration,
117    ) -> Result<(Response, ServerInfo), Error> {
118        // Basic HELO
119        let ehlo_response = self
120            .execute_command(HeloCommand::helo(me), [250], timeout)
121            .await?;
122        // extract extensions
123        let server_info = ServerInfo::from_response(&ehlo_response)?;
124        // Print server information
125        debug!("helo server info {}", server_info);
126
127        Ok((ehlo_response, server_info))
128    }
129    /// Sends STARTTLS, and confirms success message. Does not switch protocols!
130    /// Do that through the self.stream_mut() or self.into_inner()
131    pub async fn execute_starttls(&mut self, timeout: Duration) -> SmtpResult {
132        let response = self.execute_command(StarttlsCommand, [220], timeout).await;
133        response
134    }
135    /// Sends the rset command
136    pub async fn execute_rset(&mut self, timeout: Duration) -> SmtpResult {
137        let response = self.execute_command(RsetCommand, [250], timeout).await;
138        response
139    }
140    /// Sends the quit command
141    pub async fn execute_quit(&mut self, timeout: Duration) -> SmtpResult {
142        let response = self.execute_command(QuitCommand, [221], timeout).await;
143        response
144    }
145    /// Sends an AUTH command with the given mechanism, and handles challenge if needed
146    pub async fn authenticate<A: Authentication>(
147        &mut self,
148        mut authentication: A,
149        timeout: Duration,
150    ) -> SmtpResult {
151        // TODO
152        let mut challenges = 10u8;
153        let mut response = self
154            .execute_command(AuthCommand::new(&mut authentication)?, [334, 2], timeout)
155            .await?;
156
157        while challenges > 0 && response.has_code(334) {
158            challenges -= 1;
159            response = self
160                .execute_command(
161                    AuthResponse::new(&mut authentication, &response)?,
162                    [334, 2],
163                    timeout,
164                )
165                .await?;
166        }
167
168        if challenges == 0 {
169            Err(Error::ResponseParsing("Unexpected number of challenges"))
170        } else {
171            Ok(response)
172        }
173    }
174    pub async fn execute_command<C: Display, E: AsRef<[u16]>>(
175        &mut self,
176        command: C,
177        expected: E,
178        timeout: Duration,
179    ) -> SmtpResult {
180        let command = command.to_string();
181        debug!("C: {}", escape_crlf(command.as_str()));
182        let buff = command.as_bytes();
183        let written = self.write_bytes(buff, timeout).await?;
184        debug_assert!(written == buff.len(), "Make sure we write all the data");
185        self.stream.flush().await?;
186        let response = self.read_response(timeout).await?;
187        response.is(expected)
188    }
189    async fn write_bytes(&mut self, buf: &[u8], timeout: Duration) -> Result<usize, Error> {
190        with_timeout(timeout, self.stream.write(buf)).await
191    }
192    async fn read_response(&mut self, timeout: Duration) -> SmtpResult {
193        with_timeout(timeout, async move {
194            let mut enough = self.buffer.remaining() != 0;
195            loop {
196                self.buffer.reserve(1024);
197                let buf = self.buffer.chunk_mut();
198                if !enough {
199                    // It is OK to use uninitialized buffer as long as read fulfills the contract.
200                    // In other words, it will only use the given buffer for writing.
201                    // TODO: What's the story with clippy::transmute-ptr-to-ptr?
202                    #[allow(unsafe_code)]
203                    #[allow(clippy::transmute_ptr_to_ptr)]
204                    let buf = unsafe { std::mem::transmute(buf) };
205                    let read = self.stream.read(buf).await?;
206                    if read == 0 {
207                        return Err(io::Error::new(
208                            io::ErrorKind::Other,
209                            format!("incomplete after {} bytes", self.buffer().len()),
210                        )
211                        .into());
212                    }
213                    // It is OK to use uninitialized buffer as long as read fulfills the contract.
214                    // In other words, read bytes have been written at the beginning of the given buffer
215                    #[allow(unsafe_code)]
216                    unsafe {
217                        self.buffer.advance_mut(read)
218                    };
219                }
220                let response = std::str::from_utf8(self.buffer.chunk())?;
221                debug!("S: {}", escape_crlf(response));
222                break match parse_response(response) {
223                    Ok((remaining, response)) => {
224                        let consumed = self.buffer.remaining() - remaining.len();
225                        self.buffer.advance(consumed);
226                        response.is([2, 3].as_ref())
227                    }
228                    Err(nom::Err::Incomplete(_)) => {
229                        // read more unless there's too much
230                        if self.buffer.remaining() >= self.line_limit {
231                            Err(Error::ResponseParsing("Line limit reached"))
232                        } else {
233                            enough = false;
234                            continue;
235                        }
236                    }
237                    Err(nom::Err::Failure(e)) => Err(Error::Parsing(e.code)),
238                    Err(nom::Err::Error(e)) => Err(Error::Parsing(e.code)),
239                };
240            }
241        })
242        .await
243    }
244}
245
246/// Execute io operations with a timeout.
247async fn with_timeout<T, F, E, EOut>(timeout: Duration, f: F) -> std::result::Result<T, EOut>
248where
249    F: Future<Output = std::result::Result<T, E>>,
250    EOut: From<async_std::future::TimeoutError>,
251    EOut: From<E>,
252{
253    let res = async_std::future::timeout(timeout, f).await??;
254    Ok(res)
255}
256
257/// Returns the string replacing all the CRLF with "\<CRLF\>"
258/// Used for debug displays
259fn escape_crlf(string: &str) -> String {
260    string.replace("\r\n", "<CRLF>")
261}