signal_cli_jsonrpc_client/
trust_set.rs

1//! Signal trust set - a set of Signal UUIDs with optional safety numbers.
2//!
3//! Supports two deserialization formats:
4//! - Map: `{"uuid1": ["safety1", "safety2"], "uuid2": []}` - UUIDs with safety numbers
5//! - Sequence: `["uuid1", "uuid2"]` - UUIDs with no safety numbers (simpler)
6
7use crate::{Envelope, Identity, RpcClient};
8use serde::{
9    Deserialize, Deserializer,
10    de::{MapAccess, SeqAccess, Visitor},
11};
12use std::{borrow::Borrow, collections::HashMap, fmt, ops::Deref, str::FromStr};
13use tracing::{debug, info, warn};
14
15/// A validated Signal UUID in the format `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`.
16#[derive(Clone, PartialEq, Eq, Hash, Deserialize)]
17#[serde(try_from = "String")]
18pub struct Uuid(String);
19
20impl fmt::Debug for Uuid {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        fmt::Debug::fmt(&self.0, f)
23    }
24}
25
26impl FromStr for Uuid {
27    type Err = String;
28
29    fn from_str(s: &str) -> Result<Self, Self::Err> {
30        // UUID format: 8-4-4-4-12 hex chars (36 chars total with dashes)
31        if s.len() != 36 {
32            return Err(format!("UUID must be 36 characters, got {}", s.len()));
33        }
34
35        let parts: Vec<&str> = s.split('-').collect();
36        if parts.len() != 5 {
37            return Err(format!(
38                "UUID must have 5 dash-separated parts, got {}",
39                parts.len()
40            ));
41        }
42
43        let expected_lens = [8, 4, 4, 4, 12];
44        for (i, (part, &expected)) in parts.iter().zip(&expected_lens).enumerate() {
45            if part.len() != expected {
46                return Err(format!(
47                    "UUID part {} has wrong length: expected {}, got {}",
48                    i + 1,
49                    expected,
50                    part.len()
51                ));
52            }
53            if !part.chars().all(|c| c.is_ascii_hexdigit()) {
54                return Err(format!("UUID part {} contains non-hex characters", i + 1));
55            }
56        }
57
58        Ok(Uuid(s.to_owned()))
59    }
60}
61
62impl TryFrom<String> for Uuid {
63    type Error = String;
64
65    fn try_from(s: String) -> Result<Self, Self::Error> {
66        s.parse()
67    }
68}
69
70impl TryFrom<&str> for Uuid {
71    type Error = String;
72
73    fn try_from(s: &str) -> Result<Self, Self::Error> {
74        s.parse()
75    }
76}
77
78impl Deref for Uuid {
79    type Target = str;
80
81    fn deref(&self) -> &Self::Target {
82        &self.0
83    }
84}
85
86impl AsRef<str> for Uuid {
87    fn as_ref(&self) -> &str {
88        &self.0
89    }
90}
91
92impl Borrow<str> for Uuid {
93    fn borrow(&self) -> &str {
94        &self.0
95    }
96}
97
98impl fmt::Display for Uuid {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        write!(f, "{}", self.0)
101    }
102}
103
104/// A validated Signal safety number (60 digits, optionally separated by whitespace).
105#[derive(Clone, PartialEq, Eq, Hash, Deserialize)]
106#[serde(try_from = "String")]
107pub struct SafetyNumber(String);
108
109impl fmt::Debug for SafetyNumber {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        fmt::Debug::fmt(&self.0, f)
112    }
113}
114
115impl FromStr for SafetyNumber {
116    type Err = String;
117
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        // Safety number format: 60 digits, optionally grouped with whitespace
120        let mut digit_count = 0;
121        for c in s.chars() {
122            if c.is_ascii_digit() {
123                digit_count += 1;
124            } else if !c.is_whitespace() {
125                return Err(format!(
126                    "Safety number must contain only digits and whitespace, found '{c}'"
127                ));
128            }
129        }
130
131        if digit_count != 60 {
132            return Err(format!(
133                "Safety number must contain exactly 60 digits, got {digit_count}"
134            ));
135        }
136
137        Ok(SafetyNumber(s.to_owned()))
138    }
139}
140
141impl TryFrom<String> for SafetyNumber {
142    type Error = String;
143
144    fn try_from(s: String) -> Result<Self, Self::Error> {
145        s.parse()
146    }
147}
148
149impl TryFrom<&str> for SafetyNumber {
150    type Error = String;
151
152    fn try_from(s: &str) -> Result<Self, Self::Error> {
153        s.parse()
154    }
155}
156
157impl Deref for SafetyNumber {
158    type Target = str;
159
160    fn deref(&self) -> &Self::Target {
161        &self.0
162    }
163}
164
165impl AsRef<str> for SafetyNumber {
166    fn as_ref(&self) -> &str {
167        &self.0
168    }
169}
170
171impl Borrow<str> for SafetyNumber {
172    fn borrow(&self) -> &str {
173        &self.0
174    }
175}
176
177impl fmt::Display for SafetyNumber {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        write!(f, "{}", self.0)
180    }
181}
182
183/// A set of Signal UUIDs with optional safety numbers for trust verification.
184///
185/// Can be deserialized from either:
186/// - A map of UUID -> safety numbers: `{"uuid1": ["12345..."], "uuid2": []}`
187/// - A sequence of UUIDs (no safety numbers): `["uuid1", "uuid2"]`
188#[derive(Clone, Default)]
189pub struct SignalTrustSet {
190    map: HashMap<Uuid, Vec<SafetyNumber>>,
191}
192
193impl fmt::Debug for SignalTrustSet {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        fmt::Debug::fmt(&self.map, f)
196    }
197}
198
199impl SignalTrustSet {
200    /// Create an empty container.
201    pub fn new() -> Self {
202        Self::default()
203    }
204
205    /// Check if the sender of an envelope is trusted.
206    ///
207    /// Currently checks if the source UUID is in the trust set.
208    /// Will eventually also verify safety numbers.
209    pub fn is_trusted(&self, envelope: &Envelope) -> bool {
210        self.map.contains_key(envelope.source_uuid.as_str())
211    }
212
213    /// Get all UUIDs as an iterator of string slices.
214    pub fn uuids(&self) -> impl Iterator<Item = &str> {
215        self.map.keys().map(|u| u.as_ref())
216    }
217
218    /// Get the number of admin UUIDs.
219    pub fn len(&self) -> usize {
220        self.map.len()
221    }
222
223    /// Check if empty.
224    pub fn is_empty(&self) -> bool {
225        self.map.is_empty()
226    }
227
228    /// Iterate over UUID and safety number pairs.
229    pub fn iter(&self) -> impl Iterator<Item = (&Uuid, &Vec<SafetyNumber>)> {
230        self.map.iter()
231    }
232
233    /// Get safety numbers for a specific UUID.
234    pub fn get(&self, uuid: &str) -> Option<&Vec<SafetyNumber>> {
235        self.map.get(uuid)
236    }
237
238    /// Update trust for all UUIDs with safety numbers configured.
239    ///
240    /// For each UUID with safety numbers:
241    /// 1. Check current identities in signal-cli via listIdentities
242    /// 2. If any trusted identity is NOT in our config, remove the contact entirely and re-add only configured ones
243    /// 3. Otherwise, just trust any new safety numbers from config that aren't already trusted
244    ///
245    /// Returns an error if any trust operation fails.
246    pub async fn update_trust(
247        &self,
248        signal_cli: &impl RpcClient,
249        signal_account: &str,
250    ) -> Result<(), String> {
251        info!("Updating trust for {} configured UUIDs", self.map.len());
252
253        for (uuid, safety_numbers) in &self.map {
254            if safety_numbers.is_empty() {
255                continue;
256            }
257
258            // Helper to check if a safety number string is in our configured list
259            let is_configured = |sn: &str| safety_numbers.iter().any(|s| s.as_ref() == sn);
260
261            // Get current identities from signal-cli
262            let current_identities: Vec<Identity> = match signal_cli
263                .list_identities(Some(signal_account.to_owned()), Some(uuid.to_string()))
264                .await
265            {
266                Ok(value) => serde_json::from_value(value).unwrap_or_default(),
267                Err(err) => {
268                    debug!("Could not list identities for {uuid} (may not exist yet): {err}");
269                    Vec::new()
270                }
271            };
272
273            // Check if any trusted identity in signal-cli is NOT in our config
274            let has_revoked_identity = current_identities
275                .iter()
276                .any(|id| id.trust_level.is_trusted() && !is_configured(&id.safety_number));
277
278            if has_revoked_identity {
279                // Log which identities are being revoked
280                for id in &current_identities {
281                    if id.trust_level.is_trusted() && !is_configured(&id.safety_number) {
282                        warn!(
283                            "Revoking trust for {uuid}: safety number {} is trusted in signal-cli but not in config",
284                            id.safety_number
285                        );
286                    }
287                }
288
289                // Remove contact to clear all existing trust
290                warn!("Resetting all trust for {uuid} due to revoked identity");
291                signal_cli
292                    .remove_contact(
293                        Some(signal_account.to_owned()),
294                        uuid.to_string(),
295                        true,  // forget - delete identity keys and sessions
296                        false, // hide
297                    )
298                    .await
299                    .map_err(|err| format!("Failed to remove contact {uuid}: {err}"))?;
300
301                // Re-add all configured safety numbers
302                for safety_number in safety_numbers {
303                    info!("Trusting safety number for {uuid}");
304                    signal_cli
305                        .trust(
306                            Some(signal_account.to_owned()),
307                            uuid.to_string(),
308                            false,
309                            Some(safety_number.to_string()),
310                        )
311                        .await
312                        .map_err(|err| format!("Failed to trust {uuid}: {err}"))?;
313                }
314
315                // Verify the reset worked correctly
316                let new_identities: Vec<Identity> = signal_cli
317                    .list_identities(Some(signal_account.to_owned()), Some(uuid.to_string()))
318                    .await
319                    .map_err(|err| format!("Failed to verify trust reset for {uuid}: {err}"))
320                    .and_then(|value| {
321                        serde_json::from_value(value)
322                            .map_err(|err| format!("Failed to parse identities for {uuid}: {err}"))
323                    })?;
324
325                let trusted_now: Vec<&str> = new_identities
326                    .iter()
327                    .filter(|id| id.trust_level.is_trusted())
328                    .map(|id| id.safety_number.as_str())
329                    .collect();
330
331                // Check all configured safety numbers are now trusted
332                for safety_number in safety_numbers {
333                    if !trusted_now.contains(&safety_number.as_ref()) {
334                        return Err(format!(
335                            "Verification failed for {uuid}: safety number {} should be trusted but isn't",
336                            safety_number
337                        ));
338                    }
339                }
340
341                // Check no unexpected safety numbers are trusted
342                for sn in &trusted_now {
343                    if !is_configured(sn) {
344                        return Err(format!(
345                            "Verification failed for {uuid}: safety number {} is trusted but not in config",
346                            sn
347                        ));
348                    }
349                }
350
351                info!(
352                    "Trust reset verified for {uuid}: {} safety numbers trusted",
353                    trusted_now.len()
354                );
355            } else {
356                // Just add any new safety numbers that aren't already trusted
357                let already_trusted: Vec<&str> = current_identities
358                    .iter()
359                    .filter(|id| id.trust_level.is_trusted())
360                    .map(|id| id.safety_number.as_str())
361                    .collect();
362
363                for safety_number in safety_numbers {
364                    if !already_trusted.contains(&safety_number.as_ref()) {
365                        info!("Trusting new safety number for {uuid}");
366                        signal_cli
367                            .trust(
368                                Some(signal_account.to_owned()),
369                                uuid.to_string(),
370                                false,
371                                Some(safety_number.to_string()),
372                            )
373                            .await
374                            .map_err(|err| format!("Failed to trust {uuid}: {err}"))?;
375                    }
376                }
377            }
378        }
379
380        Ok(())
381    }
382}
383
384impl<'de> Deserialize<'de> for SignalTrustSet {
385    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
386    where
387        D: Deserializer<'de>,
388    {
389        deserializer.deserialize_any(SignalTrustSetVisitor)
390    }
391}
392
393struct SignalTrustSetVisitor;
394
395impl<'de> Visitor<'de> for SignalTrustSetVisitor {
396    type Value = SignalTrustSet;
397
398    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
399        formatter.write_str("a map of UUIDs to safety numbers, or a sequence of UUIDs")
400    }
401
402    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
403    where
404        M: MapAccess<'de>,
405    {
406        let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
407        while let Some((key, value)) = access.next_entry::<Uuid, Vec<SafetyNumber>>()? {
408            map.insert(key, value);
409        }
410        Ok(SignalTrustSet { map })
411    }
412
413    fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
414    where
415        S: SeqAccess<'de>,
416    {
417        let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
418        while let Some(uuid) = access.next_element::<Uuid>()? {
419            map.insert(uuid, Vec::new());
420        }
421        Ok(SignalTrustSet { map })
422    }
423}
424
425impl FromIterator<Uuid> for SignalTrustSet {
426    fn from_iter<I: IntoIterator<Item = Uuid>>(iter: I) -> Self {
427        Self {
428            map: iter.into_iter().map(|uuid| (uuid, Vec::new())).collect(),
429        }
430    }
431}
432
433impl FromIterator<(Uuid, Vec<SafetyNumber>)> for SignalTrustSet {
434    fn from_iter<I: IntoIterator<Item = (Uuid, Vec<SafetyNumber>)>>(iter: I) -> Self {
435        Self {
436            map: iter.into_iter().collect(),
437        }
438    }
439}
440
441impl<'a> IntoIterator for &'a SignalTrustSet {
442    type Item = (&'a Uuid, &'a Vec<SafetyNumber>);
443    type IntoIter = std::collections::hash_map::Iter<'a, Uuid, Vec<SafetyNumber>>;
444
445    fn into_iter(self) -> Self::IntoIter {
446        self.map.iter()
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    const UUID1: &str = "12345678-1234-1234-1234-123456789abc";
455    const UUID2: &str = "abcdef12-abcd-abcd-abcd-abcdef123456";
456    const UUID3: &str = "00000000-0000-0000-0000-000000000000";
457    const SAFETY1: &str = "123456789012345678901234567890123456789012345678901234567890";
458    const SAFETY2: &str = "098765432109876543210987654321098765432109876543210987654321";
459
460    #[test]
461    fn test_uuid_validation() {
462        assert!(UUID1.parse::<Uuid>().is_ok());
463        assert!("not-a-uuid".parse::<Uuid>().is_err());
464        assert!("12345678-1234-1234-1234-12345678".parse::<Uuid>().is_err()); // too short
465        assert!(
466            "12345678-1234-1234-1234-123456789abcdef"
467                .parse::<Uuid>()
468                .is_err()
469        ); // too long
470        assert!(
471            "12345678-1234-1234-1234-123456789xyz"
472                .parse::<Uuid>()
473                .is_err()
474        ); // non-hex
475    }
476
477    #[test]
478    fn test_safety_number_validation() {
479        assert!(SAFETY1.parse::<SafetyNumber>().is_ok());
480        // With whitespace (common format)
481        assert!(
482            "12345 67890 12345 67890 12345 67890 12345 67890 12345 67890 12345 67890"
483                .parse::<SafetyNumber>()
484                .is_ok()
485        );
486        assert!("12345".parse::<SafetyNumber>().is_err()); // too short
487        assert!(
488            "12345678901234567890123456789012345678901234567890123456789x"
489                .parse::<SafetyNumber>()
490                .is_err()
491        ); // non-digit
492    }
493
494    #[test]
495    fn test_deserialize_map() {
496        let json = format!(r#"{{"{UUID1}": ["{SAFETY1}", "{SAFETY2}"], "{UUID2}": []}}"#);
497        let trust_set: SignalTrustSet = serde_json::from_str(&json).unwrap();
498
499        assert_eq!(trust_set.len(), 2);
500        assert!(trust_set.get(UUID1).is_some());
501        assert!(trust_set.get(UUID2).is_some());
502        assert_eq!(trust_set.get(UUID1).unwrap().len(), 2);
503        assert!(trust_set.get(UUID2).unwrap().is_empty());
504    }
505
506    #[test]
507    fn test_deserialize_seq() {
508        let json = format!(r#"["{UUID1}", "{UUID2}", "{UUID3}"]"#);
509        let trust_set: SignalTrustSet = serde_json::from_str(&json).unwrap();
510
511        assert_eq!(trust_set.len(), 3);
512        assert!(trust_set.get(UUID1).is_some());
513        assert!(trust_set.get(UUID2).is_some());
514        assert!(trust_set.get(UUID3).is_some());
515        // All should have empty safety numbers
516        assert!(trust_set.get(UUID1).unwrap().is_empty());
517        assert!(trust_set.get(UUID2).unwrap().is_empty());
518        assert!(trust_set.get(UUID3).unwrap().is_empty());
519    }
520
521    #[test]
522    fn test_empty_map() {
523        let json = r#"{}"#;
524        let uuids: SignalTrustSet = serde_json::from_str(json).unwrap();
525        assert!(uuids.is_empty());
526    }
527
528    #[test]
529    fn test_empty_seq() {
530        let json = r#"[]"#;
531        let uuids: SignalTrustSet = serde_json::from_str(json).unwrap();
532        assert!(uuids.is_empty());
533    }
534
535    #[test]
536    fn test_invalid_uuid_rejected() {
537        let json = r#"["not-a-valid-uuid"]"#;
538        assert!(serde_json::from_str::<SignalTrustSet>(json).is_err());
539    }
540
541    #[test]
542    fn test_invalid_safety_number_rejected() {
543        let json = format!(r#"{{"{UUID1}": ["invalid"]}}"#);
544        assert!(serde_json::from_str::<SignalTrustSet>(&json).is_err());
545    }
546}