whatsapp_rust/store/
signal_adapter.rs

1use crate::store::Device;
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use wacore::libsignal::protocol::{
6    Direction, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore, PreKeyId,
7    PreKeyRecord, PreKeyStore, ProtocolAddress, SessionRecord, SessionStore, SignalProtocolError,
8    SignedPreKeyId, SignedPreKeyRecord, SignedPreKeyStore,
9};
10
11use wacore::libsignal::store::record_helpers as wacore_record;
12use wacore::libsignal::store::sender_key_name::SenderKeyName;
13use wacore::libsignal::store::{
14    PreKeyStore as WacorePreKeyStore, SignedPreKeyStore as WacoreSignedPreKeyStore,
15};
16
17#[derive(Clone)]
18struct SharedDevice {
19    device: Arc<RwLock<Device>>,
20}
21
22#[derive(Clone)]
23pub struct SessionAdapter(SharedDevice);
24#[derive(Clone)]
25pub struct IdentityAdapter(SharedDevice);
26#[derive(Clone)]
27pub struct PreKeyAdapter(SharedDevice);
28#[derive(Clone)]
29pub struct SignedPreKeyAdapter(SharedDevice);
30
31#[derive(Clone)]
32pub struct SenderKeyAdapter(SharedDevice);
33
34#[derive(Clone)]
35pub struct SignalProtocolStoreAdapter {
36    pub session_store: SessionAdapter,
37    pub identity_store: IdentityAdapter,
38    pub pre_key_store: PreKeyAdapter,
39    pub signed_pre_key_store: SignedPreKeyAdapter,
40    pub sender_key_store: SenderKeyAdapter,
41}
42
43impl SignalProtocolStoreAdapter {
44    pub fn new(device: Arc<RwLock<Device>>) -> Self {
45        let shared = SharedDevice { device };
46        Self {
47            session_store: SessionAdapter(shared.clone()),
48            identity_store: IdentityAdapter(shared.clone()),
49            pre_key_store: PreKeyAdapter(shared.clone()),
50            signed_pre_key_store: SignedPreKeyAdapter(shared.clone()),
51            sender_key_store: SenderKeyAdapter(shared),
52        }
53    }
54}
55
56#[async_trait]
57impl SessionStore for SessionAdapter {
58    async fn load_session(
59        &self,
60        address: &ProtocolAddress,
61    ) -> Result<Option<SessionRecord>, SignalProtocolError> {
62        let device = self.0.device.read().await;
63        match device
64            .backend
65            .get_session(&address.to_string())
66            .await
67            .map_err(|e| SignalProtocolError::InvalidState("backend", e.to_string()))?
68        {
69            Some(data) => Ok(Some(SessionRecord::deserialize(&data)?)),
70            None => Ok(None),
71        }
72    }
73
74    async fn store_session(
75        &mut self,
76        address: &ProtocolAddress,
77        record: &SessionRecord,
78    ) -> Result<(), SignalProtocolError> {
79        let device = self.0.device.read().await;
80        let record_bytes = record.serialize()?;
81        device
82            .backend
83            .put_session(&address.to_string(), &record_bytes)
84            .await
85            .map_err(|e| SignalProtocolError::InvalidState("backend", e.to_string()))
86    }
87}
88
89#[async_trait]
90impl IdentityKeyStore for IdentityAdapter {
91    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
92        let device = self.0.device.read().await;
93        IdentityKeyStore::get_identity_key_pair(&*device)
94            .await
95            .map_err(|e| SignalProtocolError::InvalidState("get_identity_key_pair", e.to_string()))
96    }
97
98    async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> {
99        let device = self.0.device.read().await;
100        IdentityKeyStore::get_local_registration_id(&*device)
101            .await
102            .map_err(|e| {
103                SignalProtocolError::InvalidState("get_local_registration_id", e.to_string())
104            })
105    }
106
107    async fn save_identity(
108        &mut self,
109        address: &ProtocolAddress,
110        identity: &IdentityKey,
111    ) -> Result<IdentityChange, SignalProtocolError> {
112        let existing_identity = self.get_identity(address).await?;
113
114        let mut device = self.0.device.write().await;
115        IdentityKeyStore::save_identity(&mut *device, address, identity)
116            .await
117            .map_err(|e| SignalProtocolError::InvalidState("save_identity", e.to_string()))?;
118
119        match existing_identity {
120            None => Ok(IdentityChange::NewOrUnchanged),
121            Some(existing) if &existing == identity => Ok(IdentityChange::NewOrUnchanged),
122            Some(_) => Ok(IdentityChange::ReplacedExisting),
123        }
124    }
125
126    async fn is_trusted_identity(
127        &self,
128        address: &ProtocolAddress,
129        identity: &IdentityKey,
130        direction: Direction,
131    ) -> Result<bool, SignalProtocolError> {
132        let device = self.0.device.read().await;
133        IdentityKeyStore::is_trusted_identity(&*device, address, identity, direction)
134            .await
135            .map_err(|e| SignalProtocolError::InvalidState("is_trusted_identity", e.to_string()))
136    }
137
138    async fn get_identity(
139        &self,
140        address: &ProtocolAddress,
141    ) -> Result<Option<IdentityKey>, SignalProtocolError> {
142        let device = self.0.device.read().await;
143        IdentityKeyStore::get_identity(&*device, address)
144            .await
145            .map_err(|e| SignalProtocolError::InvalidState("get_identity", e.to_string()))
146    }
147}
148
149#[async_trait]
150impl PreKeyStore for PreKeyAdapter {
151    async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result<PreKeyRecord, SignalProtocolError> {
152        let device = self.0.device.read().await;
153        WacorePreKeyStore::load_prekey(&*device, prekey_id.into())
154            .await
155            .map_err(|e| SignalProtocolError::InvalidState("backend", e.to_string()))?
156            .ok_or(SignalProtocolError::InvalidPreKeyId)
157            .and_then(wacore_record::prekey_structure_to_record)
158    }
159    async fn save_pre_key(
160        &mut self,
161        prekey_id: PreKeyId,
162        record: &PreKeyRecord,
163    ) -> Result<(), SignalProtocolError> {
164        let device = self.0.device.read().await;
165        let structure = wacore_record::prekey_record_to_structure(record)?;
166        WacorePreKeyStore::store_prekey(&*device, prekey_id.into(), structure, false)
167            .await
168            .map_err(|e| SignalProtocolError::InvalidState("backend", e.to_string()))
169    }
170    async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> {
171        let device = self.0.device.read().await;
172        WacorePreKeyStore::remove_prekey(&*device, prekey_id.into())
173            .await
174            .map_err(|e| SignalProtocolError::InvalidState("backend", e.to_string()))
175    }
176}
177
178#[async_trait]
179impl SignedPreKeyStore for SignedPreKeyAdapter {
180    async fn get_signed_pre_key(
181        &self,
182        signed_prekey_id: SignedPreKeyId,
183    ) -> Result<SignedPreKeyRecord, SignalProtocolError> {
184        let device = self.0.device.read().await;
185        WacoreSignedPreKeyStore::load_signed_prekey(&*device, signed_prekey_id.into())
186            .await
187            .map_err(|e| SignalProtocolError::InvalidState("backend", e.to_string()))?
188            .ok_or(SignalProtocolError::InvalidSignedPreKeyId)
189            .and_then(wacore_record::signed_prekey_structure_to_record)
190    }
191    async fn save_signed_pre_key(
192        &mut self,
193        _id: SignedPreKeyId,
194        _record: &SignedPreKeyRecord,
195    ) -> Result<(), SignalProtocolError> {
196        Ok(())
197    }
198}
199
200#[async_trait]
201impl wacore::libsignal::protocol::SenderKeyStore for SenderKeyAdapter {
202    async fn store_sender_key(
203        &mut self,
204        sender_key_name: &SenderKeyName,
205        record: &wacore::libsignal::protocol::SenderKeyRecord,
206    ) -> wacore::libsignal::protocol::error::Result<()> {
207        let mut device = self.0.device.write().await;
208        wacore::libsignal::protocol::SenderKeyStore::store_sender_key(
209            &mut *device,
210            sender_key_name,
211            record,
212        )
213        .await
214    }
215
216    async fn load_sender_key(
217        &mut self,
218        sender_key_name: &SenderKeyName,
219    ) -> wacore::libsignal::protocol::error::Result<
220        Option<wacore::libsignal::protocol::SenderKeyRecord>,
221    > {
222        let mut device = self.0.device.write().await;
223        wacore::libsignal::protocol::SenderKeyStore::load_sender_key(&mut *device, sender_key_name)
224            .await
225    }
226}