1use std::sync::Arc;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use futures::StreamExt;
6use hashbrown::{HashMap, HashSet};
7use parking_lot::RwLock;
8use serde::Deserialize;
9
10use solana_account_decoder_client_types::UiAccountEncoding;
11use solana_client::{
12 nonblocking::rpc_client::RpcClient,
13 rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig},
14};
15use solana_commitment_config::CommitmentConfig;
16use solana_pubkey::Pubkey;
17use std::time::Duration;
18use yellowstone_grpc_client::{ClientTlsConfig, GeyserGrpcClient};
19use yellowstone_grpc_proto::geyser::{
20 subscribe_update::UpdateOneof, CommitmentLevel, SubscribeRequest,
21 SubscribeRequestFilterAccounts,
22};
23use yellowstone_shield_parser::accounts::{parse_account, PermissionStrategy, Policy, ShieldProgramState, ID as PROGRAM_ID};
24
25pub struct SlotCacheItem<T> {
26 slot: u64,
27 item: T,
28}
29
30pub struct PolicyCache {
32 policies: RwLock<HashMap<Pubkey, SlotCacheItem<Policy>>>,
35}
36
37impl Default for PolicyCache {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl PolicyCache {
44 pub fn new() -> Self {
50 Self {
51 policies: RwLock::new(HashMap::new()),
52 }
53 }
54
55 pub fn insert(&self, pubkey: Pubkey, slot: u64, item: Policy) {
64 let mut policies = self.policies.write();
65 if let Some(current_item) = policies.get(&pubkey) {
66 if slot > current_item.slot {
67 policies.insert(pubkey, SlotCacheItem { slot, item });
68 }
69 } else {
70 policies.insert(pubkey, SlotCacheItem { slot, item });
71 }
72 }
73
74 pub fn get(&self, pubkey: &Pubkey) -> Option<Policy> {
85 self.policies
86 .read()
87 .get(pubkey)
88 .map(|item| item.item.clone())
89 }
90
91 pub fn remove(&self, pubkey: &Pubkey) -> Option<()> {
101 self.policies.write().remove(pubkey).map(|_| ())
102 }
103
104 pub fn all(&self) -> Vec<(Pubkey, Policy)> {
111 self.policies
112 .read()
113 .iter()
114 .map(|(k, item)| (*k, item.item.clone()))
115 .collect()
116 }
117}
118
119#[derive(Debug, thiserror::Error, PartialEq, Eq)]
120pub enum CheckError {
121 #[error("Policy not found")]
122 PolicyNotFound,
123}
124
125#[derive(Default)]
132pub struct Snapshot {
133 lookup: HashSet<(Pubkey, Pubkey)>,
135 strategies: HashMap<Pubkey, PermissionStrategy>,
136}
137
138impl Snapshot {
139 pub fn new(cache: &PolicyCache) -> Self {
150 let mut lookup = HashSet::new();
151 let mut strategies = HashMap::new();
152
153 for (address, policy) in cache.all().iter() {
154 strategies.insert(*address, policy.strategy);
155 for identity in &policy.identities {
156 lookup.insert((*address, *identity));
157 }
158 }
159
160 Self { lookup, strategies }
161 }
162
163 pub fn is_allowed(&self, policies: &[Pubkey], identity: &Pubkey) -> Result<bool, CheckError> {
197 let mut not_found = true;
198
199 for address in policies.iter() {
200 if let Some(strategy) = self.strategies.get(address) {
201 if self.lookup.contains(&(*address, *identity)) {
202 match strategy {
203 PermissionStrategy::Deny => {
204 return Ok(false);
205 }
206 PermissionStrategy::Allow => {
207 return Ok(true);
208 }
209 }
210 } else if let PermissionStrategy::Allow = strategy {
211 not_found = false;
212 }
213 } else {
214 return Err(CheckError::PolicyNotFound);
215 }
216 }
217
218 Ok(not_found)
219 }
220}
221
222#[derive(Debug, Default)]
223pub struct SlotRpcResponse<T> {
224 slot: u64,
225 result: T,
226}
227
228pub type PoliciesSlotRpcResponse = SlotRpcResponse<Vec<(Pubkey, Policy)>>;
229pub struct PolicyRpcClient(RpcClient);
230
231impl PolicyRpcClient {
232 pub fn new(client: RpcClient) -> Self {
233 Self(client)
234 }
235
236 pub async fn list(&self, program_id: &Pubkey) -> Result<PoliciesSlotRpcResponse> {
237 let slot = self.0.get_slot().await?;
238
239 let result = self
240 .0
241 .get_program_accounts_with_config(
242 program_id,
243 RpcProgramAccountsConfig {
244 account_config: RpcAccountInfoConfig {
245 encoding: Some(UiAccountEncoding::Base64),
246 commitment: Some(CommitmentConfig::confirmed()),
247 ..Default::default()
248 },
249 ..Default::default()
250 },
251 )
252 .await?
253 .into_iter()
254 .filter_map(|(address, account)| {
255 let data: &[u8] = &account.data;
256 let owner = &account.owner;
257
258 match parse_account(slot, address, owner, data, Some(program_id)) {
259 Ok(ShieldProgramState::Policy(_slot, _pubkey, policy)) => {
260 Some((address, policy))
261 }
262 Err(e) => {
263 log::warn!("Failed to parse policy account {}: {}", address, e);
264 None
265 }
266 }
267 })
268 .collect::<Vec<_>>();
269
270 Ok(SlotRpcResponse { slot, result })
271 }
272}
273
274impl From<PoliciesSlotRpcResponse> for PolicyCache {
275 fn from(response: PoliciesSlotRpcResponse) -> Self {
276 let cache = Self::new();
277
278 for (address, policy) in response.result.into_iter() {
279 cache.insert(address, response.slot, policy);
280 }
281
282 cache
283 }
284}
285
286pub trait PolicyStoreTrait {
287 fn snapshot(&self) -> Arc<Snapshot>;
288}
289
290pub struct PolicyStore {
292 snapshot: Arc<ArcSwap<Snapshot>>,
294}
295
296impl PolicyStore {
297 pub fn new(snapshot: Arc<ArcSwap<Snapshot>>) -> Self {
307 Self { snapshot }
308 }
309}
310
311impl PolicyStoreTrait for PolicyStore {
312 fn snapshot(&self) -> Arc<Snapshot> {
313 self.snapshot.load_full()
314 }
315}
316
317pub struct MockPolicyStore {
319 snapshot: Arc<Snapshot>,
320}
321
322impl MockPolicyStore {
323 pub fn new(snapshot: Arc<Snapshot>) -> Self {
333 Self { snapshot }
334 }
335}
336
337impl PolicyStoreTrait for MockPolicyStore {
338 fn snapshot(&self) -> Arc<Snapshot> {
339 Arc::clone(&self.snapshot)
340 }
341}
342
343pub type SubscriptionTask = std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'static>>;
344
345#[derive(Deserialize, Clone)]
346pub struct PolicyStoreRpcConfig {
347 pub endpoint: String,
348}
349
350#[derive(Deserialize, Clone)]
351pub struct PolicyStoreGrpcConfig {
352 pub endpoint: String,
353
354 #[serde(default = "default_commitment")]
355 pub commitment: Option<ShieldStoreCommitmentLevel>,
356
357 pub x_token: Option<String>,
358
359 #[serde(with = "humantime_serde", default = "default_timeout")]
360 pub timeout: Duration,
361
362 #[serde(with = "humantime_serde", default = "default_connect_timeout")]
363 pub connect_timeout: Duration,
364
365 #[serde(default = "default_tcp_nodelay")]
366 pub tcp_nodelay: bool,
367
368 #[serde(default = "default_http2_adaptive_window")]
369 pub http2_adaptive_window: bool,
370
371 #[serde(default = "default_http2_keep_alive")]
372 pub http2_keep_alive: bool,
373
374 #[serde(with = "humantime_serde")]
375 pub http2_keep_alive_interval: Option<Duration>,
376
377 #[serde(with = "humantime_serde")]
378 pub http2_keep_alive_timeout: Option<Duration>,
379
380 pub http2_keep_alive_while_idle: Option<bool>,
381
382 #[serde(default = "default_max_decoding_message_size")]
383 pub max_decoding_message_size: Option<usize>,
384
385 pub initial_connection_window_size: Option<u32>,
386
387 pub initial_stream_window_size: Option<u32>,
388}
389
390fn default_commitment() -> Option<ShieldStoreCommitmentLevel> {
391 Some(ShieldStoreCommitmentLevel::Confirmed)
392}
393
394fn default_timeout() -> Duration {
395 Duration::from_secs(60)
396}
397
398fn default_connect_timeout() -> Duration {
399 Duration::from_secs(10)
400}
401
402fn default_tcp_nodelay() -> bool {
403 true
404}
405
406fn default_max_decoding_message_size() -> Option<usize> {
407 Some(2u32.pow(24) as usize) }
409
410fn default_http2_adaptive_window() -> bool {
411 true
412}
413
414fn default_http2_keep_alive() -> bool {
415 false
416}
417
418#[derive(Debug, Clone, PartialEq, Eq)]
419pub enum ShieldStoreCommitmentLevel {
420 Processed,
421 Confirmed,
422 Finalized,
423}
424
425impl From<ShieldStoreCommitmentLevel> for CommitmentLevel {
426 fn from(def: ShieldStoreCommitmentLevel) -> Self {
427 match def {
428 ShieldStoreCommitmentLevel::Processed => CommitmentLevel::Processed,
429 ShieldStoreCommitmentLevel::Confirmed => CommitmentLevel::Confirmed,
430 ShieldStoreCommitmentLevel::Finalized => CommitmentLevel::Finalized,
431 }
432 }
433}
434
435impl<'de> Deserialize<'de> for ShieldStoreCommitmentLevel {
436 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
437 where
438 D: serde::Deserializer<'de>,
439 {
440 let s: &str = serde::Deserialize::deserialize(deserializer)?;
441 match s {
442 "processed" => Ok(ShieldStoreCommitmentLevel::Processed),
443 "confirmed" => Ok(ShieldStoreCommitmentLevel::Confirmed),
444 "finalized" => Ok(ShieldStoreCommitmentLevel::Finalized),
445 _ => Err(serde::de::Error::custom(format!(
446 "Invalid commitment level: {}",
447 s
448 ))),
449 }
450 }
451}
452
453#[derive(Deserialize, Clone)]
454pub struct PolicyStoreConfig {
455 pub rpc: PolicyStoreRpcConfig,
456 pub grpc: PolicyStoreGrpcConfig,
457}
458
459#[derive(Debug, thiserror::Error)]
460pub enum StoreError {
461 #[error("No config provided")]
462 NoConfig,
463 #[error("Unable to deserialize policy")]
464 DeserializePolicy,
465 #[error("RPC error: {0}")]
466 RpcError(String),
467 #[error("gRPC client error: {0}")]
468 GrpcClientError(String),
469 #[error("gRPC connection error: {0}")]
470 GrpcConnectionError(String),
471 #[error("gRPC subscription error: {0}")]
472 GrpcSubscriptionError(String),
473}
474
475pub type StoreResult<T> = std::result::Result<T, StoreError>;
476
477impl From<solana_client::client_error::ClientError> for StoreError {
478 fn from(e: solana_client::client_error::ClientError) -> Self {
479 StoreError::RpcError(e.to_string())
480 }
481}
482
483#[derive(Default)]
484pub struct PolicyStoreBuilder {
485 config: Option<PolicyStoreConfig>,
486}
487
488impl PolicyStoreBuilder {
489 pub fn config(&mut self, config: PolicyStoreConfig) -> &mut Self {
490 self.config = Some(config);
491
492 self
493 }
494
495 pub async fn run(&mut self) -> StoreResult<PolicyStore> {
496 let config = self.config.take().ok_or(StoreError::NoConfig)?;
497 let rpc = RpcClient::new(config.rpc.endpoint);
498
499 let policies = PolicyRpcClient::new(rpc).list(&PROGRAM_ID).await.map_err(|e| StoreError::RpcError(e.to_string()))?;
500
501 let cache = Arc::new(policies.into());
502 let snapshot = Arc::new(ArcSwap::from_pointee(Snapshot::new(&cache)));
503
504 let mut builder = GeyserGrpcClient::build_from_shared(config.grpc.endpoint.clone())
506 .map_err(|e| {
507 StoreError::GrpcClientError(format!("Failed to build gRPC client: {}", e))
508 })?
509 .connect_timeout(config.grpc.connect_timeout)
510 .timeout(config.grpc.timeout);
511
512 if config.grpc.tcp_nodelay {
513 builder = builder.tcp_nodelay(true);
514 }
515
516 if config.grpc.http2_adaptive_window {
517 builder = builder.http2_adaptive_window(true);
518 }
519
520 builder = builder.tls_config(ClientTlsConfig::new().with_native_roots()).expect("Failed to set TLS config");
521
522 if config.grpc.http2_keep_alive {
524 if let Some(interval) = config.grpc.http2_keep_alive_interval {
525 builder = builder.http2_keep_alive_interval(interval);
526 }
527
528 if let Some(timeout) = config.grpc.http2_keep_alive_timeout {
529 builder = builder.keep_alive_timeout(timeout);
530 }
531
532 if let Some(while_idle) = config.grpc.http2_keep_alive_while_idle {
533 builder = builder.keep_alive_while_idle(while_idle);
534 }
535 }
536
537 if let Some(max_size) = config.grpc.max_decoding_message_size {
538 builder = builder.max_decoding_message_size(max_size)
539 }
540
541 if let Some(window_size) = config.grpc.initial_connection_window_size {
542 builder = builder.initial_connection_window_size(window_size);
543 }
544
545 if let Some(stream_window_size) = config.grpc.initial_stream_window_size {
546 builder = builder.initial_stream_window_size(stream_window_size);
547 }
548
549 let builder = if let Some(ref token) = config.grpc.x_token {
551 builder
552 .x_token(Some(token.clone()))
553 .map_err(|e| StoreError::GrpcClientError(format!("Failed to set x-token: {}", e)))?
554 } else {
555 builder
556 };
557
558 let mut client = builder.connect().await.map_err(|e| {
559 StoreError::GrpcConnectionError(format!("Failed to connect to gRPC server: {}", e))
560 })?;
561
562 log::info!("Connected to gRPC endpoint: {}", config.grpc.endpoint);
563
564 let mut accounts = std::collections::HashMap::new();
566 accounts.insert(
567 "".to_string(),
568 SubscribeRequestFilterAccounts {
569 account: vec![],
570 owner: vec![PROGRAM_ID.to_string()],
571 filters: vec![],
572 nonempty_txn_signature: None,
573 },
574 );
575
576 let subscribe_request = SubscribeRequest {
577 accounts,
578 ..Default::default()
579 };
580
581 let mut stream = client
582 .subscribe_once(subscribe_request)
583 .await
584 .map_err(|e| {
585 StoreError::GrpcSubscriptionError(format!(
586 "Failed to subscribe to gRPC stream: {}",
587 e
588 ))
589 })?;
590
591 log::info!("Subscribed to Shield program account updates");
592
593 let cache_clone = Arc::clone(&cache);
595 let snapshot_clone = Arc::clone(&snapshot);
596
597 tokio::spawn(async move {
598 while let Some(message) = stream.next().await {
599 match message {
600 Ok(msg) => {
601 if let Some(UpdateOneof::Account(account_update)) = msg.update_oneof {
602 if let Some(account) = account_update.account {
604 let pubkey_bytes: [u8; 32] = match account.pubkey.try_into() {
605 Ok(bytes) => bytes,
606 Err(_) => {
607 log::warn!("Invalid pubkey length in account update");
608 continue;
609 }
610 };
611 let pubkey = Pubkey::from(pubkey_bytes);
612
613 let owner_bytes: [u8; 32] = match account.owner.try_into() {
614 Ok(bytes) => bytes,
615 Err(_) => {
616 log::warn!("Invalid owner length in account update");
617 continue;
618 }
619 };
620 let owner = Pubkey::from(owner_bytes);
621
622 match parse_account(
624 account_update.slot,
625 pubkey,
626 &owner,
627 &account.data,
628 Some(&PROGRAM_ID),
629 ) {
630 Ok(ShieldProgramState::Policy(slot, pubkey, policy)) => {
631 cache_clone.insert(pubkey, slot, policy);
632 snapshot_clone.store(Arc::new(Snapshot::new(&cache_clone)));
633 log::debug!("Updated policy for pubkey: {}", pubkey);
634 }
635 Err(e) => {
636 log::warn!("Failed to parse account update: {}", e);
637 }
638 }
639 }
640 }
641 }
642 Err(e) => {
643 log::error!("Error receiving gRPC message: {}", e);
644 break;
645 }
646 }
647 }
648
649 log::warn!("gRPC stream ended");
650 });
651
652 Ok(PolicyStore::new(snapshot))
653 }
654}
655
656impl PolicyStore {
657 pub fn build() -> PolicyStoreBuilder {
658 PolicyStoreBuilder::default()
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use solana_pubkey::Pubkey;
666 use yellowstone_shield_parser::accounts::Policy;
667
668 #[test]
669 fn test_policy_cache_insert_and_get() {
670 let cache = PolicyCache::new();
671 let address = Pubkey::new_unique();
672 let validator = Pubkey::new_unique();
673 let policy = Policy::new(PermissionStrategy::Deny, vec![validator]);
674
675 cache.insert(address, 1, policy.clone());
676 let retrieved_policy = cache.get(&address).unwrap();
677
678 assert_eq!(retrieved_policy.strategy, policy.strategy);
679 assert_eq!(retrieved_policy.identities, policy.identities);
680 }
681
682 #[test]
683 fn test_policy_cache_all() {
684 let cache = PolicyCache::new();
685 let validator = Pubkey::new_unique();
686
687 let policies = [
688 (
689 Pubkey::new_unique(),
690 Policy::new(PermissionStrategy::Deny, vec![validator]),
691 ),
692 (
693 Pubkey::new_unique(),
694 Policy::new(PermissionStrategy::Allow, vec![validator]),
695 ),
696 ];
697
698 for (pubkey, policy) in policies.iter() {
699 cache.insert(*pubkey, 1, policy.clone());
700 }
701
702 let policies = cache.all();
703 assert_eq!(policies.len(), 2);
704 }
705
706 #[test]
707 fn test_policy_cache_remove() {
708 let cache = PolicyCache::new();
709 let address = Pubkey::new_unique();
710 let validator = Pubkey::new_unique();
711 let policy = Policy::new(PermissionStrategy::Deny, vec![validator]);
712
713 cache.insert(address, 1, policy.clone());
714 cache.remove(&address).unwrap();
715
716 assert!(cache.get(&address).is_none());
717 }
718
719 #[test]
720 fn test_snapshot_is_allowed() {
721 let cache = PolicyCache::new();
722
723 let deny = Pubkey::new_unique();
724 let allow = Pubkey::new_unique();
725 let missing = Pubkey::new_unique();
726
727 let good = Pubkey::new_unique();
728 let other = Pubkey::new_unique();
729 let sanctioned = Pubkey::new_unique();
730 let sandwich = Pubkey::new_unique();
731
732 let policies = [
733 (allow, Policy::new(PermissionStrategy::Allow, vec![good])),
734 (
735 deny,
736 Policy::new(PermissionStrategy::Deny, vec![sanctioned, sandwich]),
737 ),
738 ];
739
740 for (address, policy) in policies.into_iter() {
741 cache.insert(address, 1, policy.clone());
742 }
743 let snapshot = Snapshot::new(&cache);
744
745 assert_eq!(
746 snapshot.is_allowed(&[missing], &good),
747 Err(CheckError::PolicyNotFound)
748 );
749 assert_eq!(
750 snapshot.is_allowed(&[missing, allow], &good),
751 Err(CheckError::PolicyNotFound)
752 );
753 assert_eq!(snapshot.is_allowed(&[allow, missing], &good), Ok(true));
754 assert_eq!(
755 snapshot.is_allowed(&[deny, missing], &good),
756 Err(CheckError::PolicyNotFound)
757 );
758
759 assert_eq!(snapshot.is_allowed(&[deny], &sanctioned), Ok(false));
760 assert_eq!(snapshot.is_allowed(&[deny], &sandwich), Ok(false));
761 assert_eq!(snapshot.is_allowed(&[deny], &good), Ok(true));
762 assert_eq!(snapshot.is_allowed(&[deny], &other), Ok(true));
763
764 assert_eq!(snapshot.is_allowed(&[allow], &good), Ok(true));
765 assert_eq!(snapshot.is_allowed(&[allow], &sanctioned), Ok(false));
766 assert_eq!(snapshot.is_allowed(&[allow], &sandwich), Ok(false));
767 assert_eq!(snapshot.is_allowed(&[allow], &other), Ok(false));
768
769 assert_eq!(snapshot.is_allowed(&[allow, deny], &other), Ok(false));
770 assert_eq!(snapshot.is_allowed(&[allow, deny], &good), Ok(true));
771 assert_eq!(snapshot.is_allowed(&[allow, deny], &sandwich), Ok(false));
772
773 assert_eq!(snapshot.is_allowed(&[deny, allow], &other), Ok(false));
774 assert_eq!(snapshot.is_allowed(&[deny, allow], &good), Ok(true));
775 assert_eq!(snapshot.is_allowed(&[deny, allow], &sandwich), Ok(false));
776 }
777
778 #[test]
779 fn test_mock_policy_store() {
780 let snapshot = Arc::new(Snapshot::default());
781
782 let store = MockPolicyStore::new(Arc::clone(&snapshot));
783
784 let fetched = store.snapshot();
785
786 assert!(std::sync::Arc::ptr_eq(&fetched, &snapshot));
787 }
788}