Skip to main content

saorsa_core/
api.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9
10//! Clean API implementation for saorsa-core
11//!
12//! This module provides the simplified public API for:
13//! - Identity registration and management
14//! - Presence and device management
15//! - Storage with replication
16
17use crate::auth::Sig;
18use crate::fwid::{Key, compute_key, fw_check, fw_to_key};
19use crate::types::{
20    Device, DeviceId, DeviceType, Endpoint, Identity, IdentityHandle, MAX_REPLICATION_TARGET,
21    MlDsaKeyPair, Presence, PresenceReceipt, StorageHandle, StorageStrategy,
22};
23use anyhow::{Context, Result};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::sync::Arc;
27use tokio::sync::RwLock as AsyncRwLock;
28use tokio::sync::RwLock;
29// tracing not currently used in this module
30
31// Mock DHT for fallback when no global DHT client is installed
32struct MockDht {
33    storage: HashMap<Key, Vec<u8>>,
34}
35
36impl MockDht {
37    fn new() -> Self {
38        Self {
39            storage: HashMap::new(),
40        }
41    }
42
43    async fn put(&mut self, key: Key, value: Vec<u8>) -> Result<()> {
44        self.storage.insert(key, value);
45        Ok(())
46    }
47
48    async fn get(&self, key: &Key) -> Result<Vec<u8>> {
49        self.storage
50            .get(key)
51            .cloned()
52            .ok_or_else(|| anyhow::anyhow!("Key not found"))
53    }
54}
55
56// Global DHT instance for testing
57static DHT: once_cell::sync::Lazy<Arc<RwLock<MockDht>>> =
58    once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(MockDht::new())));
59
60// Optional global DHT client (real engine). If not set, we fall back to MockDht.
61static GLOBAL_DHT_CLIENT: once_cell::sync::Lazy<
62    AsyncRwLock<Option<Arc<crate::dht::client::DhtClient>>>,
63> = once_cell::sync::Lazy::new(|| AsyncRwLock::new(None));
64
65/// Install a process-global DHT client for API operations.
66pub async fn set_dht_client(client: crate::dht::client::DhtClient) -> bool {
67    let mut guard = GLOBAL_DHT_CLIENT.write().await;
68    let was_empty = guard.is_none();
69    *guard = Some(Arc::new(client));
70    was_empty
71}
72
73/// Remove any configured global DHT client, forcing API calls to use the mock fallback.
74pub async fn clear_dht_client() {
75    GLOBAL_DHT_CLIENT.write().await.take();
76}
77
78async fn get_dht_client_async() -> Option<Arc<crate::dht::client::DhtClient>> {
79    GLOBAL_DHT_CLIENT.read().await.clone()
80}
81
82async fn dht_put_bytes(key: &Key, value: Vec<u8>) -> Result<()> {
83    if let Some(client) = get_dht_client_async().await {
84        let k = hex::encode(key.as_bytes());
85        let _ = client
86            .put(k, value)
87            .await
88            .context("Failed to store data in DHT client")?;
89        Ok(())
90    } else {
91        let mut dht = DHT.write().await;
92        dht.put(key.clone(), value).await
93    }
94}
95
96async fn dht_try_get_bytes(key: &Key) -> Result<Option<Vec<u8>>> {
97    if let Some(client) = get_dht_client_async().await {
98        let k = hex::encode(key.as_bytes());
99        client.get(k).await.context("DHT get failed")
100    } else {
101        let dht = DHT.read().await;
102        match dht.get(key).await {
103            Ok(bytes) => Ok(Some(bytes)),
104            Err(_) => Ok(None),
105        }
106    }
107}
108
109async fn dht_get_bytes(key: &Key) -> Result<Vec<u8>> {
110    match dht_try_get_bytes(key).await? {
111        Some(bytes) => Ok(bytes),
112        None => anyhow::bail!("Key not found"),
113    }
114}
115
116// =============================================================================
117// API-visible record types (minimal, per AGENTS_API.md)
118// =============================================================================
119
120/// Minimal identity packet compatible with Communitas group flows
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct IdentityPacketV1 {
123    pub v: u8,
124    pub words: [String; 4],
125    pub id: Key,
126    pub pk: Vec<u8>,
127    pub sig: Option<Vec<u8>>, // optional when registered locally
128    pub device_set_root: Key,
129}
130
131/// Member reference for group identities
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct MemberRef {
134    pub member_id: Key,
135    pub member_pk: Vec<u8>,
136}
137
138/// Group identity packet (canonical)
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct GroupIdentityPacketV1 {
141    pub v: u8,
142    pub words: [String; 4],
143    pub id: Key,
144    pub group_pk: Vec<u8>,
145    pub group_sig: Vec<u8>,
146    pub members: Vec<MemberRef>,
147    pub membership_root: Key,
148    pub created_at: u64,
149    pub mls_ciphersuite: Option<u16>,
150}
151
152/// Keypair for group signatures
153#[derive(Clone)]
154pub struct GroupKeyPair {
155    pub group_pk: crate::quantum_crypto::MlDsaPublicKey,
156    pub group_sk: crate::quantum_crypto::MlDsaSecretKey,
157}
158
159impl std::fmt::Debug for GroupKeyPair {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        write!(
162            f,
163            "GroupKeyPair {{ group_pk: <{} bytes>, group_sk: <hidden> }}",
164            self.group_pk.as_bytes().len()
165        )
166    }
167}
168
169// ============================================================================
170// IDENTITY API
171// ============================================================================
172
173/// Register a new identity on the network
174///
175/// # Arguments
176/// * `words` - Four-word identifier (must be valid dictionary words)
177/// * `keypair` - ML-DSA keypair for signing
178///
179/// # Returns
180/// * `IdentityHandle` - Handle for identity operations
181pub async fn register_identity(words: [&str; 4], keypair: &MlDsaKeyPair) -> Result<IdentityHandle> {
182    // Convert to owned strings
183    let words_owned: [String; 4] = [
184        words[0].to_string(),
185        words[1].to_string(),
186        words[2].to_string(),
187        words[3].to_string(),
188    ];
189
190    // Validate words
191    if !fw_check(words_owned.clone()) {
192        anyhow::bail!("Invalid word in identity");
193    }
194
195    // Generate key from words
196    let key = fw_to_key(words_owned.clone())?;
197
198    // Check if already registered
199    if dht_try_get_bytes(&key).await?.is_some() {
200        anyhow::bail!("Identity already registered");
201    }
202
203    // Create identity (typed) and store packet for compatibility
204    let identity = Identity {
205        words: words_owned.clone(),
206        key: key.clone(),
207        public_key: keypair.public_key.clone(),
208    };
209
210    let packet = IdentityPacketV1 {
211        v: 1,
212        words: words_owned.clone(),
213        id: key.clone(),
214        pk: keypair.public_key.clone(),
215        sig: None,
216        device_set_root: compute_key("device-set", key.as_bytes()),
217    };
218
219    dht_put_bytes(&key, serde_json::to_vec(&packet)?).await?;
220
221    Ok(IdentityHandle::new(identity, keypair.clone()))
222}
223
224/// Get an identity by its key
225///
226/// # Arguments
227/// * `key` - Identity key (derived from four-word address)
228///
229/// # Returns
230/// * `Identity` - The identity information
231pub async fn get_identity(key: Key) -> Result<Identity> {
232    // Try to read the identity packet and map back to Identity struct
233    let data = dht_get_bytes(&key).await.context("Identity not found")?;
234    if let Ok(pkt) = serde_json::from_slice::<IdentityPacketV1>(&data) {
235        let identity = Identity {
236            words: pkt.words,
237            key: pkt.id,
238            public_key: pkt.pk,
239        };
240        return Ok(identity);
241    }
242    // Fallback: legacy storage of Identity
243    let identity: Identity = serde_json::from_slice(&data)?;
244    Ok(identity)
245}
246
247/// Fetch identity packet in canonical format
248pub async fn identity_fetch(key: Key) -> Result<IdentityPacketV1> {
249    let data = dht_get_bytes(&key).await.context("Identity not found")?;
250    let pkt: IdentityPacketV1 = serde_json::from_slice(&data)?;
251    Ok(pkt)
252}
253
254// ============================================================================
255// PRESENCE API
256// ============================================================================
257
258/// Register presence on the network
259///
260/// # Arguments
261/// * `handle` - Identity handle
262/// * `devices` - List of devices for this identity
263/// * `active_device` - Currently active device ID
264///
265/// # Returns
266/// * `PresenceReceipt` - Receipt of presence registration
267pub async fn register_presence(
268    handle: &IdentityHandle,
269    devices: Vec<Device>,
270    active_device: DeviceId,
271) -> Result<PresenceReceipt> {
272    // Validate active device is in list
273    if !devices.iter().any(|d| d.id == active_device) {
274        anyhow::bail!("Active device not in device list");
275    }
276
277    // Create presence packet
278    let presence = Presence {
279        identity: handle.key(),
280        devices,
281        active_device: Some(active_device),
282        timestamp: std::time::SystemTime::now()
283            .duration_since(std::time::UNIX_EPOCH)?
284            .as_secs(),
285        signature: vec![], // Will be filled
286    };
287
288    // Sign presence
289    let presence_bytes = serde_json::to_vec(&presence)?;
290    let signature = handle.sign(&presence_bytes)?;
291
292    let mut signed_presence = presence;
293    signed_presence.signature = signature;
294
295    // Store in DHT with presence key
296    let presence_key = derive_presence_key(handle.key());
297    dht_put_bytes(&presence_key, serde_json::to_vec(&signed_presence)?).await?;
298
299    // Create receipt
300    let receipt = PresenceReceipt {
301        identity: handle.key(),
302        timestamp: signed_presence.timestamp,
303        storing_nodes: vec![Key::from([0u8; 32])], // Mock node
304    };
305
306    Ok(receipt)
307}
308
309/// Get presence information for an identity
310///
311/// # Arguments
312/// * `identity_key` - Key of the identity
313///
314/// # Returns
315/// * `Presence` - Current presence information
316pub async fn get_presence(identity_key: Key) -> Result<Presence> {
317    let presence_key = derive_presence_key(identity_key);
318    let data = dht_get_bytes(&presence_key)
319        .await
320        .context("Presence not found")?;
321    let presence: Presence = serde_json::from_slice(&data)?;
322    Ok(presence)
323}
324
325/// Register a headless storage node
326///
327/// # Arguments
328/// * `handle` - Identity handle
329/// * `storage_gb` - Storage capacity in GB
330/// * `endpoint` - Network endpoint
331///
332/// # Returns
333/// * `DeviceId` - ID of the registered headless node
334pub async fn register_headless(
335    handle: &IdentityHandle,
336    storage_gb: u32,
337    endpoint: Endpoint,
338) -> Result<DeviceId> {
339    // Get current presence
340    let mut presence = get_presence(handle.key()).await?;
341
342    // Create headless device
343    let device = Device {
344        id: DeviceId::generate(),
345        device_type: crate::types::presence::DeviceType::Headless,
346        storage_gb: storage_gb as u64,
347        endpoint,
348        capabilities: crate::types::presence::DeviceCapabilities {
349            storage_bytes: storage_gb as u64 * 1_000_000_000,
350            always_online: true,
351            supports_fec: true,
352            supports_seal: true,
353            ..Default::default()
354        },
355    };
356
357    let device_id = device.id;
358    presence.devices.push(device);
359
360    // Update presence
361    let active = presence.active_device.unwrap_or(device_id);
362    register_presence(handle, presence.devices, active).await?;
363
364    Ok(device_id)
365}
366
367/// Set the active device for an identity
368///
369/// # Arguments
370/// * `handle` - Identity handle
371/// * `device_id` - Device to make active
372pub async fn set_active_device(handle: &IdentityHandle, device_id: DeviceId) -> Result<()> {
373    // Get current presence
374    let presence = get_presence(handle.key()).await?;
375
376    // Validate device exists
377    if !presence.devices.iter().any(|d| d.id == device_id) {
378        anyhow::bail!("Device not found in presence");
379    }
380
381    // Update with new active device
382    register_presence(handle, presence.devices, device_id).await?;
383    Ok(())
384}
385
386// ============================================================================
387// STORAGE API
388// ============================================================================
389
390/// Store data on the network
391///
392/// # Arguments
393/// * `handle` - Identity handle
394/// * `data` - Data to store
395/// * `group_size` - Size of the group (affects storage strategy)
396///
397/// # Returns
398/// * `StorageHandle` - Handle to retrieve the data
399pub async fn store_data(
400    handle: &IdentityHandle,
401    data: Vec<u8>,
402    group_size: usize,
403) -> Result<StorageHandle> {
404    // Select strategy based on group size
405    let strategy = StorageStrategy::from_group_size(group_size);
406
407    match strategy {
408        StorageStrategy::Direct => store_direct(handle, data).await,
409        StorageStrategy::FullReplication { replicas } => {
410            store_replicated(handle, data, replicas).await
411        }
412    }
413}
414
415/// Store data for a dyad (2-person group)
416///
417/// # Arguments
418/// * `handle1` - First identity handle
419/// * `handle2_key` - Key of second identity
420/// * `data` - Data to store
421///
422/// # Returns
423/// * `StorageHandle` - Handle to retrieve the data
424pub async fn store_dyad(
425    handle1: &IdentityHandle,
426    _handle2_key: Key,
427    data: Vec<u8>,
428) -> Result<StorageHandle> {
429    // For dyads, use full replication (2 copies)
430    store_replicated(handle1, data, 2).await
431}
432
433/// Store data with a custom replication target (legacy FEC API)
434///
435/// This function now interprets `data_shards + parity_shards` as the desired
436/// replica count (clamped to the global maximum) and stores full copies of
437/// the data on the selected devices.
438pub async fn store_with_fec(
439    handle: &IdentityHandle,
440    data: Vec<u8>,
441    data_shards: usize,
442    parity_shards: usize,
443) -> Result<StorageHandle> {
444    let requested = data_shards.saturating_add(parity_shards).max(1);
445    let replicas = requested.min(MAX_REPLICATION_TARGET);
446    store_replicated(handle, data, replicas).await
447}
448
449/// Retrieve data from the network
450///
451/// # Arguments
452/// * `handle` - Storage handle
453///
454/// # Returns
455/// * `Vec<u8>` - The retrieved data
456pub async fn get_data(handle: &StorageHandle) -> Result<Vec<u8>> {
457    // TODO: Handle different strategies (e.g. encrypted blobs, multi-peer retrieval)
458    // For now, just retrieve from DHT
459
460    let dht = DHT.read().await;
461    let data = dht.get(&handle.id).await.context("Data not found")?;
462    Ok(data)
463}
464
465// ============================================================================
466// HELPER FUNCTIONS
467// ============================================================================
468
469/// Derive presence key from identity key
470fn derive_presence_key(identity_key: Key) -> Key {
471    let mut hasher = blake3::Hasher::new();
472    hasher.update(b"presence:");
473    hasher.update(identity_key.as_bytes());
474    Key::from(*hasher.finalize().as_bytes())
475}
476
477/// Store data directly (no redundancy)
478async fn store_direct(handle: &IdentityHandle, data: Vec<u8>) -> Result<StorageHandle> {
479    let storage_id = Key::from(*blake3::hash(&data).as_bytes());
480
481    // Get presence BEFORE acquiring DHT write lock to avoid deadlock
482    let presence = get_presence(handle.key()).await?;
483    let device = presence.devices.first().context("No devices available")?;
484    let device_id = device.id;
485
486    // Store in DHT
487    dht_put_bytes(&storage_id, data.clone()).await?;
488
489    let mut shard_map = crate::types::storage::ShardMap::new();
490    shard_map.assign_shard(device_id, 0);
491
492    Ok(StorageHandle {
493        id: storage_id,
494        size: data.len() as u64,
495        strategy: StorageStrategy::Direct,
496        shard_map,
497        sealed_key: None,
498    })
499}
500
501// ============================================================================
502// GROUP API (per AGENTS_API.md, minimal subset used by Communitas)
503// ============================================================================
504
505/// Canonical bytes for group identity signing: b"saorsa-group:identity:v1" || id || membership_root
506pub fn group_identity_canonical_sign_bytes(id: &Key, membership_root: &Key) -> Vec<u8> {
507    let mut out = Vec::with_capacity(16 + 32 + 32);
508    out.extend_from_slice(b"saorsa-group:identity:v1");
509    out.extend_from_slice(id.as_bytes());
510    out.extend_from_slice(membership_root.as_bytes());
511    out
512}
513
514fn compute_membership_root(members: &[MemberRef]) -> Key {
515    let mut ids: Vec<[u8; 32]> = members.iter().map(|m| *m.member_id.as_bytes()).collect();
516    ids.sort_unstable();
517    let mut hasher = blake3::Hasher::new();
518    for id in ids {
519        hasher.update(&id);
520    }
521    Key::from(*hasher.finalize().as_bytes())
522}
523
524/// Create a canonical group identity and keypair
525pub fn group_identity_create(
526    words: [String; 4],
527    members: Vec<MemberRef>,
528) -> Result<(GroupIdentityPacketV1, GroupKeyPair)> {
529    // Validate words and id
530    if !fw_check(words.clone()) {
531        anyhow::bail!("Invalid group words");
532    }
533    let id = fw_to_key(words.clone())?;
534
535    // Generate ML-DSA group keypair
536    use crate::quantum_crypto::{MlDsa65, MlDsaOperations};
537    let ml = MlDsa65::new();
538    let (group_pk, group_sk) = ml
539        .generate_keypair()
540        .map_err(|e| anyhow::anyhow!("group keypair generation failed: {e:?}"))?;
541
542    // Compute membership root and sign canonical bytes
543    let membership_root = compute_membership_root(&members);
544    let msg = group_identity_canonical_sign_bytes(&id, &membership_root);
545    let sig = ml
546        .sign(&group_sk, &msg)
547        .map_err(|e| anyhow::anyhow!("group sign failed: {e:?}"))?;
548
549    let pkt = GroupIdentityPacketV1 {
550        v: 1,
551        words,
552        id: id.clone(),
553        group_pk: group_pk.as_bytes().to_vec(),
554        group_sig: sig.0.to_vec(),
555        members,
556        membership_root,
557        created_at: std::time::SystemTime::now()
558            .duration_since(std::time::UNIX_EPOCH)
559            .unwrap_or_default()
560            .as_secs(),
561        mls_ciphersuite: None,
562    };
563
564    Ok((pkt, GroupKeyPair { group_pk, group_sk }))
565}
566
567/// Publish a group identity packet under its id key
568pub async fn group_identity_publish(packet: GroupIdentityPacketV1) -> Result<()> {
569    // Basic validation: recompute root and signature check
570    let root = compute_membership_root(&packet.members);
571    if root != packet.membership_root {
572        anyhow::bail!("membership_root mismatch");
573    }
574    // Verify signature
575    use crate::quantum_crypto::{MlDsa65, MlDsaOperations, MlDsaPublicKey, MlDsaSignature};
576    const SIG_LEN: usize = 3309;
577    if packet.group_sig.len() != SIG_LEN {
578        anyhow::bail!("invalid signature length");
579    }
580    let mut sig_arr = [0u8; SIG_LEN];
581    sig_arr.copy_from_slice(&packet.group_sig);
582    let sig = MlDsaSignature(Box::new(sig_arr));
583    let pk = MlDsaPublicKey::from_bytes(&packet.group_pk)
584        .map_err(|_| anyhow::anyhow!("invalid group_pk"))?;
585    let ml = MlDsa65::new();
586    let msg = group_identity_canonical_sign_bytes(&packet.id, &packet.membership_root);
587    let ok = ml
588        .verify(&pk, &msg, &sig)
589        .map_err(|e| anyhow::anyhow!("verify failed: {e:?}"))?;
590    if !ok {
591        anyhow::bail!("group signature invalid");
592    }
593    dht_put_bytes(&packet.id, serde_json::to_vec(&packet)?).await
594}
595
596/// Fetch a group identity by id key
597pub async fn group_identity_fetch(id_key: Key) -> Result<GroupIdentityPacketV1> {
598    let data = dht_get_bytes(&id_key).await.context("Group not found")?;
599    let pkt: GroupIdentityPacketV1 = serde_json::from_slice(&data)?;
600    Ok(pkt)
601}
602
603/// Update group members with signature verification over canonical bytes
604pub async fn group_identity_update_members_signed(
605    id_key: Key,
606    new_members: Vec<MemberRef>,
607    group_pk: Vec<u8>,
608    group_sig: Sig,
609) -> Result<()> {
610    // Compute new root and verify signature
611    let new_root = compute_membership_root(&new_members);
612    use crate::quantum_crypto::{MlDsa65, MlDsaOperations, MlDsaPublicKey, MlDsaSignature};
613    const SIG_LEN: usize = 3309;
614    let sig_bytes = group_sig.as_bytes();
615    if sig_bytes.len() != SIG_LEN {
616        anyhow::bail!("invalid signature length");
617    }
618    let mut sig_arr = [0u8; SIG_LEN];
619    sig_arr.copy_from_slice(sig_bytes);
620    let sig = MlDsaSignature(Box::new(sig_arr));
621    let pk =
622        MlDsaPublicKey::from_bytes(&group_pk).map_err(|_| anyhow::anyhow!("invalid group_pk"))?;
623    let ml = MlDsa65::new();
624    let msg = group_identity_canonical_sign_bytes(&id_key, &new_root);
625    let ok = ml
626        .verify(&pk, &msg, &sig)
627        .map_err(|e| anyhow::anyhow!("verify failed: {e:?}"))?;
628    if !ok {
629        anyhow::bail!("group signature invalid");
630    }
631
632    // Fetch current (if exists) to preserve metadata
633    let mut pkt = match group_identity_fetch(id_key.clone()).await {
634        Ok(p) => p,
635        Err(_) => GroupIdentityPacketV1 {
636            v: 1,
637            words: [String::new(), String::new(), String::new(), String::new()],
638            id: id_key.clone(),
639            group_pk: group_pk.clone(),
640            group_sig: sig.0.clone().to_vec(),
641            members: Vec::new(),
642            membership_root: new_root.clone(),
643            created_at: std::time::SystemTime::now()
644                .duration_since(std::time::UNIX_EPOCH)
645                .unwrap_or_default()
646                .as_secs(),
647            mls_ciphersuite: None,
648        },
649    };
650
651    pkt.members = new_members;
652    pkt.membership_root = new_root;
653    pkt.group_pk = group_pk;
654    pkt.group_sig = sig.0.to_vec();
655
656    group_identity_publish(pkt).await
657}
658
659/// Store data with full replication
660async fn store_replicated(
661    handle: &IdentityHandle,
662    data: Vec<u8>,
663    replicas: usize,
664) -> Result<StorageHandle> {
665    let storage_id = Key::from(*blake3::hash(&data).as_bytes());
666
667    let presence = get_presence(handle.key()).await?;
668
669    if presence.devices.is_empty() {
670        anyhow::bail!("No devices available");
671    }
672
673    let shard_map = build_replication_plan(&presence.devices, replicas);
674
675    // Store in DHT
676    dht_put_bytes(&storage_id, data.clone()).await?;
677
678    Ok(StorageHandle {
679        id: storage_id,
680        size: data.len() as u64,
681        strategy: StorageStrategy::FullReplication { replicas },
682        shard_map,
683        sealed_key: None,
684    })
685}
686
687fn build_replication_plan(
688    devices: &[Device],
689    desired_shards: usize,
690) -> crate::types::storage::ShardMap {
691    let mut shard_map = crate::types::storage::ShardMap::new();
692    if devices.is_empty() || desired_shards == 0 {
693        return shard_map;
694    }
695
696    let mut headless_devices: Vec<&Device> = devices
697        .iter()
698        .filter(|d| d.device_type == DeviceType::Headless)
699        .collect();
700    let mut active_devices: Vec<&Device> = devices
701        .iter()
702        .filter(|d| d.device_type == DeviceType::Active)
703        .collect();
704    let mobile_devices: Vec<&Device> = devices
705        .iter()
706        .filter(|d| d.device_type == DeviceType::Mobile)
707        .collect();
708
709    headless_devices.sort_by(|a, b| b.storage_gb.cmp(&a.storage_gb));
710    active_devices.sort_by(|a, b| b.storage_gb.cmp(&a.storage_gb));
711
712    let total_shards = desired_shards;
713    let mut shard_idx = 0u32;
714
715    if !headless_devices.is_empty() {
716        let min_headless_shards = (total_shards * 3).div_ceil(5);
717        let shards_per_headless = min_headless_shards.div_ceil(headless_devices.len());
718
719        for device in &headless_devices {
720            for _ in 0..shards_per_headless {
721                if (shard_idx as usize) < total_shards {
722                    shard_map.assign_shard(device.id, shard_idx);
723                    shard_idx += 1;
724                }
725            }
726        }
727    }
728
729    for device in &active_devices {
730        if (shard_idx as usize) < total_shards {
731            shard_map.assign_shard(device.id, shard_idx);
732            shard_idx += 1;
733        }
734    }
735
736    if (shard_idx as usize) < total_shards
737        && headless_devices.is_empty()
738        && active_devices.is_empty()
739    {
740        for device in &mobile_devices {
741            if (shard_idx as usize) < total_shards {
742                shard_map.assign_shard(device.id, shard_idx);
743                shard_idx += 1;
744            }
745        }
746    }
747
748    while (shard_idx as usize) < total_shards {
749        let all_devices: Vec<&Device> = headless_devices
750            .iter()
751            .chain(active_devices.iter())
752            .copied()
753            .collect();
754        if all_devices.is_empty() {
755            break;
756        }
757        let device = all_devices[(shard_idx as usize) % all_devices.len()];
758        shard_map.assign_shard(device.id, shard_idx);
759        shard_idx += 1;
760    }
761
762    shard_map
763}