signal_cli_jsonrpc_client/
trust_set.rs1use crate::{Envelope, Identity, RpcClient};
8use serde::{
9 Deserialize, Deserializer,
10 de::{MapAccess, SeqAccess, Visitor},
11};
12use std::{borrow::Borrow, collections::HashMap, fmt, ops::Deref, str::FromStr};
13use tracing::{debug, info, warn};
14
15#[derive(Clone, PartialEq, Eq, Hash, Deserialize)]
17#[serde(try_from = "String")]
18pub struct Uuid(String);
19
20impl fmt::Debug for Uuid {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 fmt::Debug::fmt(&self.0, f)
23 }
24}
25
26impl FromStr for Uuid {
27 type Err = String;
28
29 fn from_str(s: &str) -> Result<Self, Self::Err> {
30 if s.len() != 36 {
32 return Err(format!("UUID must be 36 characters, got {}", s.len()));
33 }
34
35 let parts: Vec<&str> = s.split('-').collect();
36 if parts.len() != 5 {
37 return Err(format!(
38 "UUID must have 5 dash-separated parts, got {}",
39 parts.len()
40 ));
41 }
42
43 let expected_lens = [8, 4, 4, 4, 12];
44 for (i, (part, &expected)) in parts.iter().zip(&expected_lens).enumerate() {
45 if part.len() != expected {
46 return Err(format!(
47 "UUID part {} has wrong length: expected {}, got {}",
48 i + 1,
49 expected,
50 part.len()
51 ));
52 }
53 if !part.chars().all(|c| c.is_ascii_hexdigit()) {
54 return Err(format!("UUID part {} contains non-hex characters", i + 1));
55 }
56 }
57
58 Ok(Uuid(s.to_owned()))
59 }
60}
61
62impl TryFrom<String> for Uuid {
63 type Error = String;
64
65 fn try_from(s: String) -> Result<Self, Self::Error> {
66 s.parse()
67 }
68}
69
70impl TryFrom<&str> for Uuid {
71 type Error = String;
72
73 fn try_from(s: &str) -> Result<Self, Self::Error> {
74 s.parse()
75 }
76}
77
78impl Deref for Uuid {
79 type Target = str;
80
81 fn deref(&self) -> &Self::Target {
82 &self.0
83 }
84}
85
86impl AsRef<str> for Uuid {
87 fn as_ref(&self) -> &str {
88 &self.0
89 }
90}
91
92impl Borrow<str> for Uuid {
93 fn borrow(&self) -> &str {
94 &self.0
95 }
96}
97
98impl fmt::Display for Uuid {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 write!(f, "{}", self.0)
101 }
102}
103
104#[derive(Clone, PartialEq, Eq, Hash, Deserialize)]
106#[serde(try_from = "String")]
107pub struct SafetyNumber(String);
108
109impl fmt::Debug for SafetyNumber {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 fmt::Debug::fmt(&self.0, f)
112 }
113}
114
115impl FromStr for SafetyNumber {
116 type Err = String;
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 let mut digit_count = 0;
121 for c in s.chars() {
122 if c.is_ascii_digit() {
123 digit_count += 1;
124 } else if !c.is_whitespace() {
125 return Err(format!(
126 "Safety number must contain only digits and whitespace, found '{c}'"
127 ));
128 }
129 }
130
131 if digit_count != 60 {
132 return Err(format!(
133 "Safety number must contain exactly 60 digits, got {digit_count}"
134 ));
135 }
136
137 Ok(SafetyNumber(s.to_owned()))
138 }
139}
140
141impl TryFrom<String> for SafetyNumber {
142 type Error = String;
143
144 fn try_from(s: String) -> Result<Self, Self::Error> {
145 s.parse()
146 }
147}
148
149impl TryFrom<&str> for SafetyNumber {
150 type Error = String;
151
152 fn try_from(s: &str) -> Result<Self, Self::Error> {
153 s.parse()
154 }
155}
156
157impl Deref for SafetyNumber {
158 type Target = str;
159
160 fn deref(&self) -> &Self::Target {
161 &self.0
162 }
163}
164
165impl AsRef<str> for SafetyNumber {
166 fn as_ref(&self) -> &str {
167 &self.0
168 }
169}
170
171impl Borrow<str> for SafetyNumber {
172 fn borrow(&self) -> &str {
173 &self.0
174 }
175}
176
177impl fmt::Display for SafetyNumber {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 write!(f, "{}", self.0)
180 }
181}
182
183#[derive(Clone, Default)]
189pub struct SignalTrustSet {
190 map: HashMap<Uuid, Vec<SafetyNumber>>,
191}
192
193impl fmt::Debug for SignalTrustSet {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 fmt::Debug::fmt(&self.map, f)
196 }
197}
198
199impl SignalTrustSet {
200 pub fn new() -> Self {
202 Self::default()
203 }
204
205 pub fn is_trusted(&self, envelope: &Envelope) -> bool {
210 self.map.contains_key(envelope.source_uuid.as_str())
211 }
212
213 pub fn uuids(&self) -> impl Iterator<Item = &str> {
215 self.map.keys().map(|u| u.as_ref())
216 }
217
218 pub fn len(&self) -> usize {
220 self.map.len()
221 }
222
223 pub fn is_empty(&self) -> bool {
225 self.map.is_empty()
226 }
227
228 pub fn iter(&self) -> impl Iterator<Item = (&Uuid, &Vec<SafetyNumber>)> {
230 self.map.iter()
231 }
232
233 pub fn get(&self, uuid: &str) -> Option<&Vec<SafetyNumber>> {
235 self.map.get(uuid)
236 }
237
238 pub async fn update_trust(
247 &self,
248 signal_cli: &impl RpcClient,
249 signal_account: &str,
250 ) -> Result<(), String> {
251 info!("Updating trust for {} configured UUIDs", self.map.len());
252
253 for (uuid, safety_numbers) in &self.map {
254 if safety_numbers.is_empty() {
255 continue;
256 }
257
258 let is_configured = |sn: &str| safety_numbers.iter().any(|s| s.as_ref() == sn);
260
261 let current_identities: Vec<Identity> = match signal_cli
263 .list_identities(Some(signal_account.to_owned()), Some(uuid.to_string()))
264 .await
265 {
266 Ok(value) => serde_json::from_value(value).unwrap_or_default(),
267 Err(err) => {
268 debug!("Could not list identities for {uuid} (may not exist yet): {err}");
269 Vec::new()
270 }
271 };
272
273 let has_revoked_identity = current_identities
275 .iter()
276 .any(|id| id.trust_level.is_trusted() && !is_configured(&id.safety_number));
277
278 if has_revoked_identity {
279 for id in ¤t_identities {
281 if id.trust_level.is_trusted() && !is_configured(&id.safety_number) {
282 warn!(
283 "Revoking trust for {uuid}: safety number {} is trusted in signal-cli but not in config",
284 id.safety_number
285 );
286 }
287 }
288
289 warn!("Resetting all trust for {uuid} due to revoked identity");
291 signal_cli
292 .remove_contact(
293 Some(signal_account.to_owned()),
294 uuid.to_string(),
295 true, false, )
298 .await
299 .map_err(|err| format!("Failed to remove contact {uuid}: {err}"))?;
300
301 for safety_number in safety_numbers {
303 info!("Trusting safety number for {uuid}");
304 signal_cli
305 .trust(
306 Some(signal_account.to_owned()),
307 uuid.to_string(),
308 false,
309 Some(safety_number.to_string()),
310 )
311 .await
312 .map_err(|err| format!("Failed to trust {uuid}: {err}"))?;
313 }
314
315 let new_identities: Vec<Identity> = signal_cli
317 .list_identities(Some(signal_account.to_owned()), Some(uuid.to_string()))
318 .await
319 .map_err(|err| format!("Failed to verify trust reset for {uuid}: {err}"))
320 .and_then(|value| {
321 serde_json::from_value(value)
322 .map_err(|err| format!("Failed to parse identities for {uuid}: {err}"))
323 })?;
324
325 let trusted_now: Vec<&str> = new_identities
326 .iter()
327 .filter(|id| id.trust_level.is_trusted())
328 .map(|id| id.safety_number.as_str())
329 .collect();
330
331 for safety_number in safety_numbers {
333 if !trusted_now.contains(&safety_number.as_ref()) {
334 return Err(format!(
335 "Verification failed for {uuid}: safety number {} should be trusted but isn't",
336 safety_number
337 ));
338 }
339 }
340
341 for sn in &trusted_now {
343 if !is_configured(sn) {
344 return Err(format!(
345 "Verification failed for {uuid}: safety number {} is trusted but not in config",
346 sn
347 ));
348 }
349 }
350
351 info!(
352 "Trust reset verified for {uuid}: {} safety numbers trusted",
353 trusted_now.len()
354 );
355 } else {
356 let already_trusted: Vec<&str> = current_identities
358 .iter()
359 .filter(|id| id.trust_level.is_trusted())
360 .map(|id| id.safety_number.as_str())
361 .collect();
362
363 for safety_number in safety_numbers {
364 if !already_trusted.contains(&safety_number.as_ref()) {
365 info!("Trusting new safety number for {uuid}");
366 signal_cli
367 .trust(
368 Some(signal_account.to_owned()),
369 uuid.to_string(),
370 false,
371 Some(safety_number.to_string()),
372 )
373 .await
374 .map_err(|err| format!("Failed to trust {uuid}: {err}"))?;
375 }
376 }
377 }
378 }
379
380 Ok(())
381 }
382}
383
384impl<'de> Deserialize<'de> for SignalTrustSet {
385 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
386 where
387 D: Deserializer<'de>,
388 {
389 deserializer.deserialize_any(SignalTrustSetVisitor)
390 }
391}
392
393struct SignalTrustSetVisitor;
394
395impl<'de> Visitor<'de> for SignalTrustSetVisitor {
396 type Value = SignalTrustSet;
397
398 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
399 formatter.write_str("a map of UUIDs to safety numbers, or a sequence of UUIDs")
400 }
401
402 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
403 where
404 M: MapAccess<'de>,
405 {
406 let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
407 while let Some((key, value)) = access.next_entry::<Uuid, Vec<SafetyNumber>>()? {
408 map.insert(key, value);
409 }
410 Ok(SignalTrustSet { map })
411 }
412
413 fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
414 where
415 S: SeqAccess<'de>,
416 {
417 let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
418 while let Some(uuid) = access.next_element::<Uuid>()? {
419 map.insert(uuid, Vec::new());
420 }
421 Ok(SignalTrustSet { map })
422 }
423}
424
425impl FromIterator<Uuid> for SignalTrustSet {
426 fn from_iter<I: IntoIterator<Item = Uuid>>(iter: I) -> Self {
427 Self {
428 map: iter.into_iter().map(|uuid| (uuid, Vec::new())).collect(),
429 }
430 }
431}
432
433impl FromIterator<(Uuid, Vec<SafetyNumber>)> for SignalTrustSet {
434 fn from_iter<I: IntoIterator<Item = (Uuid, Vec<SafetyNumber>)>>(iter: I) -> Self {
435 Self {
436 map: iter.into_iter().collect(),
437 }
438 }
439}
440
441impl<'a> IntoIterator for &'a SignalTrustSet {
442 type Item = (&'a Uuid, &'a Vec<SafetyNumber>);
443 type IntoIter = std::collections::hash_map::Iter<'a, Uuid, Vec<SafetyNumber>>;
444
445 fn into_iter(self) -> Self::IntoIter {
446 self.map.iter()
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 const UUID1: &str = "12345678-1234-1234-1234-123456789abc";
455 const UUID2: &str = "abcdef12-abcd-abcd-abcd-abcdef123456";
456 const UUID3: &str = "00000000-0000-0000-0000-000000000000";
457 const SAFETY1: &str = "123456789012345678901234567890123456789012345678901234567890";
458 const SAFETY2: &str = "098765432109876543210987654321098765432109876543210987654321";
459
460 #[test]
461 fn test_uuid_validation() {
462 assert!(UUID1.parse::<Uuid>().is_ok());
463 assert!("not-a-uuid".parse::<Uuid>().is_err());
464 assert!("12345678-1234-1234-1234-12345678".parse::<Uuid>().is_err()); assert!(
466 "12345678-1234-1234-1234-123456789abcdef"
467 .parse::<Uuid>()
468 .is_err()
469 ); assert!(
471 "12345678-1234-1234-1234-123456789xyz"
472 .parse::<Uuid>()
473 .is_err()
474 ); }
476
477 #[test]
478 fn test_safety_number_validation() {
479 assert!(SAFETY1.parse::<SafetyNumber>().is_ok());
480 assert!(
482 "12345 67890 12345 67890 12345 67890 12345 67890 12345 67890 12345 67890"
483 .parse::<SafetyNumber>()
484 .is_ok()
485 );
486 assert!("12345".parse::<SafetyNumber>().is_err()); assert!(
488 "12345678901234567890123456789012345678901234567890123456789x"
489 .parse::<SafetyNumber>()
490 .is_err()
491 ); }
493
494 #[test]
495 fn test_deserialize_map() {
496 let json = format!(r#"{{"{UUID1}": ["{SAFETY1}", "{SAFETY2}"], "{UUID2}": []}}"#);
497 let trust_set: SignalTrustSet = serde_json::from_str(&json).unwrap();
498
499 assert_eq!(trust_set.len(), 2);
500 assert!(trust_set.get(UUID1).is_some());
501 assert!(trust_set.get(UUID2).is_some());
502 assert_eq!(trust_set.get(UUID1).unwrap().len(), 2);
503 assert!(trust_set.get(UUID2).unwrap().is_empty());
504 }
505
506 #[test]
507 fn test_deserialize_seq() {
508 let json = format!(r#"["{UUID1}", "{UUID2}", "{UUID3}"]"#);
509 let trust_set: SignalTrustSet = serde_json::from_str(&json).unwrap();
510
511 assert_eq!(trust_set.len(), 3);
512 assert!(trust_set.get(UUID1).is_some());
513 assert!(trust_set.get(UUID2).is_some());
514 assert!(trust_set.get(UUID3).is_some());
515 assert!(trust_set.get(UUID1).unwrap().is_empty());
517 assert!(trust_set.get(UUID2).unwrap().is_empty());
518 assert!(trust_set.get(UUID3).unwrap().is_empty());
519 }
520
521 #[test]
522 fn test_empty_map() {
523 let json = r#"{}"#;
524 let uuids: SignalTrustSet = serde_json::from_str(json).unwrap();
525 assert!(uuids.is_empty());
526 }
527
528 #[test]
529 fn test_empty_seq() {
530 let json = r#"[]"#;
531 let uuids: SignalTrustSet = serde_json::from_str(json).unwrap();
532 assert!(uuids.is_empty());
533 }
534
535 #[test]
536 fn test_invalid_uuid_rejected() {
537 let json = r#"["not-a-valid-uuid"]"#;
538 assert!(serde_json::from_str::<SignalTrustSet>(json).is_err());
539 }
540
541 #[test]
542 fn test_invalid_safety_number_rejected() {
543 let json = format!(r#"{{"{UUID1}": ["invalid"]}}"#);
544 assert!(serde_json::from_str::<SignalTrustSet>(&json).is_err());
545 }
546}