1use std::collections::HashMap;
4
5use anyhow::{Result, anyhow};
6use wacore::poll;
7use wacore_binary::jid::{Jid, JidExt};
8use waproto::whatsapp as wa;
9
10use crate::client::Client;
11
12#[derive(Debug, Clone)]
13pub struct PollOptionResult {
14 pub name: String,
15 pub voters: Vec<String>,
16}
17
18pub struct Polls<'a> {
19 client: &'a Client,
20}
21
22impl<'a> Polls<'a> {
23 pub(crate) fn new(client: &'a Client) -> Self {
24 Self { client }
25 }
26
27 pub async fn create(
29 &self,
30 to: &Jid,
31 name: &str,
32 options: &[String],
33 selectable_count: u32,
34 ) -> Result<(String, Vec<u8>)> {
35 if options.len() < 2 {
36 return Err(anyhow!("Poll must have at least 2 options"));
37 }
38 if options.len() > 12 {
39 return Err(anyhow!("Polls can have a maximum of 12 options"));
40 }
41 if selectable_count < 1 || selectable_count > options.len() as u32 {
42 return Err(anyhow!(
43 "selectable_count must be between 1 and {} (got {selectable_count})",
44 options.len()
45 ));
46 }
47
48 let mut seen = std::collections::HashSet::new();
50 for opt in options {
51 if !seen.insert(opt) {
52 return Err(anyhow!("Duplicate option name: {opt}"));
53 }
54 }
55
56 let poll_options: Vec<wa::message::poll_creation_message::Option> = options
57 .iter()
58 .map(|name| wa::message::poll_creation_message::Option {
59 option_name: Some(name.clone()),
60 option_hash: None,
61 })
62 .collect();
63
64 let poll_msg = wa::message::PollCreationMessage {
65 enc_key: None,
66 name: Some(name.to_string()),
67 options: poll_options,
68 selectable_options_count: Some(selectable_count),
69 context_info: None,
70 poll_content_type: None,
71 poll_type: None,
72 correct_answer: None,
73 };
74
75 let mut message = if selectable_count == 1 {
77 wa::Message {
78 poll_creation_message_v3: Some(Box::new(poll_msg)),
79 ..Default::default()
80 }
81 } else {
82 wa::Message {
83 poll_creation_message: Some(Box::new(poll_msg)),
84 ..Default::default()
85 }
86 };
87
88 let message_secret: Vec<u8> = {
91 use rand::Rng;
92 let mut secret = vec![0u8; 32];
93 rand::make_rng::<rand::rngs::StdRng>().fill_bytes(&mut secret);
94 secret
95 };
96
97 message.message_context_info = Some(wa::MessageContextInfo {
98 message_secret: Some(message_secret.clone()),
99 ..Default::default()
100 });
101
102 let msg_id = self.client.send_message(to.clone(), message).await?;
103 Ok((msg_id, message_secret))
104 }
105
106 pub async fn vote(
107 &self,
108 chat_jid: &Jid,
109 poll_msg_id: &str,
110 poll_creator_jid: &Jid,
111 message_secret: &[u8],
112 option_names: &[String],
113 ) -> Result<String> {
114 let my_jid = self
115 .client
116 .get_pn()
117 .await
118 .ok_or_else(|| anyhow!("Not logged in — cannot determine own JID"))?;
119 let voter_jid_str = my_jid.to_non_ad().to_string();
120 let creator_jid_str = poll_creator_jid.to_non_ad().to_string();
121
122 let selected_hashes: Vec<Vec<u8>> = option_names
123 .iter()
124 .map(|name| poll::compute_option_hash(name).to_vec())
125 .collect();
126
127 let key = poll::derive_vote_encryption_key(
128 message_secret,
129 poll_msg_id,
130 &creator_jid_str,
131 &voter_jid_str,
132 )?;
133
134 let (enc_payload, iv) =
135 poll::encrypt_poll_vote(&selected_hashes, &key, poll_msg_id, &voter_jid_str)?;
136
137 let from_me = my_jid.to_non_ad() == poll_creator_jid.to_non_ad();
138
139 let poll_update = wa::message::PollUpdateMessage {
140 poll_creation_message_key: Some(wa::MessageKey {
141 remote_jid: Some(chat_jid.to_string()),
142 from_me: Some(from_me),
143 id: Some(poll_msg_id.to_string()),
144 participant: if chat_jid.is_group() {
145 Some(poll_creator_jid.to_string())
146 } else {
147 None
148 },
149 }),
150 vote: Some(wa::message::PollEncValue {
151 enc_payload: Some(enc_payload),
152 enc_iv: Some(iv.to_vec()),
153 }),
154 metadata: Some(wa::message::PollUpdateMessageMetadata {}),
155 sender_timestamp_ms: Some(wacore::time::now_millis()),
156 };
157
158 let message = wa::Message {
159 poll_update_message: Some(poll_update),
160 ..Default::default()
161 };
162
163 self.client.send_message(chat_jid.clone(), message).await
164 }
165
166 pub fn decrypt_vote(
169 enc_payload: &[u8],
170 enc_iv: &[u8],
171 message_secret: &[u8],
172 poll_msg_id: &str,
173 poll_creator_jid: &Jid,
174 voter_jid: &Jid,
175 ) -> Result<Vec<Vec<u8>>> {
176 let creator = poll_creator_jid.to_non_ad().to_string();
177 let voter = voter_jid.to_non_ad().to_string();
178 let key = poll::derive_vote_encryption_key(message_secret, poll_msg_id, &creator, &voter)?;
179 poll::decrypt_poll_vote(enc_payload, enc_iv, &key, poll_msg_id, &voter)
180 }
181
182 pub fn aggregate_votes(
186 poll_options: &[String],
187 votes: &[(&Jid, &[u8], &[u8])], message_secret: &[u8],
189 poll_msg_id: &str,
190 poll_creator_jid: &Jid,
191 ) -> Result<Vec<PollOptionResult>> {
192 let option_hashes: Vec<([u8; 32], &str)> = poll_options
193 .iter()
194 .map(|name| (poll::compute_option_hash(name), name.as_str()))
195 .collect();
196
197 let mut latest_votes: HashMap<String, Vec<Vec<u8>>> = HashMap::new();
199 for (voter_jid, enc_payload, enc_iv) in votes {
200 let voter_key = voter_jid.to_non_ad().to_string();
201 match Self::decrypt_vote(
202 enc_payload,
203 enc_iv,
204 message_secret,
205 poll_msg_id,
206 poll_creator_jid,
207 voter_jid,
208 ) {
209 Ok(selected_hashes) => {
210 if selected_hashes.is_empty() {
211 latest_votes.remove(&voter_key);
213 } else {
214 latest_votes.insert(voter_key, selected_hashes);
215 }
216 }
217 Err(e) => {
218 log::warn!("Failed to decrypt vote from {voter_jid}: {e}");
219 }
220 }
221 }
222
223 let mut results: Vec<PollOptionResult> = poll_options
224 .iter()
225 .map(|name| PollOptionResult {
226 name: name.clone(),
227 voters: Vec::new(),
228 })
229 .collect();
230
231 for (voter_jid, selected_hashes) in &latest_votes {
232 for hash in selected_hashes {
233 if let Ok(hash_arr) = <[u8; 32]>::try_from(hash.as_slice())
234 && let Some(idx) = option_hashes.iter().position(|(h, _)| *h == hash_arr)
235 {
236 results[idx].voters.push(voter_jid.clone());
237 }
238 }
239 }
240
241 Ok(results)
242 }
243}
244
245impl Client {
246 pub fn polls(&self) -> Polls<'_> {
247 Polls::new(self)
248 }
249}