yellowstone_shield_cli/command/
identity.rs1use std::collections::{HashSet, VecDeque};
2
3use super::{RunCommand, RunResult};
4use crate::{
5 command::{send_batched_tx, CommandContext},
6 policy::PolicyVersion,
7 CommandComplete, LogPolicy, SolanaAccount,
8};
9use borsh::BorshDeserialize;
10
11use solana_pubkey::Pubkey;
12use solana_signer::Signer;
13use spl_associated_token_account::get_associated_token_address_with_program_id;
14use spl_token_2022::{
15 extension::{BaseStateWithExtensions, PodStateWithExtensions},
16 pod::PodMint,
17};
18use spl_token_metadata_interface::state::TokenMetadata;
19
20use yellowstone_shield_client::{
21 accounts::{Policy, PolicyV2},
22 instructions::ReplaceIdentityBuilder,
23 types::Kind,
24};
25use yellowstone_shield_client::{
26 instructions::{AddIdentityBuilder, RemoveIdentityBuilder},
27 PolicyTrait,
28};
29
30const CHUNK_SIZE: usize = 20;
31
32#[derive(Debug, Clone)]
34pub struct AddBatchCommandBuilder<'a> {
35 mint: Option<&'a Pubkey>,
36 identities: Option<Vec<Pubkey>>,
37}
38
39impl Default for AddBatchCommandBuilder<'_> {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl<'a> AddBatchCommandBuilder<'a> {
46 pub fn new() -> Self {
48 Self {
49 mint: None,
50 identities: None,
51 }
52 }
53
54 pub fn mint(mut self, mint: &'a Pubkey) -> Self {
56 self.mint = Some(mint);
57 self
58 }
59
60 pub fn identities(mut self, identities: Vec<Pubkey>) -> Self {
62 self.identities = Some(identities);
63 self
64 }
65}
66
67#[async_trait::async_trait]
68impl RunCommand for AddBatchCommandBuilder<'_> {
69 async fn run(&mut self, context: CommandContext) -> RunResult {
71 let CommandContext { keypair, client } = context;
72
73 let mint = self.mint.expect("mint must be set");
74
75 let (address, _) = Policy::find_pda(mint);
77 let mut identities = self.identities.take().expect("identities must be set");
78 let mut seen = std::collections::HashSet::new();
79 identities.retain(|pk| seen.insert(*pk));
80
81 let token_account = get_associated_token_address_with_program_id(
82 &keypair.pubkey(),
83 mint,
84 &spl_token_2022::ID,
85 );
86
87 let account_data = client.get_account(&address).await?;
88 let account_data: &[u8] = &account_data.data;
89
90 let policy_version = Kind::try_from_slice(&[account_data[0]])?;
91
92 let current = match policy_version {
93 Kind::Policy => Policy::try_deserialize_identities(account_data),
94 Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data),
95 }?;
96
97 let empty_identity_indices = current
98 .iter()
99 .enumerate()
100 .filter_map(|(idx, p)| {
101 if p == &Pubkey::default() {
102 return Some(idx);
103 }
104 None
105 })
106 .collect::<Vec<usize>>();
107
108 let mut add_or_replace: Vec<Pubkey> = identities
109 .into_iter()
110 .filter(|identity| !current.contains(identity))
111 .collect();
112
113 let mut replace = Vec::new();
114
115 for i in empty_identity_indices {
116 if let Some(iden) = add_or_replace.pop() {
117 replace.push((i, iden));
118 }
119 }
120
121 send_batched_tx(
123 &client,
124 &keypair,
125 &replace,
126 CHUNK_SIZE,
127 |(idx, identity)| {
128 ReplaceIdentityBuilder::new()
129 .policy(address)
130 .mint(*mint)
131 .token_account(token_account)
132 .owner(keypair.pubkey())
133 .identity(*identity)
134 .index(*idx as u64)
135 .instruction()
136 },
137 )
138 .await?;
139
140 send_batched_tx(&client, &keypair, &add_or_replace, CHUNK_SIZE, |identity| {
142 AddIdentityBuilder::new()
143 .policy(address)
144 .mint(*mint)
145 .token_account(token_account)
146 .payer(keypair.pubkey())
147 .owner(keypair.pubkey())
148 .identity(*identity)
149 .instruction()
150 })
151 .await?;
152
153 let account_data = client.get_account(&address).await?;
154 let account_data: &[u8] = &account_data.data;
155
156 let policy_version = Kind::try_from_slice(&[account_data[0]])?;
157
158 let policy = match policy_version {
159 Kind::Policy => PolicyVersion::V1(Policy::from_bytes(&account_data[..Policy::LEN])?),
160 Kind::PolicyV2 => {
161 PolicyVersion::V2(PolicyV2::from_bytes(&account_data[..PolicyV2::LEN])?)
162 }
163 };
164
165 let mint_data = client.get_account(mint).await?;
166 let mint_account_data: &[u8] = &mint_data.data;
167
168 let mint_pod = PodStateWithExtensions::<PodMint>::unpack(mint_account_data).unwrap();
169 let mint_bytes = mint_pod.get_extension_bytes::<TokenMetadata>().unwrap();
170 let token_metadata = TokenMetadata::try_from_slice(mint_bytes).unwrap();
171
172 let identities = match policy_version {
173 Kind::Policy => Policy::try_deserialize_identities(account_data)?,
174 Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
175 };
176
177 LogPolicy::new(mint, &token_metadata, &address, &policy, Some(&identities)).log();
178
179 Ok(CommandComplete(
180 SolanaAccount(*mint, Some(token_metadata)),
181 SolanaAccount(address, Some(policy)),
182 ))
183 }
184}
185
186pub struct UpdateBatchCommandBuilder<'a> {
188 mint: Option<&'a Pubkey>,
189 identities: Option<Vec<Pubkey>>,
190}
191
192impl Default for UpdateBatchCommandBuilder<'_> {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198impl<'a> UpdateBatchCommandBuilder<'a> {
199 pub fn new() -> Self {
201 Self {
202 mint: None,
203 identities: None,
204 }
205 }
206
207 pub fn mint(mut self, mint: &'a Pubkey) -> Self {
209 self.mint = Some(mint);
210 self
211 }
212
213 pub fn identities(mut self, identities: Vec<Pubkey>) -> Self {
215 self.identities = Some(identities);
216 self
217 }
218}
219
220#[async_trait::async_trait]
221impl RunCommand for UpdateBatchCommandBuilder<'_> {
222 async fn run(&mut self, context: CommandContext) -> RunResult {
224 let CommandContext { keypair, client } = context;
225
226 let mint = self.mint.expect("mint must be set");
227
228 let (address, _) = Policy::find_pda(mint);
230
231 let mut identities = self.identities.take().expect("identities must be set");
232 let mut seen = std::collections::HashSet::new();
233 identities.retain(|pk| seen.insert(*pk));
234
235 let token_account = get_associated_token_address_with_program_id(
236 &keypair.pubkey(),
237 mint,
238 &spl_token_2022::ID,
239 );
240
241 let account_data = client.get_account(&address).await?;
242 let account_data: &[u8] = &account_data.data;
243
244 let policy_version = Kind::try_from_slice(&[account_data[0]])?;
245
246 let current = match policy_version {
247 Kind::Policy => Policy::try_deserialize_identities(account_data)?,
248 Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
249 };
250
251 let current_set: HashSet<_> = current.iter().collect();
252
253 let mut iden_to_replace_or_add = VecDeque::new();
254 let identities_set: HashSet<_> = identities.iter().collect();
255
256 for i in &identities {
257 if !current_set.contains(&i) {
258 iden_to_replace_or_add.push_back(*i);
259 }
260 }
261
262 let mut iden_to_be_replaced_or_deleted_indices = current
263 .iter()
264 .enumerate()
265 .filter_map(|(idx, p)| {
266 if p == &Pubkey::default() {
267 return Some((idx, true));
268 }
269 if !identities_set.contains(p) {
270 return Some((idx, false));
271 }
272 None
273 })
274 .collect::<VecDeque<(usize, bool)>>();
275
276 let len_current = current.len();
277 let len_identities = identities.len();
278
279 let len_diff = len_current.saturating_sub(len_identities);
281
282 let mut remove = Vec::new();
283 for _ in 0..len_diff {
284 if let Some((idx, already_deleted)) = iden_to_be_replaced_or_deleted_indices.pop_back()
285 {
286 if !already_deleted {
287 remove.push(idx);
288 }
289 }
290 }
291
292 let mut replace = Vec::new();
293
294 let min_len = usize::min(
295 iden_to_be_replaced_or_deleted_indices.len(),
296 iden_to_replace_or_add.len(),
297 );
298
299 for i in 0..min_len {
300 let (idx, _) = iden_to_be_replaced_or_deleted_indices[i];
301 let identity = iden_to_replace_or_add[i];
302 replace.push((idx, identity));
303 }
304
305 iden_to_be_replaced_or_deleted_indices.drain(0..min_len);
306 iden_to_replace_or_add.drain(0..min_len);
307
308 let add: Vec<_> = iden_to_replace_or_add.into_iter().collect();
309
310 send_batched_tx(&client, &keypair, &remove, CHUNK_SIZE, |idx| {
312 RemoveIdentityBuilder::new()
313 .policy(address)
314 .mint(*mint)
315 .token_account(token_account)
316 .owner(keypair.pubkey())
317 .index(*idx as u64)
318 .instruction()
319 })
320 .await?;
321
322 send_batched_tx(
324 &client,
325 &keypair,
326 &replace,
327 CHUNK_SIZE,
328 |(idx, identity)| {
329 ReplaceIdentityBuilder::new()
330 .policy(address)
331 .mint(*mint)
332 .token_account(token_account)
333 .owner(keypair.pubkey())
334 .identity(*identity)
335 .index(*idx as u64)
336 .instruction()
337 },
338 )
339 .await?;
340
341 send_batched_tx(&client, &keypair, &add, CHUNK_SIZE, |identity| {
343 AddIdentityBuilder::new()
344 .policy(address)
345 .mint(*mint)
346 .token_account(token_account)
347 .payer(keypair.pubkey())
348 .owner(keypair.pubkey())
349 .identity(*identity)
350 .instruction()
351 })
352 .await?;
353
354 let account_data = client.get_account(&address).await?;
355 let account_data: &[u8] = &account_data.data;
356
357 let policy_version = Kind::try_from_slice(&[account_data[0]])?;
358
359 let policy = match policy_version {
360 Kind::Policy => PolicyVersion::V1(Policy::from_bytes(&account_data[..Policy::LEN])?),
361 Kind::PolicyV2 => {
362 PolicyVersion::V2(PolicyV2::from_bytes(&account_data[..PolicyV2::LEN])?)
363 }
364 };
365
366 let mint_data = client.get_account(mint).await?;
367 let mint_account_data: &[u8] = &mint_data.data;
368
369 let mint_pod = PodStateWithExtensions::<PodMint>::unpack(mint_account_data).unwrap();
370 let mint_bytes = mint_pod.get_extension_bytes::<TokenMetadata>().unwrap();
371 let token_metadata = TokenMetadata::try_from_slice(mint_bytes).unwrap();
372
373 let identities = match policy_version {
374 Kind::Policy => Policy::try_deserialize_identities(account_data)?,
375 Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
376 };
377
378 LogPolicy::new(mint, &token_metadata, &address, &policy, Some(&identities)).log();
379
380 Ok(CommandComplete(
381 SolanaAccount(*mint, Some(token_metadata)),
382 SolanaAccount(address, Some(policy)),
383 ))
384 }
385}
386
387pub struct RemoveBatchCommandBuilder<'a> {
389 mint: Option<&'a Pubkey>,
390 identities: Option<Vec<Pubkey>>,
391}
392
393impl Default for RemoveBatchCommandBuilder<'_> {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399impl<'a> RemoveBatchCommandBuilder<'a> {
400 pub fn new() -> Self {
402 Self {
403 mint: None,
404 identities: None,
405 }
406 }
407
408 pub fn mint(mut self, mint: &'a Pubkey) -> Self {
410 self.mint = Some(mint);
411 self
412 }
413
414 pub fn identities(mut self, identities: Vec<Pubkey>) -> Self {
416 self.identities = Some(identities);
417 self
418 }
419}
420
421#[async_trait::async_trait]
422impl RunCommand for RemoveBatchCommandBuilder<'_> {
423 async fn run(&mut self, context: CommandContext) -> RunResult {
425 let CommandContext { keypair, client } = context;
426
427 let mint = self.mint.expect("mint must be set");
428 let (address, _) = Policy::find_pda(mint);
430
431 let mut identities = self.identities.take().expect("identities must be set");
432 let mut seen = std::collections::HashSet::new();
433 identities.retain(|pk| seen.insert(*pk));
434
435 let token_account = get_associated_token_address_with_program_id(
436 &keypair.pubkey(),
437 mint,
438 &spl_token_2022::ID,
439 );
440
441 let account_data = client.get_account(&address).await?;
442 let account_data: &[u8] = &account_data.data;
443
444 let policy_version = Kind::try_from_slice(&[account_data[0]])?;
445
446 let current = match policy_version {
447 Kind::Policy => Policy::try_deserialize_identities(account_data),
448 Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data),
449 }?;
450
451 let remove: Vec<usize> = identities
452 .into_iter()
453 .filter_map(|identity| {
454 current
455 .iter()
456 .position(|¤t_identity| current_identity == identity)
457 })
458 .collect();
459
460 send_batched_tx(&client, &keypair, &remove, CHUNK_SIZE, |idx| {
461 RemoveIdentityBuilder::new()
462 .policy(address)
463 .mint(*mint)
464 .token_account(token_account)
465 .owner(keypair.pubkey())
466 .index(*idx as u64)
467 .instruction()
468 })
469 .await?;
470
471 let account_data = client.get_account(&address).await?;
472 let account_data: &[u8] = &account_data.data;
473
474 let policy_version = Kind::try_from_slice(&[account_data[0]])?;
475
476 let policy = match policy_version {
477 Kind::Policy => PolicyVersion::V1(Policy::from_bytes(&account_data[..Policy::LEN])?),
478 Kind::PolicyV2 => {
479 PolicyVersion::V2(PolicyV2::from_bytes(&account_data[..PolicyV2::LEN])?)
480 }
481 };
482
483 let mint_data = client.get_account(mint).await?;
484 let mint_account_data: &[u8] = &mint_data.data;
485
486 let mint_pod = PodStateWithExtensions::<PodMint>::unpack(mint_account_data).unwrap();
487 let mint_bytes = mint_pod.get_extension_bytes::<TokenMetadata>().unwrap();
488 let token_metadata = TokenMetadata::try_from_slice(mint_bytes).unwrap();
489
490 let identities = match policy_version {
491 Kind::Policy => Policy::try_deserialize_identities(account_data)?,
492 Kind::PolicyV2 => PolicyV2::try_deserialize_identities(account_data)?,
493 };
494
495 LogPolicy::new(mint, &token_metadata, &address, &policy, Some(&identities)).log();
496
497 Ok(CommandComplete(
498 SolanaAccount(*mint, Some(token_metadata)),
499 SolanaAccount(address, Some(policy)),
500 ))
501 }
502}