yellowstone_shield_cli/command/
identity.rs

1use std::collections::{HashSet, VecDeque};
2
3use super::{RunCommand, RunResult};
4use crate::{
5    command::{send_batched_tx, CommandContext},
6    policy::PolicyVersion,
7    CommandComplete, LogPolicy, SolanaAccount,
8};
9use borsh::BorshDeserialize;
10
11use solana_pubkey::Pubkey;
12use solana_signer::Signer;
13use spl_associated_token_account::get_associated_token_address_with_program_id;
14use spl_token_2022::{
15    extension::{BaseStateWithExtensions, PodStateWithExtensions},
16    pod::PodMint,
17};
18use spl_token_metadata_interface::state::TokenMetadata;
19
20use yellowstone_shield_client::{
21    accounts::{Policy, PolicyV2},
22    instructions::ReplaceIdentityBuilder,
23    types::Kind,
24};
25use yellowstone_shield_client::{
26    instructions::{AddIdentityBuilder, RemoveIdentityBuilder},
27    PolicyTrait,
28};
29
30const CHUNK_SIZE: usize = 20;
31
32/// Builder for adding a identities to a policy
33#[derive(Debug, Clone)]
34pub struct AddBatchCommandBuilder<'a> {
35    mint: Option<&'a Pubkey>,
36    identities: Option<Vec<Pubkey>>,
37}
38
39impl Default for AddBatchCommandBuilder<'_> {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<'a> AddBatchCommandBuilder<'a> {
46    /// Create a new AddCommandBuilder
47    pub fn new() -> Self {
48        Self {
49            mint: None,
50            identities: None,
51        }
52    }
53
54    /// Set the mint address
55    pub fn mint(mut self, mint: &'a Pubkey) -> Self {
56        self.mint = Some(mint);
57        self
58    }
59
60    /// Set the identities to add
61    pub fn identities(mut self, identities: Vec<Pubkey>) -> Self {
62        self.identities = Some(identities);
63        self
64    }
65}
66
67#[async_trait::async_trait]
68impl RunCommand for AddBatchCommandBuilder<'_> {
69    /// Execute the addition of a identity to the policy
70    async fn run(&mut self, context: CommandContext) -> RunResult {
71        let CommandContext { keypair, client } = context;
72
73        let mint = self.mint.expect("mint must be set");
74
75        // PDA seeds are same for both Policy and PolicyV2
76        let (address, _) = Policy::find_pda(mint);
77        let mut identities = self.identities.take().expect("identities must be set");
78        let mut seen = std::collections::HashSet::new();
79        identities.retain(|pk| seen.insert(*pk));
80
81        let token_account = get_associated_token_address_with_program_id(
82            &keypair.pubkey(),
83            mint,
84            &spl_token_2022::ID,
85        );
86
87        let account_data = client.get_account(&address).await?;
88        let account_data: &[u8] = &account_data.data;
89
90        let policy_version = Kind::try_from_slice(&[account_data[0]])?;
91
92        let current = match policy_version {
93            Kind::Policy => Policy::try_deserialize_identities(account_data),
94            Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data),
95        }?;
96
97        let empty_identity_indices = current
98            .iter()
99            .enumerate()
100            .filter_map(|(idx, p)| {
101                if p == &Pubkey::default() {
102                    return Some(idx);
103                }
104                None
105            })
106            .collect::<Vec<usize>>();
107
108        let mut add_or_replace: Vec<Pubkey> = identities
109            .into_iter()
110            .filter(|identity| !current.contains(identity))
111            .collect();
112
113        let mut replace = Vec::new();
114
115        for i in empty_identity_indices {
116            if let Some(iden) = add_or_replace.pop() {
117                replace.push((i, iden));
118            }
119        }
120
121        // REPLACE
122        send_batched_tx(
123            &client,
124            &keypair,
125            &replace,
126            CHUNK_SIZE,
127            |(idx, identity)| {
128                ReplaceIdentityBuilder::new()
129                    .policy(address)
130                    .mint(*mint)
131                    .token_account(token_account)
132                    .owner(keypair.pubkey())
133                    .identity(*identity)
134                    .index(*idx as u64)
135                    .instruction()
136            },
137        )
138        .await?;
139
140        // ADD
141        send_batched_tx(&client, &keypair, &add_or_replace, CHUNK_SIZE, |identity| {
142            AddIdentityBuilder::new()
143                .policy(address)
144                .mint(*mint)
145                .token_account(token_account)
146                .payer(keypair.pubkey())
147                .owner(keypair.pubkey())
148                .identity(*identity)
149                .instruction()
150        })
151        .await?;
152
153        let account_data = client.get_account(&address).await?;
154        let account_data: &[u8] = &account_data.data;
155
156        let policy_version = Kind::try_from_slice(&[account_data[0]])?;
157
158        let policy = match policy_version {
159            Kind::Policy => PolicyVersion::V1(Policy::from_bytes(&account_data[..Policy::LEN])?),
160            Kind::PolicyV2 => {
161                PolicyVersion::V2(PolicyV2::from_bytes(&account_data[..PolicyV2::LEN])?)
162            }
163        };
164
165        let mint_data = client.get_account(mint).await?;
166        let mint_account_data: &[u8] = &mint_data.data;
167
168        let mint_pod = PodStateWithExtensions::<PodMint>::unpack(mint_account_data).unwrap();
169        let mint_bytes = mint_pod.get_extension_bytes::<TokenMetadata>().unwrap();
170        let token_metadata = TokenMetadata::try_from_slice(mint_bytes).unwrap();
171
172        let identities = match policy_version {
173            Kind::Policy => Policy::try_deserialize_identities(account_data)?,
174            Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
175        };
176
177        LogPolicy::new(mint, &token_metadata, &address, &policy, Some(&identities)).log();
178
179        Ok(CommandComplete(
180            SolanaAccount(*mint, Some(token_metadata)),
181            SolanaAccount(address, Some(policy)),
182        ))
183    }
184}
185
186/// Builder for updating/replacing identities in a policy
187pub struct UpdateBatchCommandBuilder<'a> {
188    mint: Option<&'a Pubkey>,
189    identities: Option<Vec<Pubkey>>,
190}
191
192impl Default for UpdateBatchCommandBuilder<'_> {
193    fn default() -> Self {
194        Self::new()
195    }
196}
197
198impl<'a> UpdateBatchCommandBuilder<'a> {
199    /// Create a new UpdateCommandBuilder
200    pub fn new() -> Self {
201        Self {
202            mint: None,
203            identities: None,
204        }
205    }
206
207    /// Set the mint address
208    pub fn mint(mut self, mint: &'a Pubkey) -> Self {
209        self.mint = Some(mint);
210        self
211    }
212
213    /// Set the identities to replace/update
214    pub fn identities(mut self, identities: Vec<Pubkey>) -> Self {
215        self.identities = Some(identities);
216        self
217    }
218}
219
220#[async_trait::async_trait]
221impl RunCommand for UpdateBatchCommandBuilder<'_> {
222    /// Execute replace/update of identities
223    async fn run(&mut self, context: CommandContext) -> RunResult {
224        let CommandContext { keypair, client } = context;
225
226        let mint = self.mint.expect("mint must be set");
227
228        // PDA seeds are same for both Policy and PolicyV2
229        let (address, _) = Policy::find_pda(mint);
230
231        let mut identities = self.identities.take().expect("identities must be set");
232        let mut seen = std::collections::HashSet::new();
233        identities.retain(|pk| seen.insert(*pk));
234
235        let token_account = get_associated_token_address_with_program_id(
236            &keypair.pubkey(),
237            mint,
238            &spl_token_2022::ID,
239        );
240
241        let account_data = client.get_account(&address).await?;
242        let account_data: &[u8] = &account_data.data;
243
244        let policy_version = Kind::try_from_slice(&[account_data[0]])?;
245
246        let current = match policy_version {
247            Kind::Policy => Policy::try_deserialize_identities(account_data)?,
248            Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
249        };
250
251        let current_set: HashSet<_> = current.iter().collect();
252
253        let mut iden_to_replace_or_add = VecDeque::new();
254        let identities_set: HashSet<_> = identities.iter().collect();
255
256        for i in &identities {
257            if !current_set.contains(&i) {
258                iden_to_replace_or_add.push_back(*i);
259            }
260        }
261
262        let mut iden_to_be_replaced_or_deleted_indices = current
263            .iter()
264            .enumerate()
265            .filter_map(|(idx, p)| {
266                if p == &Pubkey::default() {
267                    return Some((idx, true));
268                }
269                if !identities_set.contains(p) {
270                    return Some((idx, false));
271                }
272                None
273            })
274            .collect::<VecDeque<(usize, bool)>>();
275
276        let len_current = current.len();
277        let len_identities = identities.len();
278
279        // REMOVE if current > identities
280        let len_diff = len_current.saturating_sub(len_identities);
281
282        let mut remove = Vec::new();
283        for _ in 0..len_diff {
284            if let Some((idx, already_deleted)) = iden_to_be_replaced_or_deleted_indices.pop_back()
285            {
286                if !already_deleted {
287                    remove.push(idx);
288                }
289            }
290        }
291
292        let mut replace = Vec::new();
293
294        let min_len = usize::min(
295            iden_to_be_replaced_or_deleted_indices.len(),
296            iden_to_replace_or_add.len(),
297        );
298
299        for i in 0..min_len {
300            let (idx, _) = iden_to_be_replaced_or_deleted_indices[i];
301            let identity = iden_to_replace_or_add[i];
302            replace.push((idx, identity));
303        }
304
305        iden_to_be_replaced_or_deleted_indices.drain(0..min_len);
306        iden_to_replace_or_add.drain(0..min_len);
307
308        let add: Vec<_> = iden_to_replace_or_add.into_iter().collect();
309
310        // REMOVE
311        send_batched_tx(&client, &keypair, &remove, CHUNK_SIZE, |idx| {
312            RemoveIdentityBuilder::new()
313                .policy(address)
314                .mint(*mint)
315                .token_account(token_account)
316                .owner(keypair.pubkey())
317                .index(*idx as u64)
318                .instruction()
319        })
320        .await?;
321
322        // REPLACE
323        send_batched_tx(
324            &client,
325            &keypair,
326            &replace,
327            CHUNK_SIZE,
328            |(idx, identity)| {
329                ReplaceIdentityBuilder::new()
330                    .policy(address)
331                    .mint(*mint)
332                    .token_account(token_account)
333                    .owner(keypair.pubkey())
334                    .identity(*identity)
335                    .index(*idx as u64)
336                    .instruction()
337            },
338        )
339        .await?;
340
341        // ADD
342        send_batched_tx(&client, &keypair, &add, CHUNK_SIZE, |identity| {
343            AddIdentityBuilder::new()
344                .policy(address)
345                .mint(*mint)
346                .token_account(token_account)
347                .payer(keypair.pubkey())
348                .owner(keypair.pubkey())
349                .identity(*identity)
350                .instruction()
351        })
352        .await?;
353
354        let account_data = client.get_account(&address).await?;
355        let account_data: &[u8] = &account_data.data;
356
357        let policy_version = Kind::try_from_slice(&[account_data[0]])?;
358
359        let policy = match policy_version {
360            Kind::Policy => PolicyVersion::V1(Policy::from_bytes(&account_data[..Policy::LEN])?),
361            Kind::PolicyV2 => {
362                PolicyVersion::V2(PolicyV2::from_bytes(&account_data[..PolicyV2::LEN])?)
363            }
364        };
365
366        let mint_data = client.get_account(mint).await?;
367        let mint_account_data: &[u8] = &mint_data.data;
368
369        let mint_pod = PodStateWithExtensions::<PodMint>::unpack(mint_account_data).unwrap();
370        let mint_bytes = mint_pod.get_extension_bytes::<TokenMetadata>().unwrap();
371        let token_metadata = TokenMetadata::try_from_slice(mint_bytes).unwrap();
372
373        let identities = match policy_version {
374            Kind::Policy => Policy::try_deserialize_identities(account_data)?,
375            Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
376        };
377
378        LogPolicy::new(mint, &token_metadata, &address, &policy, Some(&identities)).log();
379
380        Ok(CommandComplete(
381            SolanaAccount(*mint, Some(token_metadata)),
382            SolanaAccount(address, Some(policy)),
383        ))
384    }
385}
386
387/// Builder for removing identities from a policy
388pub struct RemoveBatchCommandBuilder<'a> {
389    mint: Option<&'a Pubkey>,
390    identities: Option<Vec<Pubkey>>,
391}
392
393impl Default for RemoveBatchCommandBuilder<'_> {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399impl<'a> RemoveBatchCommandBuilder<'a> {
400    /// Create a new RemoveCommandBuilder
401    pub fn new() -> Self {
402        Self {
403            mint: None,
404            identities: None,
405        }
406    }
407
408    /// Set the mint address
409    pub fn mint(mut self, mint: &'a Pubkey) -> Self {
410        self.mint = Some(mint);
411        self
412    }
413
414    /// Set the identities to remove
415    pub fn identities(mut self, identities: Vec<Pubkey>) -> Self {
416        self.identities = Some(identities);
417        self
418    }
419}
420
421#[async_trait::async_trait]
422impl RunCommand for RemoveBatchCommandBuilder<'_> {
423    /// Execute the removal of an identity from the policy
424    async fn run(&mut self, context: CommandContext) -> RunResult {
425        let CommandContext { keypair, client } = context;
426
427        let mint = self.mint.expect("mint must be set");
428        // PDA seeds are same for both Policy and PolicyV2
429        let (address, _) = Policy::find_pda(mint);
430
431        let mut identities = self.identities.take().expect("identities must be set");
432        let mut seen = std::collections::HashSet::new();
433        identities.retain(|pk| seen.insert(*pk));
434
435        let token_account = get_associated_token_address_with_program_id(
436            &keypair.pubkey(),
437            mint,
438            &spl_token_2022::ID,
439        );
440
441        let account_data = client.get_account(&address).await?;
442        let account_data: &[u8] = &account_data.data;
443
444        let policy_version = Kind::try_from_slice(&[account_data[0]])?;
445
446        let current = match policy_version {
447            Kind::Policy => Policy::try_deserialize_identities(account_data),
448            Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data),
449        }?;
450
451        let remove: Vec<usize> = identities
452            .into_iter()
453            .filter_map(|identity| {
454                current
455                    .iter()
456                    .position(|&current_identity| current_identity == identity)
457            })
458            .collect();
459
460        send_batched_tx(&client, &keypair, &remove, CHUNK_SIZE, |idx| {
461            RemoveIdentityBuilder::new()
462                .policy(address)
463                .mint(*mint)
464                .token_account(token_account)
465                .owner(keypair.pubkey())
466                .index(*idx as u64)
467                .instruction()
468        })
469        .await?;
470
471        let account_data = client.get_account(&address).await?;
472        let account_data: &[u8] = &account_data.data;
473
474        let policy_version = Kind::try_from_slice(&[account_data[0]])?;
475
476        let policy = match policy_version {
477            Kind::Policy => PolicyVersion::V1(Policy::from_bytes(&account_data[..Policy::LEN])?),
478            Kind::PolicyV2 => {
479                PolicyVersion::V2(PolicyV2::from_bytes(&account_data[..PolicyV2::LEN])?)
480            }
481        };
482
483        let mint_data = client.get_account(mint).await?;
484        let mint_account_data: &[u8] = &mint_data.data;
485
486        let mint_pod = PodStateWithExtensions::<PodMint>::unpack(mint_account_data).unwrap();
487        let mint_bytes = mint_pod.get_extension_bytes::<TokenMetadata>().unwrap();
488        let token_metadata = TokenMetadata::try_from_slice(mint_bytes).unwrap();
489
490        let identities = match policy_version {
491            Kind::Policy => Policy::try_deserialize_identities(account_data)?,
492            Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
493        };
494
495        LogPolicy::new(mint, &token_metadata, &address, &policy, Some(&identities)).log();
496
497        Ok(CommandComplete(
498            SolanaAccount(*mint, Some(token_metadata)),
499            SolanaAccount(address, Some(policy)),
500        ))
501    }
502}