1use crate::inner::{create_client, decode, uncompress};
2use crate::{VkApiError, VkApiResult};
3use bytes::Buf;
4use cfg_if::cfg_if;
5use reqwest::header::{ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
6use reqwest::Client;
7use serde::de::DeserializeOwned;
8use serde::{Deserialize, Deserializer, Serialize};
9use std::error::Error;
10use std::fmt::{Display, Formatter};
11
12#[derive(Debug, Clone)]
32pub struct VkLongPoll {
33 client: Client,
34}
35
36impl VkLongPoll {
37 #[cfg(feature = "longpoll_stream")]
60 pub fn subscribe<T: Serialize + Clone + Send, I: DeserializeOwned>(
61 &self,
62 mut request: LongPollRequest<T>,
63 ) -> impl futures_util::Stream<Item = VkApiResult<I>> {
64 let client = self.client.clone();
65
66 async_stream::stream! {
67 loop {
68 match Self::subscribe_once_with_client(&client, request.clone()).await {
69 Err(VkApiError::LongPoll(LongPollError { ts: Some(ts), .. })) => {
70 request.ts = ts;
71 },
72 Ok(LongPollSuccess{ ts, updates }) => {
73 request.ts = ts.clone();
74 for update in updates {
75 yield Ok(update);
76 }
77 },
78 Err(e) => {
79 yield Err(e);
80 break;
81 },
82 };
83 }
84 }
85 }
86
87 pub async fn subscribe_once<T: Serialize + Send, I: DeserializeOwned>(
108 &self,
109 request: LongPollRequest<T>,
110 ) -> VkApiResult<LongPollSuccess<I>> {
111 Self::subscribe_once_with_client(&self.client, request).await
112 }
113
114 async fn subscribe_once_with_client<T: Serialize + Send, I: DeserializeOwned>(
115 client: &Client,
116 request: LongPollRequest<T>,
117 ) -> VkApiResult<LongPollSuccess<I>> {
118 let LongPollInnerRequest(LongPollServer(server), params) =
119 LongPollInnerRequest::from(request);
120
121 let params = serde_urlencoded::to_string(params).map_err(VkApiError::RequestSerialize)?;
122
123 let url = if server.starts_with("http") {
124 format!("{server}?act=a_check&{params}")
125 } else {
126 format!("https://{server}?act=a_check&{params}")
127 };
128
129 cfg_if! {
130 if #[cfg(feature = "compression_gzip")] {
131 let encoding = "gzip";
132 } else {
133 let encoding = "identity";
134 }
135 }
136
137 cfg_if! {
138 if #[cfg(feature = "encode_json")] {
139 let serialisation = "application/json";
140 } else {
141 let serialisation = "text/*";
142 }
143 }
144
145 let request = client
146 .get(url)
147 .header(ACCEPT_ENCODING, encoding)
148 .header(ACCEPT, serialisation);
149
150 let response = request.send().await.map_err(VkApiError::Request)?;
151
152 let headers = response.headers();
153
154 let content_type = headers.get(CONTENT_TYPE).cloned();
155 let content_encoding = headers.get(CONTENT_ENCODING).cloned();
156
157 let body = response.bytes().await.map_err(VkApiError::Request)?;
158
159 let resp = decode::<LongPollResponse<I>, _>(
160 content_type,
161 uncompress(content_encoding, body.reader())?,
162 )?;
163
164 match resp {
165 LongPollResponse::Success(r) => Ok(r),
166 LongPollResponse::Error(e) => Err(VkApiError::LongPoll(e)),
167 }
168 }
169}
170
171impl From<Client> for VkLongPoll {
172 fn from(client: Client) -> Self {
173 Self { client }
174 }
175}
176
177impl Default for VkLongPoll {
178 fn default() -> Self {
179 Self::from(create_client())
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(untagged)]
185enum LongPollResponse<R> {
186 Success(LongPollSuccess<R>),
187 Error(LongPollError),
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct LongPollSuccess<R> {
193 #[serde(deserialize_with = "deserialize_usize_or_string")]
194 ts: String,
195 updates: Vec<R>,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct LongPollError {
202 failed: usize,
203 #[serde(default)]
204 #[serde(deserialize_with = "deserialize_usize_or_string_option")]
205 ts: Option<String>,
206 #[serde(default)]
207 min_version: Option<usize>,
208 #[serde(default)]
209 max_version: Option<usize>,
210}
211
212impl Display for LongPollError {
213 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214 write!(f, "long poll error occurred, code: {}", self.failed,)
215 }
216}
217
218impl Error for LongPollError {}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct LongPollRequest<T> {
226 pub server: String,
227 pub key: String,
228 #[serde(deserialize_with = "deserialize_usize_or_string")]
229 pub ts: String,
230 pub wait: usize,
231 #[serde(flatten)]
232 pub additional_params: T,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236struct LongPollServer(String);
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239struct LongPollQueryParams<T> {
240 key: String,
241 #[serde(deserialize_with = "deserialize_usize_or_string")]
242 ts: String,
243 wait: usize,
244 #[serde(flatten)]
245 additional_params: T,
246}
247
248struct LongPollInnerRequest<T>(LongPollServer, LongPollQueryParams<T>);
249
250impl<T> From<LongPollRequest<T>> for LongPollInnerRequest<T> {
251 fn from(
252 LongPollRequest {
253 server,
254 key,
255 ts,
256 wait,
257 additional_params,
258 }: LongPollRequest<T>,
259 ) -> Self {
260 Self(
261 LongPollServer(server),
262 LongPollQueryParams {
263 key,
264 ts,
265 wait,
266 additional_params,
267 },
268 )
269 }
270}
271
272struct DeserializeUsizeOrString;
273
274impl<'de> serde::de::Visitor<'de> for DeserializeUsizeOrString {
275 type Value = String;
276
277 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
278 formatter.write_str("an integer or a string")
279 }
280
281 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
282 where
283 E: serde::de::Error,
284 {
285 Ok(v.to_string())
286 }
287
288 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
289 where
290 E: serde::de::Error,
291 {
292 Ok(v.to_owned())
293 }
294
295 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
296 where
297 E: serde::de::Error,
298 {
299 Ok(v)
300 }
301}
302
303struct DeserializeUsizeOrStringOption;
304
305impl<'de> serde::de::Visitor<'de> for DeserializeUsizeOrStringOption {
306 type Value = Option<String>;
307
308 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
309 formatter.write_str("an integer or a string or a null")
310 }
311
312 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
313 where
314 E: serde::de::Error,
315 {
316 Ok(Some(v.to_string()))
317 }
318
319 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
320 where
321 E: serde::de::Error,
322 {
323 Ok(Some(v.to_owned()))
324 }
325
326 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
327 where
328 E: serde::de::Error,
329 {
330 Ok(Some(v))
331 }
332
333 fn visit_none<E>(self) -> Result<Self::Value, E>
334 where
335 E: serde::de::Error,
336 {
337 Ok(None)
338 }
339}
340
341fn deserialize_usize_or_string<'de, D>(
342 deserializer: D,
343) -> Result<String, <D as Deserializer<'de>>::Error>
344where
345 D: Deserializer<'de>,
346{
347 deserializer.deserialize_any(DeserializeUsizeOrString)
348}
349
350fn deserialize_usize_or_string_option<'de, D>(
351 deserializer: D,
352) -> Result<Option<String>, <D as Deserializer<'de>>::Error>
353where
354 D: Deserializer<'de>,
355{
356 deserializer.deserialize_any(DeserializeUsizeOrStringOption)
357}
358
359#[cfg(test)]
360mod tests {
361 use crate::longpoll::{deserialize_usize_or_string, deserialize_usize_or_string_option};
362 use serde::Deserialize;
363
364 #[derive(Deserialize)]
365 struct Ts {
366 #[serde(deserialize_with = "deserialize_usize_or_string")]
367 ts: String,
368 }
369
370 #[derive(Deserialize)]
371 struct TsOpt {
372 #[serde(default)]
373 #[serde(deserialize_with = "deserialize_usize_or_string_option")]
374 ts: Option<String>,
375 }
376
377 #[test]
378 fn test_deserialize_ts_string() {
379 let ts: Ts = serde_json::from_str(r#"{"ts": "123"}"#).unwrap();
380 assert_eq!(ts.ts, "123".to_owned());
381 }
382
383 #[test]
384 fn test_deserialize_ts_usize() {
385 let ts: Ts = serde_json::from_str(r#"{"ts": 123}"#).unwrap();
386 assert_eq!(ts.ts, "123".to_owned());
387 }
388
389 #[test]
390 fn test_deserialize_ts_opt_string() {
391 let ts: TsOpt = serde_json::from_str(r#"{"ts": "123"}"#).unwrap();
392 assert_eq!(ts.ts, Some("123".to_owned()));
393 }
394
395 #[test]
396 fn test_deserialize_ts_opt_usize() {
397 let ts: TsOpt = serde_json::from_str(r#"{"ts": 123}"#).unwrap();
398 assert_eq!(ts.ts, Some("123".to_owned()));
399 }
400
401 #[test]
402 fn test_deserialize_ts_opt_none() {
403 let ts: TsOpt = serde_json::from_str("{}").unwrap();
404 assert_eq!(ts.ts, None);
405 }
406}