1use std::future::Future;
2use std::io;
3use std::net::ToSocketAddrs;
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio::net::TcpStream;
9use tokio_rustls::rustls::ClientConfig;
10use tokio_rustls::webpki::DNSNameRef;
11use tokio_rustls::{client::TlsStream, TlsConnector};
12use tokio_util::codec::Decoder;
13
14use crate::proto::{ImapCodec, ImapTransport, ResponseData};
15use imap_proto::builders::command::Command;
16use imap_proto::{Request, RequestId, State};
17
18pub mod builder {
19 pub use imap_proto::builders::command::{
20 CommandBuilder, FetchBuilderAttributes, FetchBuilderMessages, FetchBuilderModifiers,
21 FetchCommand, FetchCommandAttributes, FetchCommandMessages,
22 };
23}
24
25pub type TlsClient = Client<TlsStream<TcpStream>>;
26
27pub struct Client<T> {
28 transport: ImapTransport<T>,
29 state: ClientState,
30}
31
32impl TlsClient {
33 pub async fn connect(server: &str) -> io::Result<(ResponseData, Self)> {
34 let addr = (server, 993).to_socket_addrs()?.next().ok_or_else(|| {
35 io::Error::new(
36 io::ErrorKind::Other,
37 format!("no IP addresses found for {}", server),
38 )
39 })?;
40
41 let mut tls_config = ClientConfig::new();
42 tls_config
43 .root_store
44 .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
45 let connector: TlsConnector = Arc::new(tls_config).into();
46
47 let stream = TcpStream::connect(&addr).await?;
48 let stream = connector
49 .connect(DNSNameRef::try_from_ascii_str(server).unwrap(), stream)
50 .await?;
51 let mut transport = ImapCodec::default().framed(stream);
52
53 let greeting = match transport.next().await {
54 Some(greeting) => Ok(greeting),
55 None => Err(io::Error::new(io::ErrorKind::Other, "no greeting found")),
56 }?;
57 let client = Client {
58 transport,
59 state: ClientState::new(),
60 };
61
62 greeting.map(|greeting| (greeting, client))
63 }
64
65 pub fn call(&mut self, cmd: Command) -> ResponseStream<TlsStream<TcpStream>> {
66 ResponseStream::new(self, cmd)
67 }
68}
69
70pub struct ResponseStream<'a, T> {
71 client: &'a mut Client<T>,
72 request: Request,
73 next_state: Option<State>,
74 sending: bool,
75 done: bool,
76}
77
78impl<'a, T> ResponseStream<'a, T>
79where
80 T: AsyncRead + AsyncWrite + Unpin,
81{
82 pub fn new(client: &mut Client<T>, cmd: Command) -> ResponseStream<'_, T> {
83 let request_id = client.state.request_ids.next().unwrap(); let (cmd_bytes, next_state) = cmd.into_parts();
85 let request = Request(request_id, cmd_bytes);
86
87 ResponseStream {
88 client,
89 request,
90 next_state,
91 sending: true,
92 done: false,
93 }
94 }
95
96 #[allow(clippy::should_implement_trait)]
97 pub async fn next(&mut self) -> Option<Result<ResponseData, io::Error>> {
98 if self.done {
99 return None;
100 }
101
102 if self.sending {
103 match self.client.transport.send(self.request.clone()).await {
104 Ok(()) => {
105 self.sending = false;
106 }
107 Err(e) => return Some(Err(e)),
108 }
109 }
110
111 match self.client.transport.next().await {
112 Some(Ok(rsp)) => {
113 if let Some(req_id) = rsp.request_id() {
114 self.done = *req_id == self.request.0;
115 }
116
117 if self.done {
118 if let Some(next_state) = self.next_state.take() {
119 self.client.state.state = next_state;
120 }
121 }
122
123 Some(Ok(rsp))
124 }
125 Some(Err(e)) => Some(Err(e)),
126 None => Some(Err(io::Error::new(
127 io::ErrorKind::Other,
128 "stream ended before command completion",
129 ))),
130 }
131 }
132
133 pub async fn try_collect(&mut self) -> Result<Vec<ResponseData>, io::Error> {
134 let mut data = vec![];
135 loop {
136 match self.next().await {
137 Some(Ok(rsp)) => {
138 data.push(rsp);
139 }
140 Some(Err(e)) => return Err(e),
141 None => return Ok(data),
142 }
143 }
144 }
145
146 pub async fn try_for_each<F, Fut>(&mut self, mut f: F) -> Result<(), io::Error>
147 where
148 F: FnMut(ResponseData) -> Fut,
149 Fut: Future<Output = Result<(), io::Error>>,
150 {
151 loop {
152 match self.next().await {
153 Some(Ok(rsp)) => f(rsp).await?,
154 Some(Err(e)) => return Err(e),
155 None => return Ok(()),
156 }
157 }
158 }
159
160 pub async fn try_fold<S, Fut, F>(&mut self, mut state: S, mut f: F) -> Result<S, io::Error>
161 where
162 F: FnMut(S, ResponseData) -> Fut,
163 Fut: Future<Output = Result<S, io::Error>>,
164 {
165 loop {
166 match self.next().await {
167 Some(Ok(rsp)) => match f(state, rsp).await {
168 Ok(new) => {
169 state = new;
170 }
171 Err(e) => return Err(e),
172 },
173 Some(Err(e)) => return Err(e),
174 None => return Ok(state),
175 }
176 }
177 }
178}
179
180pub struct ClientState {
181 state: State,
182 request_ids: IdGenerator,
183}
184
185impl ClientState {
186 pub fn new() -> Self {
187 Self {
188 state: State::NotAuthenticated,
189 request_ids: IdGenerator::new(),
190 }
191 }
192}
193
194impl Default for ClientState {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200pub struct IdGenerator {
201 next: u64,
202}
203
204impl IdGenerator {
205 pub fn new() -> Self {
206 Self { next: 0 }
207 }
208}
209
210impl Default for IdGenerator {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl Iterator for IdGenerator {
217 type Item = RequestId;
218 fn next(&mut self) -> Option<Self::Item> {
219 self.next += 1;
220 Some(RequestId(format!("A{:04}", self.next % 10_000)))
221 }
222}