1use crate::auth::Sig;
18use crate::fwid::{Key, compute_key, fw_check, fw_to_key};
19use crate::types::{
20 Device, DeviceId, DeviceType, Endpoint, Identity, IdentityHandle, MAX_REPLICATION_TARGET,
21 MlDsaKeyPair, Presence, PresenceReceipt, StorageHandle, StorageStrategy,
22};
23use anyhow::{Context, Result};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::sync::Arc;
27use tokio::sync::RwLock as AsyncRwLock;
28use tokio::sync::RwLock;
29struct MockDht {
33 storage: HashMap<Key, Vec<u8>>,
34}
35
36impl MockDht {
37 fn new() -> Self {
38 Self {
39 storage: HashMap::new(),
40 }
41 }
42
43 async fn put(&mut self, key: Key, value: Vec<u8>) -> Result<()> {
44 self.storage.insert(key, value);
45 Ok(())
46 }
47
48 async fn get(&self, key: &Key) -> Result<Vec<u8>> {
49 self.storage
50 .get(key)
51 .cloned()
52 .ok_or_else(|| anyhow::anyhow!("Key not found"))
53 }
54}
55
56static DHT: once_cell::sync::Lazy<Arc<RwLock<MockDht>>> =
58 once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(MockDht::new())));
59
60static GLOBAL_DHT_CLIENT: once_cell::sync::Lazy<
62 AsyncRwLock<Option<Arc<crate::dht::client::DhtClient>>>,
63> = once_cell::sync::Lazy::new(|| AsyncRwLock::new(None));
64
65pub async fn set_dht_client(client: crate::dht::client::DhtClient) -> bool {
67 let mut guard = GLOBAL_DHT_CLIENT.write().await;
68 let was_empty = guard.is_none();
69 *guard = Some(Arc::new(client));
70 was_empty
71}
72
73pub async fn clear_dht_client() {
75 GLOBAL_DHT_CLIENT.write().await.take();
76}
77
78async fn get_dht_client_async() -> Option<Arc<crate::dht::client::DhtClient>> {
79 GLOBAL_DHT_CLIENT.read().await.clone()
80}
81
82async fn dht_put_bytes(key: &Key, value: Vec<u8>) -> Result<()> {
83 if let Some(client) = get_dht_client_async().await {
84 let k = hex::encode(key.as_bytes());
85 let _ = client
86 .put(k, value)
87 .await
88 .context("Failed to store data in DHT client")?;
89 Ok(())
90 } else {
91 let mut dht = DHT.write().await;
92 dht.put(key.clone(), value).await
93 }
94}
95
96async fn dht_try_get_bytes(key: &Key) -> Result<Option<Vec<u8>>> {
97 if let Some(client) = get_dht_client_async().await {
98 let k = hex::encode(key.as_bytes());
99 client.get(k).await.context("DHT get failed")
100 } else {
101 let dht = DHT.read().await;
102 match dht.get(key).await {
103 Ok(bytes) => Ok(Some(bytes)),
104 Err(_) => Ok(None),
105 }
106 }
107}
108
109async fn dht_get_bytes(key: &Key) -> Result<Vec<u8>> {
110 match dht_try_get_bytes(key).await? {
111 Some(bytes) => Ok(bytes),
112 None => anyhow::bail!("Key not found"),
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct IdentityPacketV1 {
123 pub v: u8,
124 pub words: [String; 4],
125 pub id: Key,
126 pub pk: Vec<u8>,
127 pub sig: Option<Vec<u8>>, pub device_set_root: Key,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct MemberRef {
134 pub member_id: Key,
135 pub member_pk: Vec<u8>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct GroupIdentityPacketV1 {
141 pub v: u8,
142 pub words: [String; 4],
143 pub id: Key,
144 pub group_pk: Vec<u8>,
145 pub group_sig: Vec<u8>,
146 pub members: Vec<MemberRef>,
147 pub membership_root: Key,
148 pub created_at: u64,
149 pub mls_ciphersuite: Option<u16>,
150}
151
152#[derive(Clone)]
154pub struct GroupKeyPair {
155 pub group_pk: crate::quantum_crypto::MlDsaPublicKey,
156 pub group_sk: crate::quantum_crypto::MlDsaSecretKey,
157}
158
159impl std::fmt::Debug for GroupKeyPair {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 write!(
162 f,
163 "GroupKeyPair {{ group_pk: <{} bytes>, group_sk: <hidden> }}",
164 self.group_pk.as_bytes().len()
165 )
166 }
167}
168
169pub async fn register_identity(words: [&str; 4], keypair: &MlDsaKeyPair) -> Result<IdentityHandle> {
182 let words_owned: [String; 4] = [
184 words[0].to_string(),
185 words[1].to_string(),
186 words[2].to_string(),
187 words[3].to_string(),
188 ];
189
190 if !fw_check(words_owned.clone()) {
192 anyhow::bail!("Invalid word in identity");
193 }
194
195 let key = fw_to_key(words_owned.clone())?;
197
198 if dht_try_get_bytes(&key).await?.is_some() {
200 anyhow::bail!("Identity already registered");
201 }
202
203 let identity = Identity {
205 words: words_owned.clone(),
206 key: key.clone(),
207 public_key: keypair.public_key.clone(),
208 };
209
210 let packet = IdentityPacketV1 {
211 v: 1,
212 words: words_owned.clone(),
213 id: key.clone(),
214 pk: keypair.public_key.clone(),
215 sig: None,
216 device_set_root: compute_key("device-set", key.as_bytes()),
217 };
218
219 dht_put_bytes(&key, serde_json::to_vec(&packet)?).await?;
220
221 Ok(IdentityHandle::new(identity, keypair.clone()))
222}
223
224pub async fn get_identity(key: Key) -> Result<Identity> {
232 let data = dht_get_bytes(&key).await.context("Identity not found")?;
234 if let Ok(pkt) = serde_json::from_slice::<IdentityPacketV1>(&data) {
235 let identity = Identity {
236 words: pkt.words,
237 key: pkt.id,
238 public_key: pkt.pk,
239 };
240 return Ok(identity);
241 }
242 let identity: Identity = serde_json::from_slice(&data)?;
244 Ok(identity)
245}
246
247pub async fn identity_fetch(key: Key) -> Result<IdentityPacketV1> {
249 let data = dht_get_bytes(&key).await.context("Identity not found")?;
250 let pkt: IdentityPacketV1 = serde_json::from_slice(&data)?;
251 Ok(pkt)
252}
253
254pub async fn register_presence(
268 handle: &IdentityHandle,
269 devices: Vec<Device>,
270 active_device: DeviceId,
271) -> Result<PresenceReceipt> {
272 if !devices.iter().any(|d| d.id == active_device) {
274 anyhow::bail!("Active device not in device list");
275 }
276
277 let presence = Presence {
279 identity: handle.key(),
280 devices,
281 active_device: Some(active_device),
282 timestamp: std::time::SystemTime::now()
283 .duration_since(std::time::UNIX_EPOCH)?
284 .as_secs(),
285 signature: vec![], };
287
288 let presence_bytes = serde_json::to_vec(&presence)?;
290 let signature = handle.sign(&presence_bytes)?;
291
292 let mut signed_presence = presence;
293 signed_presence.signature = signature;
294
295 let presence_key = derive_presence_key(handle.key());
297 dht_put_bytes(&presence_key, serde_json::to_vec(&signed_presence)?).await?;
298
299 let receipt = PresenceReceipt {
301 identity: handle.key(),
302 timestamp: signed_presence.timestamp,
303 storing_nodes: vec![Key::from([0u8; 32])], };
305
306 Ok(receipt)
307}
308
309pub async fn get_presence(identity_key: Key) -> Result<Presence> {
317 let presence_key = derive_presence_key(identity_key);
318 let data = dht_get_bytes(&presence_key)
319 .await
320 .context("Presence not found")?;
321 let presence: Presence = serde_json::from_slice(&data)?;
322 Ok(presence)
323}
324
325pub async fn register_headless(
335 handle: &IdentityHandle,
336 storage_gb: u32,
337 endpoint: Endpoint,
338) -> Result<DeviceId> {
339 let mut presence = get_presence(handle.key()).await?;
341
342 let device = Device {
344 id: DeviceId::generate(),
345 device_type: crate::types::presence::DeviceType::Headless,
346 storage_gb: storage_gb as u64,
347 endpoint,
348 capabilities: crate::types::presence::DeviceCapabilities {
349 storage_bytes: storage_gb as u64 * 1_000_000_000,
350 always_online: true,
351 supports_fec: true,
352 supports_seal: true,
353 ..Default::default()
354 },
355 };
356
357 let device_id = device.id;
358 presence.devices.push(device);
359
360 let active = presence.active_device.unwrap_or(device_id);
362 register_presence(handle, presence.devices, active).await?;
363
364 Ok(device_id)
365}
366
367pub async fn set_active_device(handle: &IdentityHandle, device_id: DeviceId) -> Result<()> {
373 let presence = get_presence(handle.key()).await?;
375
376 if !presence.devices.iter().any(|d| d.id == device_id) {
378 anyhow::bail!("Device not found in presence");
379 }
380
381 register_presence(handle, presence.devices, device_id).await?;
383 Ok(())
384}
385
386pub async fn store_data(
400 handle: &IdentityHandle,
401 data: Vec<u8>,
402 group_size: usize,
403) -> Result<StorageHandle> {
404 let strategy = StorageStrategy::from_group_size(group_size);
406
407 match strategy {
408 StorageStrategy::Direct => store_direct(handle, data).await,
409 StorageStrategy::FullReplication { replicas } => {
410 store_replicated(handle, data, replicas).await
411 }
412 }
413}
414
415pub async fn store_dyad(
425 handle1: &IdentityHandle,
426 _handle2_key: Key,
427 data: Vec<u8>,
428) -> Result<StorageHandle> {
429 store_replicated(handle1, data, 2).await
431}
432
433pub async fn store_with_fec(
439 handle: &IdentityHandle,
440 data: Vec<u8>,
441 data_shards: usize,
442 parity_shards: usize,
443) -> Result<StorageHandle> {
444 let requested = data_shards.saturating_add(parity_shards).max(1);
445 let replicas = requested.min(MAX_REPLICATION_TARGET);
446 store_replicated(handle, data, replicas).await
447}
448
449pub async fn get_data(handle: &StorageHandle) -> Result<Vec<u8>> {
457 let dht = DHT.read().await;
461 let data = dht.get(&handle.id).await.context("Data not found")?;
462 Ok(data)
463}
464
465fn derive_presence_key(identity_key: Key) -> Key {
471 let mut hasher = blake3::Hasher::new();
472 hasher.update(b"presence:");
473 hasher.update(identity_key.as_bytes());
474 Key::from(*hasher.finalize().as_bytes())
475}
476
477async fn store_direct(handle: &IdentityHandle, data: Vec<u8>) -> Result<StorageHandle> {
479 let storage_id = Key::from(*blake3::hash(&data).as_bytes());
480
481 let presence = get_presence(handle.key()).await?;
483 let device = presence.devices.first().context("No devices available")?;
484 let device_id = device.id;
485
486 dht_put_bytes(&storage_id, data.clone()).await?;
488
489 let mut shard_map = crate::types::storage::ShardMap::new();
490 shard_map.assign_shard(device_id, 0);
491
492 Ok(StorageHandle {
493 id: storage_id,
494 size: data.len() as u64,
495 strategy: StorageStrategy::Direct,
496 shard_map,
497 sealed_key: None,
498 })
499}
500
501pub fn group_identity_canonical_sign_bytes(id: &Key, membership_root: &Key) -> Vec<u8> {
507 let mut out = Vec::with_capacity(16 + 32 + 32);
508 out.extend_from_slice(b"saorsa-group:identity:v1");
509 out.extend_from_slice(id.as_bytes());
510 out.extend_from_slice(membership_root.as_bytes());
511 out
512}
513
514fn compute_membership_root(members: &[MemberRef]) -> Key {
515 let mut ids: Vec<[u8; 32]> = members.iter().map(|m| *m.member_id.as_bytes()).collect();
516 ids.sort_unstable();
517 let mut hasher = blake3::Hasher::new();
518 for id in ids {
519 hasher.update(&id);
520 }
521 Key::from(*hasher.finalize().as_bytes())
522}
523
524pub fn group_identity_create(
526 words: [String; 4],
527 members: Vec<MemberRef>,
528) -> Result<(GroupIdentityPacketV1, GroupKeyPair)> {
529 if !fw_check(words.clone()) {
531 anyhow::bail!("Invalid group words");
532 }
533 let id = fw_to_key(words.clone())?;
534
535 use crate::quantum_crypto::{MlDsa65, MlDsaOperations};
537 let ml = MlDsa65::new();
538 let (group_pk, group_sk) = ml
539 .generate_keypair()
540 .map_err(|e| anyhow::anyhow!("group keypair generation failed: {e:?}"))?;
541
542 let membership_root = compute_membership_root(&members);
544 let msg = group_identity_canonical_sign_bytes(&id, &membership_root);
545 let sig = ml
546 .sign(&group_sk, &msg)
547 .map_err(|e| anyhow::anyhow!("group sign failed: {e:?}"))?;
548
549 let pkt = GroupIdentityPacketV1 {
550 v: 1,
551 words,
552 id: id.clone(),
553 group_pk: group_pk.as_bytes().to_vec(),
554 group_sig: sig.0.to_vec(),
555 members,
556 membership_root,
557 created_at: std::time::SystemTime::now()
558 .duration_since(std::time::UNIX_EPOCH)
559 .unwrap_or_default()
560 .as_secs(),
561 mls_ciphersuite: None,
562 };
563
564 Ok((pkt, GroupKeyPair { group_pk, group_sk }))
565}
566
567pub async fn group_identity_publish(packet: GroupIdentityPacketV1) -> Result<()> {
569 let root = compute_membership_root(&packet.members);
571 if root != packet.membership_root {
572 anyhow::bail!("membership_root mismatch");
573 }
574 use crate::quantum_crypto::{MlDsa65, MlDsaOperations, MlDsaPublicKey, MlDsaSignature};
576 const SIG_LEN: usize = 3309;
577 if packet.group_sig.len() != SIG_LEN {
578 anyhow::bail!("invalid signature length");
579 }
580 let mut sig_arr = [0u8; SIG_LEN];
581 sig_arr.copy_from_slice(&packet.group_sig);
582 let sig = MlDsaSignature(Box::new(sig_arr));
583 let pk = MlDsaPublicKey::from_bytes(&packet.group_pk)
584 .map_err(|_| anyhow::anyhow!("invalid group_pk"))?;
585 let ml = MlDsa65::new();
586 let msg = group_identity_canonical_sign_bytes(&packet.id, &packet.membership_root);
587 let ok = ml
588 .verify(&pk, &msg, &sig)
589 .map_err(|e| anyhow::anyhow!("verify failed: {e:?}"))?;
590 if !ok {
591 anyhow::bail!("group signature invalid");
592 }
593 dht_put_bytes(&packet.id, serde_json::to_vec(&packet)?).await
594}
595
596pub async fn group_identity_fetch(id_key: Key) -> Result<GroupIdentityPacketV1> {
598 let data = dht_get_bytes(&id_key).await.context("Group not found")?;
599 let pkt: GroupIdentityPacketV1 = serde_json::from_slice(&data)?;
600 Ok(pkt)
601}
602
603pub async fn group_identity_update_members_signed(
605 id_key: Key,
606 new_members: Vec<MemberRef>,
607 group_pk: Vec<u8>,
608 group_sig: Sig,
609) -> Result<()> {
610 let new_root = compute_membership_root(&new_members);
612 use crate::quantum_crypto::{MlDsa65, MlDsaOperations, MlDsaPublicKey, MlDsaSignature};
613 const SIG_LEN: usize = 3309;
614 let sig_bytes = group_sig.as_bytes();
615 if sig_bytes.len() != SIG_LEN {
616 anyhow::bail!("invalid signature length");
617 }
618 let mut sig_arr = [0u8; SIG_LEN];
619 sig_arr.copy_from_slice(sig_bytes);
620 let sig = MlDsaSignature(Box::new(sig_arr));
621 let pk =
622 MlDsaPublicKey::from_bytes(&group_pk).map_err(|_| anyhow::anyhow!("invalid group_pk"))?;
623 let ml = MlDsa65::new();
624 let msg = group_identity_canonical_sign_bytes(&id_key, &new_root);
625 let ok = ml
626 .verify(&pk, &msg, &sig)
627 .map_err(|e| anyhow::anyhow!("verify failed: {e:?}"))?;
628 if !ok {
629 anyhow::bail!("group signature invalid");
630 }
631
632 let mut pkt = match group_identity_fetch(id_key.clone()).await {
634 Ok(p) => p,
635 Err(_) => GroupIdentityPacketV1 {
636 v: 1,
637 words: [String::new(), String::new(), String::new(), String::new()],
638 id: id_key.clone(),
639 group_pk: group_pk.clone(),
640 group_sig: sig.0.clone().to_vec(),
641 members: Vec::new(),
642 membership_root: new_root.clone(),
643 created_at: std::time::SystemTime::now()
644 .duration_since(std::time::UNIX_EPOCH)
645 .unwrap_or_default()
646 .as_secs(),
647 mls_ciphersuite: None,
648 },
649 };
650
651 pkt.members = new_members;
652 pkt.membership_root = new_root;
653 pkt.group_pk = group_pk;
654 pkt.group_sig = sig.0.to_vec();
655
656 group_identity_publish(pkt).await
657}
658
659async fn store_replicated(
661 handle: &IdentityHandle,
662 data: Vec<u8>,
663 replicas: usize,
664) -> Result<StorageHandle> {
665 let storage_id = Key::from(*blake3::hash(&data).as_bytes());
666
667 let presence = get_presence(handle.key()).await?;
668
669 if presence.devices.is_empty() {
670 anyhow::bail!("No devices available");
671 }
672
673 let shard_map = build_replication_plan(&presence.devices, replicas);
674
675 dht_put_bytes(&storage_id, data.clone()).await?;
677
678 Ok(StorageHandle {
679 id: storage_id,
680 size: data.len() as u64,
681 strategy: StorageStrategy::FullReplication { replicas },
682 shard_map,
683 sealed_key: None,
684 })
685}
686
687fn build_replication_plan(
688 devices: &[Device],
689 desired_shards: usize,
690) -> crate::types::storage::ShardMap {
691 let mut shard_map = crate::types::storage::ShardMap::new();
692 if devices.is_empty() || desired_shards == 0 {
693 return shard_map;
694 }
695
696 let mut headless_devices: Vec<&Device> = devices
697 .iter()
698 .filter(|d| d.device_type == DeviceType::Headless)
699 .collect();
700 let mut active_devices: Vec<&Device> = devices
701 .iter()
702 .filter(|d| d.device_type == DeviceType::Active)
703 .collect();
704 let mobile_devices: Vec<&Device> = devices
705 .iter()
706 .filter(|d| d.device_type == DeviceType::Mobile)
707 .collect();
708
709 headless_devices.sort_by(|a, b| b.storage_gb.cmp(&a.storage_gb));
710 active_devices.sort_by(|a, b| b.storage_gb.cmp(&a.storage_gb));
711
712 let total_shards = desired_shards;
713 let mut shard_idx = 0u32;
714
715 if !headless_devices.is_empty() {
716 let min_headless_shards = (total_shards * 3).div_ceil(5);
717 let shards_per_headless = min_headless_shards.div_ceil(headless_devices.len());
718
719 for device in &headless_devices {
720 for _ in 0..shards_per_headless {
721 if (shard_idx as usize) < total_shards {
722 shard_map.assign_shard(device.id, shard_idx);
723 shard_idx += 1;
724 }
725 }
726 }
727 }
728
729 for device in &active_devices {
730 if (shard_idx as usize) < total_shards {
731 shard_map.assign_shard(device.id, shard_idx);
732 shard_idx += 1;
733 }
734 }
735
736 if (shard_idx as usize) < total_shards
737 && headless_devices.is_empty()
738 && active_devices.is_empty()
739 {
740 for device in &mobile_devices {
741 if (shard_idx as usize) < total_shards {
742 shard_map.assign_shard(device.id, shard_idx);
743 shard_idx += 1;
744 }
745 }
746 }
747
748 while (shard_idx as usize) < total_shards {
749 let all_devices: Vec<&Device> = headless_devices
750 .iter()
751 .chain(active_devices.iter())
752 .copied()
753 .collect();
754 if all_devices.is_empty() {
755 break;
756 }
757 let device = all_devices[(shard_idx as usize) % all_devices.len()];
758 shard_map.assign_shard(device.id, shard_idx);
759 shard_idx += 1;
760 }
761
762 shard_map
763}