Skip to main content

whatsapp_rust/features/
polls.rs

1//! Poll creation, voting, and vote decryption.
2
3use std::collections::HashMap;
4
5use anyhow::{Result, anyhow};
6use wacore::poll;
7use wacore_binary::jid::{Jid, JidExt};
8use waproto::whatsapp as wa;
9
10use crate::client::Client;
11
12#[derive(Debug, Clone)]
13pub struct PollOptionResult {
14    pub name: String,
15    pub voters: Vec<String>,
16}
17
18pub struct Polls<'a> {
19    client: &'a Client,
20}
21
22impl<'a> Polls<'a> {
23    pub(crate) fn new(client: &'a Client) -> Self {
24        Self { client }
25    }
26
27    /// Returns `(message_id, message_secret)`. Caller needs `message_secret` to decrypt votes.
28    pub async fn create(
29        &self,
30        to: &Jid,
31        name: &str,
32        options: &[String],
33        selectable_count: u32,
34    ) -> Result<(String, Vec<u8>)> {
35        if options.len() < 2 {
36            return Err(anyhow!("Poll must have at least 2 options"));
37        }
38        if options.len() > 12 {
39            return Err(anyhow!("Polls can have a maximum of 12 options"));
40        }
41        if selectable_count < 1 || selectable_count > options.len() as u32 {
42            return Err(anyhow!(
43                "selectable_count must be between 1 and {} (got {selectable_count})",
44                options.len()
45            ));
46        }
47
48        // Duplicate names would produce identical SHA-256 hashes, making votes indistinguishable
49        let mut seen = std::collections::HashSet::new();
50        for opt in options {
51            if !seen.insert(opt) {
52                return Err(anyhow!("Duplicate option name: {opt}"));
53            }
54        }
55
56        let poll_options: Vec<wa::message::poll_creation_message::Option> = options
57            .iter()
58            .map(|name| wa::message::poll_creation_message::Option {
59                option_name: Some(name.clone()),
60                option_hash: None,
61            })
62            .collect();
63
64        let poll_msg = wa::message::PollCreationMessage {
65            enc_key: None,
66            name: Some(name.to_string()),
67            options: poll_options,
68            selectable_options_count: Some(selectable_count),
69            context_info: None,
70            poll_content_type: None,
71            poll_type: None,
72            correct_answer: None,
73        };
74
75        // WA Web: v3 for single-select, v1 for multi-select (GeneratePollCreationMessageProto.js:39-41)
76        let mut message = if selectable_count == 1 {
77            wa::Message {
78                poll_creation_message_v3: Some(Box::new(poll_msg)),
79                ..Default::default()
80            }
81        } else {
82            wa::Message {
83                poll_creation_message: Some(Box::new(poll_msg)),
84                ..Default::default()
85            }
86        };
87
88        // WA Web generates a 32-byte random secret at poll creation time
89        // (SendPollCreationMsgAction.js:158). Voters need this to derive their encryption key.
90        let message_secret: Vec<u8> = {
91            use rand::Rng;
92            let mut secret = vec![0u8; 32];
93            rand::make_rng::<rand::rngs::StdRng>().fill_bytes(&mut secret);
94            secret
95        };
96
97        message.message_context_info = Some(wa::MessageContextInfo {
98            message_secret: Some(message_secret.clone()),
99            ..Default::default()
100        });
101
102        let msg_id = self.client.send_message(to.clone(), message).await?;
103        Ok((msg_id, message_secret))
104    }
105
106    pub async fn vote(
107        &self,
108        chat_jid: &Jid,
109        poll_msg_id: &str,
110        poll_creator_jid: &Jid,
111        message_secret: &[u8],
112        option_names: &[String],
113    ) -> Result<String> {
114        let my_jid = self
115            .client
116            .get_pn()
117            .await
118            .ok_or_else(|| anyhow!("Not logged in — cannot determine own JID"))?;
119        let voter_jid_str = my_jid.to_non_ad().to_string();
120        let creator_jid_str = poll_creator_jid.to_non_ad().to_string();
121
122        let selected_hashes: Vec<Vec<u8>> = option_names
123            .iter()
124            .map(|name| poll::compute_option_hash(name).to_vec())
125            .collect();
126
127        let key = poll::derive_vote_encryption_key(
128            message_secret,
129            poll_msg_id,
130            &creator_jid_str,
131            &voter_jid_str,
132        )?;
133
134        let (enc_payload, iv) =
135            poll::encrypt_poll_vote(&selected_hashes, &key, poll_msg_id, &voter_jid_str)?;
136
137        let from_me = my_jid.to_non_ad() == poll_creator_jid.to_non_ad();
138
139        let poll_update = wa::message::PollUpdateMessage {
140            poll_creation_message_key: Some(wa::MessageKey {
141                remote_jid: Some(chat_jid.to_string()),
142                from_me: Some(from_me),
143                id: Some(poll_msg_id.to_string()),
144                participant: if chat_jid.is_group() {
145                    Some(poll_creator_jid.to_string())
146                } else {
147                    None
148                },
149            }),
150            vote: Some(wa::message::PollEncValue {
151                enc_payload: Some(enc_payload),
152                enc_iv: Some(iv.to_vec()),
153            }),
154            metadata: Some(wa::message::PollUpdateMessageMetadata {}),
155            sender_timestamp_ms: Some(wacore::time::now_millis()),
156        };
157
158        let message = wa::Message {
159            poll_update_message: Some(poll_update),
160            ..Default::default()
161        };
162
163        self.client.send_message(chat_jid.clone(), message).await
164    }
165
166    /// Returns the selected option hashes (each 32 bytes).
167    /// JIDs are normalized (AD suffix stripped) to match the key derivation in `vote()`.
168    pub fn decrypt_vote(
169        enc_payload: &[u8],
170        enc_iv: &[u8],
171        message_secret: &[u8],
172        poll_msg_id: &str,
173        poll_creator_jid: &Jid,
174        voter_jid: &Jid,
175    ) -> Result<Vec<Vec<u8>>> {
176        let creator = poll_creator_jid.to_non_ad().to_string();
177        let voter = voter_jid.to_non_ad().to_string();
178        let key = poll::derive_vote_encryption_key(message_secret, poll_msg_id, &creator, &voter)?;
179        poll::decrypt_poll_vote(enc_payload, enc_iv, &key, poll_msg_id, &voter)
180    }
181
182    /// Decrypts each vote and tallies per-option results.
183    /// Later votes from the same voter replace earlier ones (last-vote-wins).
184    /// `votes` should be ordered oldest-first.
185    pub fn aggregate_votes(
186        poll_options: &[String],
187        votes: &[(&Jid, &[u8], &[u8])], // (voter_jid, enc_payload, enc_iv)
188        message_secret: &[u8],
189        poll_msg_id: &str,
190        poll_creator_jid: &Jid,
191    ) -> Result<Vec<PollOptionResult>> {
192        let option_hashes: Vec<([u8; 32], &str)> = poll_options
193            .iter()
194            .map(|name| (poll::compute_option_hash(name), name.as_str()))
195            .collect();
196
197        // Last-vote-wins: each new vote from the same voter replaces the previous
198        let mut latest_votes: HashMap<String, Vec<Vec<u8>>> = HashMap::new();
199        for (voter_jid, enc_payload, enc_iv) in votes {
200            let voter_key = voter_jid.to_non_ad().to_string();
201            match Self::decrypt_vote(
202                enc_payload,
203                enc_iv,
204                message_secret,
205                poll_msg_id,
206                poll_creator_jid,
207                voter_jid,
208            ) {
209                Ok(selected_hashes) => {
210                    if selected_hashes.is_empty() {
211                        // Empty selection = voter cleared their vote
212                        latest_votes.remove(&voter_key);
213                    } else {
214                        latest_votes.insert(voter_key, selected_hashes);
215                    }
216                }
217                Err(e) => {
218                    log::warn!("Failed to decrypt vote from {voter_jid}: {e}");
219                }
220            }
221        }
222
223        let mut results: Vec<PollOptionResult> = poll_options
224            .iter()
225            .map(|name| PollOptionResult {
226                name: name.clone(),
227                voters: Vec::new(),
228            })
229            .collect();
230
231        for (voter_jid, selected_hashes) in &latest_votes {
232            for hash in selected_hashes {
233                if let Ok(hash_arr) = <[u8; 32]>::try_from(hash.as_slice())
234                    && let Some(idx) = option_hashes.iter().position(|(h, _)| *h == hash_arr)
235                {
236                    results[idx].voters.push(voter_jid.clone());
237                }
238            }
239        }
240
241        Ok(results)
242    }
243}
244
245impl Client {
246    pub fn polls(&self) -> Polls<'_> {
247        Polls::new(self)
248    }
249}