wacore/
prekeys.rs

1use crate::libsignal::protocol::{IdentityKey, PreKeyBundle, PreKeyId, PublicKey, SignedPreKeyId};
2use crate::xml::DisplayableNode;
3use std::collections::HashMap;
4use wacore_binary::builder::NodeBuilder;
5use wacore_binary::jid::Jid;
6use wacore_binary::node::{Node, NodeContent};
7
8pub struct PreKeyUtils;
9
10impl PreKeyUtils {
11    pub fn build_fetch_prekeys_request(jids: &[Jid], reason: Option<&str>) -> Node {
12        let user_nodes = jids.iter().map(|jid| {
13            let mut user_builder = NodeBuilder::new("user").attr("jid", jid.to_string());
14            if let Some(r) = reason {
15                user_builder = user_builder.attr("reason", r);
16            }
17            user_builder.build()
18        });
19
20        NodeBuilder::new("key").children(user_nodes).build()
21    }
22
23    pub fn build_upload_prekeys_request(
24        registration_id: u32,
25        identity_key_bytes: Vec<u8>,
26        signed_pre_key_id: u32,
27        signed_pre_key_public_bytes: Vec<u8>,
28        signed_pre_key_signature: Vec<u8>,
29        pre_keys: &[(u32, Vec<u8>)],
30    ) -> Vec<Node> {
31        let mut pre_key_nodes = Vec::new();
32        for (pre_key_id, public_bytes) in pre_keys {
33            let id_bytes = pre_key_id.to_be_bytes()[1..].to_vec();
34            let node = NodeBuilder::new("key")
35                .children([
36                    NodeBuilder::new("id").bytes(id_bytes).build(),
37                    NodeBuilder::new("value")
38                        .bytes(public_bytes.clone())
39                        .build(),
40                ])
41                .build();
42            pre_key_nodes.push(node);
43        }
44
45        let registration_id_bytes = registration_id.to_be_bytes().to_vec();
46
47        let signed_pre_key_node = NodeBuilder::new("skey")
48            .children([
49                NodeBuilder::new("id")
50                    .bytes(signed_pre_key_id.to_be_bytes()[1..].to_vec())
51                    .build(),
52                NodeBuilder::new("value")
53                    .bytes(signed_pre_key_public_bytes)
54                    .build(),
55                NodeBuilder::new("signature")
56                    .bytes(signed_pre_key_signature)
57                    .build(),
58            ])
59            .build();
60
61        let type_bytes = vec![5u8];
62
63        vec![
64            NodeBuilder::new("registration")
65                .bytes(registration_id_bytes)
66                .build(),
67            NodeBuilder::new("type").bytes(type_bytes.clone()).build(),
68            NodeBuilder::new("identity")
69                .bytes(identity_key_bytes)
70                .build(),
71            NodeBuilder::new("list").children(pre_key_nodes).build(),
72            signed_pre_key_node,
73        ]
74    }
75
76    pub fn parse_prekeys_response(
77        resp_node: &Node,
78    ) -> Result<HashMap<Jid, PreKeyBundle>, anyhow::Error> {
79        let list_node = resp_node
80            .get_optional_child("list")
81            .ok_or_else(|| anyhow::anyhow!("<list> not found in pre-key response"))?;
82
83        let mut bundles = HashMap::new();
84        for user_node in list_node.children().unwrap_or_default() {
85            if user_node.tag != "user" {
86                continue;
87            }
88            let mut attrs = user_node.attrs();
89            let jid = attrs.jid("jid");
90            let bundle = match Self::node_to_pre_key_bundle(&jid, user_node) {
91                Ok(b) => b,
92                Err(_e) => {
93                    continue;
94                }
95            };
96            bundles.insert(jid, bundle);
97        }
98
99        Ok(bundles)
100    }
101
102    fn node_to_pre_key_bundle(jid: &Jid, node: &Node) -> Result<PreKeyBundle, anyhow::Error> {
103        fn extract_bytes(node: Option<&Node>) -> Result<Vec<u8>, anyhow::Error> {
104            match node.and_then(|n| n.content.as_ref()) {
105                Some(NodeContent::Bytes(b)) => Ok(b.clone()),
106                _ => Err(anyhow::anyhow!("Expected bytes in node content")),
107            }
108        }
109
110        if let Some(error_node) = node.get_optional_child("error") {
111            return Err(anyhow::anyhow!(
112                "Error getting prekeys: {}",
113                DisplayableNode(error_node)
114            ));
115        }
116
117        let reg_id_bytes = extract_bytes(node.get_optional_child("registration"))?;
118        if reg_id_bytes.len() != 4 {
119            return Err(anyhow::anyhow!("Invalid registration ID length"));
120        }
121        let registration_id = u32::from_be_bytes(reg_id_bytes.try_into().unwrap());
122
123        let keys_node = node.get_optional_child("keys").unwrap_or(node);
124
125        let identity_key_bytes = extract_bytes(keys_node.get_optional_child("identity"))?;
126
127        let identity_key_array: [u8; 32] =
128            identity_key_bytes.try_into().map_err(|v: Vec<u8>| {
129                anyhow::anyhow!("Invalid identity key length: got {}, expected 32", v.len())
130            })?;
131
132        let identity_key =
133            IdentityKey::new(PublicKey::from_djb_public_key_bytes(&identity_key_array)?);
134
135        let mut pre_key_tuple = None;
136        if let Some(pre_key_node) = keys_node.get_optional_child("key")
137            && let Some((id, key_bytes)) = Self::node_to_pre_key(pre_key_node)?
138        {
139            let pre_key_id: PreKeyId = id.into();
140            let pre_key_public = PublicKey::from_djb_public_key_bytes(&key_bytes)?;
141            pre_key_tuple = Some((pre_key_id, pre_key_public));
142        }
143
144        let signed_pre_key_node = keys_node
145            .get_optional_child("skey")
146            .ok_or(anyhow::anyhow!("Missing signed prekey"))?;
147        let (signed_pre_key_id_u32, signed_pre_key_public_bytes, signed_pre_key_signature) =
148            Self::node_to_signed_pre_key(signed_pre_key_node)?;
149
150        let signed_pre_key_id: SignedPreKeyId = signed_pre_key_id_u32.into();
151        let signed_pre_key_public =
152            PublicKey::from_djb_public_key_bytes(&signed_pre_key_public_bytes)?;
153
154        let bundle = PreKeyBundle::new(
155            registration_id,
156            (jid.device as u32).into(),
157            pre_key_tuple,
158            signed_pre_key_id,
159            signed_pre_key_public,
160            signed_pre_key_signature.to_vec(),
161            identity_key,
162        )?;
163
164        Ok(bundle)
165    }
166
167    fn node_to_pre_key(node: &Node) -> Result<Option<(u32, [u8; 32])>, anyhow::Error> {
168        let id_node_content = node
169            .get_optional_child("id")
170            .and_then(|n| n.content.as_ref());
171
172        let id = match id_node_content {
173            Some(NodeContent::Bytes(b)) if !b.is_empty() => {
174                if b.len() == 3 {
175                    Ok(u32::from_be_bytes([0, b[0], b[1], b[2]]))
176                } else if let Ok(s) = std::str::from_utf8(b) {
177                    let trimmed_s = s.trim();
178                    if trimmed_s.is_empty() {
179                        Err(anyhow::anyhow!("ID content is only whitespace"))
180                    } else {
181                        u32::from_str_radix(trimmed_s, 16).map_err(|e| e.into())
182                    }
183                } else {
184                    Err(anyhow::anyhow!("ID is not valid UTF-8 hex or 3-byte int"))
185                }
186            }
187            _ => Err(anyhow::anyhow!("Missing or empty pre-key ID content")),
188        };
189
190        let id = match id {
191            Ok(val) => val,
192            Err(_e) => return Ok(None),
193        };
194
195        let value_bytes = node
196            .get_optional_child("value")
197            .and_then(|n| n.content.as_ref())
198            .and_then(|c| {
199                if let NodeContent::Bytes(b) = c {
200                    Some(b.clone())
201                } else {
202                    None
203                }
204            })
205            .ok_or(anyhow::anyhow!("Missing pre-key value"))?;
206        if value_bytes.len() != 32 {
207            return Err(anyhow::anyhow!("Invalid pre-key value length"));
208        }
209
210        Ok(Some((id, value_bytes.try_into().unwrap())))
211    }
212
213    fn node_to_signed_pre_key(node: &Node) -> Result<(u32, [u8; 32], [u8; 64]), anyhow::Error> {
214        let (id, public_key_bytes) = match Self::node_to_pre_key(node)? {
215            Some((id, key)) => (id, key),
216            None => return Err(anyhow::anyhow!("Signed pre-key is missing ID or value")),
217        };
218        let signature_bytes = node
219            .get_optional_child("signature")
220            .and_then(|n| n.content.as_ref())
221            .and_then(|c| {
222                if let NodeContent::Bytes(b) = c {
223                    Some(b.clone())
224                } else {
225                    None
226                }
227            })
228            .ok_or(anyhow::anyhow!("Missing signed pre-key signature"))?;
229        if signature_bytes.len() != 64 {
230            return Err(anyhow::anyhow!("Invalid signature length"));
231        }
232
233        Ok((id, public_key_bytes, signature_bytes.try_into().unwrap()))
234    }
235}