tokio_imap/
client.rs

1use std::future::Future;
2use std::io;
3use std::net::ToSocketAddrs;
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio::net::TcpStream;
9use tokio_rustls::rustls::ClientConfig;
10use tokio_rustls::webpki::DNSNameRef;
11use tokio_rustls::{client::TlsStream, TlsConnector};
12use tokio_util::codec::Decoder;
13
14use crate::proto::{ImapCodec, ImapTransport, ResponseData};
15use imap_proto::builders::command::Command;
16use imap_proto::{Request, RequestId, State};
17
18pub mod builder {
19    pub use imap_proto::builders::command::{
20        CommandBuilder, FetchBuilderAttributes, FetchBuilderMessages, FetchBuilderModifiers,
21        FetchCommand, FetchCommandAttributes, FetchCommandMessages,
22    };
23}
24
25pub type TlsClient = Client<TlsStream<TcpStream>>;
26
27pub struct Client<T> {
28    transport: ImapTransport<T>,
29    state: ClientState,
30}
31
32impl TlsClient {
33    pub async fn connect(server: &str) -> io::Result<(ResponseData, Self)> {
34        let addr = (server, 993).to_socket_addrs()?.next().ok_or_else(|| {
35            io::Error::new(
36                io::ErrorKind::Other,
37                format!("no IP addresses found for {}", server),
38            )
39        })?;
40
41        let mut tls_config = ClientConfig::new();
42        tls_config
43            .root_store
44            .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
45        let connector: TlsConnector = Arc::new(tls_config).into();
46
47        let stream = TcpStream::connect(&addr).await?;
48        let stream = connector
49            .connect(DNSNameRef::try_from_ascii_str(server).unwrap(), stream)
50            .await?;
51        let mut transport = ImapCodec::default().framed(stream);
52
53        let greeting = match transport.next().await {
54            Some(greeting) => Ok(greeting),
55            None => Err(io::Error::new(io::ErrorKind::Other, "no greeting found")),
56        }?;
57        let client = Client {
58            transport,
59            state: ClientState::new(),
60        };
61
62        greeting.map(|greeting| (greeting, client))
63    }
64
65    pub fn call(&mut self, cmd: Command) -> ResponseStream<TlsStream<TcpStream>> {
66        ResponseStream::new(self, cmd)
67    }
68}
69
70pub struct ResponseStream<'a, T> {
71    client: &'a mut Client<T>,
72    request: Request,
73    next_state: Option<State>,
74    sending: bool,
75    done: bool,
76}
77
78impl<'a, T> ResponseStream<'a, T>
79where
80    T: AsyncRead + AsyncWrite + Unpin,
81{
82    pub fn new(client: &mut Client<T>, cmd: Command) -> ResponseStream<'_, T> {
83        let request_id = client.state.request_ids.next().unwrap(); // safe: never returns Err
84        let (cmd_bytes, next_state) = cmd.into_parts();
85        let request = Request(request_id, cmd_bytes);
86
87        ResponseStream {
88            client,
89            request,
90            next_state,
91            sending: true,
92            done: false,
93        }
94    }
95
96    #[allow(clippy::should_implement_trait)]
97    pub async fn next(&mut self) -> Option<Result<ResponseData, io::Error>> {
98        if self.done {
99            return None;
100        }
101
102        if self.sending {
103            match self.client.transport.send(self.request.clone()).await {
104                Ok(()) => {
105                    self.sending = false;
106                }
107                Err(e) => return Some(Err(e)),
108            }
109        }
110
111        match self.client.transport.next().await {
112            Some(Ok(rsp)) => {
113                if let Some(req_id) = rsp.request_id() {
114                    self.done = *req_id == self.request.0;
115                }
116
117                if self.done {
118                    if let Some(next_state) = self.next_state.take() {
119                        self.client.state.state = next_state;
120                    }
121                }
122
123                Some(Ok(rsp))
124            }
125            Some(Err(e)) => Some(Err(e)),
126            None => Some(Err(io::Error::new(
127                io::ErrorKind::Other,
128                "stream ended before command completion",
129            ))),
130        }
131    }
132
133    pub async fn try_collect(&mut self) -> Result<Vec<ResponseData>, io::Error> {
134        let mut data = vec![];
135        loop {
136            match self.next().await {
137                Some(Ok(rsp)) => {
138                    data.push(rsp);
139                }
140                Some(Err(e)) => return Err(e),
141                None => return Ok(data),
142            }
143        }
144    }
145
146    pub async fn try_for_each<F, Fut>(&mut self, mut f: F) -> Result<(), io::Error>
147    where
148        F: FnMut(ResponseData) -> Fut,
149        Fut: Future<Output = Result<(), io::Error>>,
150    {
151        loop {
152            match self.next().await {
153                Some(Ok(rsp)) => f(rsp).await?,
154                Some(Err(e)) => return Err(e),
155                None => return Ok(()),
156            }
157        }
158    }
159
160    pub async fn try_fold<S, Fut, F>(&mut self, mut state: S, mut f: F) -> Result<S, io::Error>
161    where
162        F: FnMut(S, ResponseData) -> Fut,
163        Fut: Future<Output = Result<S, io::Error>>,
164    {
165        loop {
166            match self.next().await {
167                Some(Ok(rsp)) => match f(state, rsp).await {
168                    Ok(new) => {
169                        state = new;
170                    }
171                    Err(e) => return Err(e),
172                },
173                Some(Err(e)) => return Err(e),
174                None => return Ok(state),
175            }
176        }
177    }
178}
179
180pub struct ClientState {
181    state: State,
182    request_ids: IdGenerator,
183}
184
185impl ClientState {
186    pub fn new() -> Self {
187        Self {
188            state: State::NotAuthenticated,
189            request_ids: IdGenerator::new(),
190        }
191    }
192}
193
194impl Default for ClientState {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200pub struct IdGenerator {
201    next: u64,
202}
203
204impl IdGenerator {
205    pub fn new() -> Self {
206        Self { next: 0 }
207    }
208}
209
210impl Default for IdGenerator {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216impl Iterator for IdGenerator {
217    type Item = RequestId;
218    fn next(&mut self) -> Option<Self::Item> {
219        self.next += 1;
220        Some(RequestId(format!("A{:04}", self.next % 10_000)))
221    }
222}