threema_client/
messaging_client.rs1use tokio::sync::oneshot;
4use std::sync::atomic::AtomicU32;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::{transport,Credentials};
9use crate::transport::{ThreemaServer, Ack, BoxedMessage};
10
11#[derive(Clone, Debug)]
12pub enum Event{
13 Connected,
14 Disconnected,
15 BoxedMessage(transport::BoxedMessage),
16 Alert(String),
17 Error{reconnect_allowed: bool, message: String},
18 QueueSendComplete,
19}
20
21#[derive(Clone, Debug)]
22pub enum Closed{
23 Shutdown,
24 LoginFailed,
25 Rejected(String),
26}
27
28pub struct Client {
29 send_queue: tokio::sync::mpsc::UnboundedSender<SendItem>,
30 event_queue: tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<Event>>,
31 jh: tokio::task::JoinHandle<Closed>,
32}
33impl Client {
34 pub fn new(server: ThreemaServer, creds: Credentials) -> Self{
35 let (send_queue_s,send_queue_r) = tokio::sync::mpsc::unbounded_channel();
37 let (event_queue_s,event_queue_r) = tokio::sync::mpsc::unbounded_channel();
38
39 let jh = tokio::task::spawn(Self::keep_alive(server, creds, event_queue_s, send_queue_r));
40 Client{
41 send_queue: send_queue_s,
42 event_queue: tokio::sync::Mutex::new(event_queue_r),
43 jh,
44 }
45 }
46 pub async fn event(&self) -> Option<Event> {
48 self.event_queue.lock().await.recv().await
49 }
50 pub async fn send_message(&self, msg: transport::BoxedMessage) -> Result<(),Closed>{
53 let (signal_s, signal_r) = tokio::sync::oneshot::channel();
55 log::trace!("Putting mesage to sent queue");
58 self.send_queue.send(SendItem::Message(msg, signal_s)).map_err(|_| Closed::Shutdown)?;
59 signal_r.await.map_err(|_| Closed::Shutdown)?;
60 Ok(())
61 }
62 pub async fn send_ack(&self, ack: &Ack) -> Result<(), Closed>{
63 self.send_queue.send(SendItem::Ack(*ack)).map_err(|_| Closed::Shutdown)?;
64 Ok(())
65 }
67
68 pub fn shutdown(&self) {
71 self.jh.abort();
72 }
73
74 async fn keep_alive(server: ThreemaServer, creds: Credentials,
75 event_queue: tokio::sync::mpsc::UnboundedSender<Event>,
76 mut send_queue: tokio::sync::mpsc::UnboundedReceiver<SendItem>,
77 ) -> Closed {
78 const BACKOFF_MAX: Duration = Duration::from_secs(180);
79 const BACKOFF_MIN: Duration = Duration::from_millis(50);
80 const KEEP_ALIVE: Duration = Duration::from_secs(180);
81 let poison_pill = Arc::new(tokio::sync::Notify::new());
83 'reconnect: loop {
84 let mut write_half;
85 let mut recv_task;
87
88 let mut backoff = BACKOFF_MIN;
89 let last_rcvd_echo_seq = Arc::new(AtomicU32::new(0));
90 let mut next_echo;
91 let mut last_sent_echo_seq = 0;
92 let ack_queue = AckQueue::default();
93 loop {
94 let res = transport::connect(&server, &creds).await;
95 match res {
96 Ok((r,w)) => {
97 write_half = w;
98 recv_task = AbortOnDrop{task:
99 tokio::task::spawn(Self::receive_loop(r, event_queue.clone(), last_rcvd_echo_seq.clone(), ack_queue.clone()))};
100 break;
101 }
102 Err(e) => {
103 log::warn!("Connection failed: {}, reconnecting in {}ms", e, backoff.as_millis());
104 tokio::time::sleep(backoff).await;
105 backoff = std::cmp::min(BACKOFF_MAX, 2*backoff);
106 }
107 }
108 };
109
110 let previously_sent = ack_queue.lock().unwrap().values().map(|(m,_ping)| m.clone()).collect::<Vec<_>>();
112 for msg in previously_sent {
113 let sent = write_half.send_message(&msg, transport::ClientToServer).await;
114 if sent.is_err() {
115 continue 'reconnect;
116 }
117 }
118 let sent = event_queue.send(Event::Connected);
119 if sent.is_err() {
120 return Closed::Shutdown;
121 }
122 next_echo = tokio::time::Instant::now() + KEEP_ALIVE;
123 loop {
124 tokio::select!{
125 _a = &mut recv_task.task => {
126 log::trace!("receiver ended");
127 break; }
130 () = tokio::time::sleep_until(next_echo) => {
131 log::trace!("echo timer");
132 if last_rcvd_echo_seq.load(std::sync::atomic::Ordering::SeqCst) != last_sent_echo_seq {
134 break;
135 }
136 last_sent_echo_seq = last_sent_echo_seq.wrapping_add(1);
138 let sent = write_half.echo_request(last_sent_echo_seq).await;
139 if sent.is_err() {
140 break;
141 }
142 next_echo = tokio::time::Instant::now() + KEEP_ALIVE;
143 }
144 msgo = send_queue.recv() => {
145 log::trace!("got from send queue: {:?}", msgo);
146 match msgo {
147 Some(SendItem::Message(msg, pingback)) => {
148
149 let expected_ack = msg.envelope.recipient_ack();
150 let sent = write_half.send_message(&msg, transport::ClientToServer).await;
151 ack_queue.lock().unwrap().insert(expected_ack,(msg,pingback));
152 if let Err(e) = sent {
153 log::trace!("error while sending: {}", &e);
154 break;
155 }
156 }
157 Some(SendItem::Ack(ref ack)) => {
158 let sent = write_half.send_ack(ack, transport::Direction::ClientToServer).await;
159 if let Err(e) = sent {
160 log::trace!("error while sending: {}", &e);
161 break;
162 }
163 }
164 None => {
165 poison_pill.notify_waiters();
166 poison_pill.notify_one();
167 return Closed::Shutdown;
168 }
169 }
170 }
171 };
172 }
173 let _ = event_queue.send(Event::Disconnected);
174 }
176 }
177
178 async fn receive_loop(
179 mut r: transport::ReadHalf,
180 event_queue: tokio::sync::mpsc::UnboundedSender<Event>,
181 last_received_echo: Arc<AtomicU32>,
182 ack_queue: AckQueue,
183 ) {
184 loop {
185 let res = r.receive_packet().await;
186 match res {
187 Ok(packet) => {
188 match packet {
189 transport::Packet::BoxedMessageDownload(m) => {
190 let _ = event_queue.send(Event::BoxedMessage(m));
191 }
192 transport::Packet::QueueSendComplete => {
193 let _ = event_queue.send(Event::QueueSendComplete);
194 }
195 transport::Packet::EchoReply(i) => {
196 last_received_echo.store(i, std::sync::atomic::Ordering::SeqCst);
197 }
198 transport::Packet::AckUpload(ack) =>{
199 if let Some((_, signal)) = ack_queue.lock().unwrap().remove(&ack) {
200 let _res = signal.send(());
201 }
202 else {
203 log::warn!("INCOMING_MESSAGE_ACK for unknown message: {:?}", ack);
204 }
205 }
206 unexpected => {
207 log::warn!("Received packet with unexpected payload type from server: {:?}", unexpected);
208 }
209 }
210 }
211 Err(e) => {
212 log::warn!("receive_loop: {}", e);
213 return;
214 }
215 }
216 }
217 }
218}
219type AckQueue = Arc<std::sync::Mutex<std::collections::HashMap<Ack, (BoxedMessage, oneshot::Sender<()>)>>>;
220#[derive(Debug)]
221enum SendItem {
222 Message(transport::BoxedMessage, tokio::sync::oneshot::Sender<()>),
223 Ack(Ack)
224}
225struct AbortOnDrop<T>{task: tokio::task::JoinHandle<T>}
226impl<T> Drop for AbortOnDrop<T> {
227 fn drop(&mut self) {
228 log::trace!("AbortOnDrop::drop");
229 self.task.abort();
230 }
231}