1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;

use async_nats::jetstream::kv::{Entry, EntryError, Operation};
use async_nats::{jetstream::kv::Store, Client};
use futures::StreamExt;
use futures::TryStreamExt;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::{debug, error};

use crate::LinkDefinition;
use crate::Result;

use super::{
    delete_link, ld_hash, ld_hash_raw, put_link, KvStore, CLAIMS_PREFIX, LINKDEF_PREFIX,
    SUBJECT_KEY,
};

type ClaimsMap = HashMap<String, HashMap<String, String>>;

#[derive(Clone, Debug)]
pub struct CachedKvStore {
    store: Store,
    linkdefs: Arc<RwLock<HashMap<String, LinkDefinition>>>,
    claims: Arc<RwLock<ClaimsMap>>,
    handle: Arc<JoinHandle<()>>,
}

impl Drop for CachedKvStore {
    fn drop(&mut self) {
        self.handle.abort();
    }
}

impl AsRef<Store> for CachedKvStore {
    fn as_ref(&self) -> &Store {
        &self.store
    }
}

impl Deref for CachedKvStore {
    type Target = Store;

    fn deref(&self) -> &Self::Target {
        &self.store
    }
}

impl CachedKvStore {
    /// Create a new KV store with the given configuration. This function will do an initial fetch
    /// of all claims and linkdefs from the store and then start a watcher to keep the cache up to
    /// date. All data fetched from this store will be from the in memory cache
    pub async fn new(nc: Client, lattice_prefix: &str, js_domain: Option<String>) -> Result<Self> {
        let store = super::get_kv_store(nc, lattice_prefix, js_domain).await?;
        let linkdefs = Arc::new(RwLock::new(HashMap::new()));
        let claims = Arc::new(RwLock::new(ClaimsMap::default()));
        let linkdefs_clone = linkdefs.clone();
        let claims_clone = claims.clone();
        let cloned_store = store.clone();
        let (tx, rx) = tokio::sync::oneshot::channel::<Result<()>>();
        let kvstore = CachedKvStore {
            store,
            linkdefs,
            claims,
            handle: Arc::new(tokio::spawn(async move {
                // We have to create this in here and use the oneshot to return the error because of
                // lifetimes
                let mut watcher = match cloned_store.watch_all().await {
                    // NOTE(thomastaylor312) We are unwrapping the sends here because it only fails
                    // if the rx has hung up. Since we are literally using it down below in the new
                    // function, this shouldn't happen and if it does it is programmer error
                    Ok(w) => w,
                    Err(e) => {
                        error!(error = %e, "Unable to start watcher");
                        tx.send(Err(e.into())).unwrap();
                        return;
                    }
                };
                // Start with an initial list of the data before consuming events from the watcher.
                // This will ensure we have the most up to date data from the watcher (which we
                // started before this step) as well as all entries from the store
                let keys = match cloned_store.keys().await {
                    Ok(k) => k,
                    Err(e) => {
                        error!(error = %e, "Unable to get keys from store");
                        tx.send(Err(e.into())).unwrap();
                        return;
                    }
                };

                let futs = match keys
                    .map_ok(|k| cloned_store.entry(k))
                    .try_collect::<Vec<_>>()
                    .await
                {
                    Ok(f) => f,
                    Err(e) => {
                        error!(error = %e, "Unable to get keys from store");
                        tx.send(Err(e.into())).unwrap();
                        return;
                    }
                };

                let all_entries = match futures::future::join_all(futs)
                    .await
                    .into_iter()
                    .filter_map(|res| res.transpose())
                    .collect::<std::result::Result<Vec<_>, EntryError>>()
                {
                    Ok(entries) => entries,
                    Err(e) => {
                        error!(error = %e, "Unable to get values from store");
                        tx.send(Err(e.into())).unwrap();
                        return;
                    }
                };

                tx.send(Ok(())).unwrap();

                for entry in all_entries {
                    handle_entry(entry, linkdefs_clone.clone(), claims_clone.clone()).await;
                }

                while let Some(event) = watcher.next().await {
                    let entry = match event {
                        Ok(en) => en,
                        Err(e) => {
                            error!(error = %e, "Error from latticedata watcher");
                            continue;
                        }
                    };
                    handle_entry(entry, linkdefs_clone.clone(), claims_clone.clone()).await;
                }
                // NOTE(thomastaylor312): We should probably do something to automatically restart
                // the watch if something fails. But for now this should be ok
                error!("Cache watcher has exited");
            })),
        };
        rx.await??;
        Ok(kvstore)
    }
}

#[async_trait::async_trait]
impl KvStore for CachedKvStore {
    /// Return a copy of all link definitions in the store
    // TODO(thomastaylor312): This should probably return a reference to the link definitions, but
    // that involves wrapping this with an owned ReadWriteLockGuard, which is probably overkill for
    // now
    async fn get_links(&self) -> Result<Vec<LinkDefinition>> {
        Ok(self.linkdefs.read().await.values().cloned().collect())
    }

    /// Return a copy of all claims in the store
    // See comment above about get_links
    async fn get_all_claims(&self) -> Result<Vec<HashMap<String, String>>> {
        Ok(self.claims.read().await.values().cloned().collect())
    }

    /// Return a copy of all provider claims in the store
    // See comment above about get_links
    async fn get_provider_claims(&self) -> Result<Vec<HashMap<String, String>>> {
        Ok(self
            .claims
            .read()
            .await
            .iter()
            // V is the first character of a provider nkey
            .filter_map(|(key, values)| key.starts_with('V').then_some(values))
            .cloned()
            .collect())
    }

    /// Return a copy of all actor claims in the store
    // See comment above about get_links
    async fn get_actor_claims(&self) -> Result<Vec<HashMap<String, String>>> {
        Ok(self
            .claims
            .read()
            .await
            .iter()
            // M is the first character of an actor nkey
            .filter_map(|(key, values)| key.starts_with('M').then_some(values))
            .cloned()
            .collect())
    }

    /// A convenience function to get a list of link definitions filtered using the given filter
    /// function
    async fn get_filtered_links<F>(&self, mut filter_fn: F) -> Result<Vec<LinkDefinition>>
    where
        F: FnMut(&LinkDefinition) -> bool + Send,
    {
        Ok(self
            .linkdefs
            .read()
            .await
            .values()
            // We have to pass in manually because this is a technically an &&LinkDefinition
            .filter(|ld| filter_fn(ld))
            .cloned()
            .collect())
    }

    /// Get a link definition for a specific ID (actor_id, contract_id, link_name)
    async fn get_link(
        &self,
        actor_id: &str,
        link_name: &str,
        contract_id: &str,
    ) -> Result<Option<LinkDefinition>> {
        Ok(self
            .linkdefs
            .read()
            .await
            .get(&ld_hash_raw(actor_id, contract_id, link_name))
            .cloned())
    }

    /// Get claims for a specific provider or actor id
    async fn get_claims(&self, id: &str) -> Result<Option<HashMap<String, String>>> {
        Ok(self.claims.read().await.get(id).cloned())
    }

    async fn put_link(&self, ld: LinkDefinition) -> Result<()> {
        put_link(&self.store, &ld).await?;
        // Immediately add the link to the local cache. It will get overwritten by the watcher as
        // soon as the event comes in, but this way a user can immediately get the link they just
        // put if needed
        self.linkdefs.write().await.insert(ld_hash(&ld), ld);
        Ok(())
    }

    async fn delete_link(&self, actor_id: &str, contract_id: &str, link_name: &str) -> Result<()> {
        delete_link(&self.store, actor_id, contract_id, link_name).await?;
        // Immediately delete the link from the local cache. It will get deleted by the watcher as
        // soon as the event comes in, but this way a user that calls get links will see it gone
        // immediately
        self.linkdefs
            .write()
            .await
            .remove(&ld_hash_raw(actor_id, contract_id, link_name));
        Ok(())
    }
}

async fn handle_entry(
    entry: Entry,
    linkdefs: Arc<RwLock<HashMap<String, LinkDefinition>>>,
    claims: Arc<RwLock<ClaimsMap>>,
) {
    if entry.key.starts_with(LINKDEF_PREFIX) {
        handle_linkdef(entry, linkdefs).await;
    } else if entry.key.starts_with(CLAIMS_PREFIX) {
        handle_claim(entry, claims).await;
    } else {
        debug!(key = %entry.key, "Ignoring entry with unrecognized key");
    }
}

async fn handle_linkdef(entry: Entry, linkdefs: Arc<RwLock<HashMap<String, LinkDefinition>>>) {
    match entry.operation {
        Operation::Delete | Operation::Purge => {
            let mut linkdefs = linkdefs.write().await;
            linkdefs.remove(entry.key.trim_start_matches(LINKDEF_PREFIX));
        }
        Operation::Put => {
            let ld: LinkDefinition = match serde_json::from_slice(&entry.value) {
                Ok(ld) => ld,
                Err(e) => {
                    error!(error = %e, "Unable to deserialize as link definition");
                    return;
                }
            };
            let key = entry.key.trim_start_matches(LINKDEF_PREFIX).to_owned();
            linkdefs.write().await.insert(key, ld);
        }
    }
}

async fn handle_claim(entry: Entry, claims: Arc<RwLock<ClaimsMap>>) {
    match entry.operation {
        Operation::Delete | Operation::Purge => {
            let mut claims = claims.write().await;
            claims.remove(entry.key.trim_start_matches(CLAIMS_PREFIX));
        }
        Operation::Put => {
            let json: HashMap<String, String> = match serde_json::from_slice(&entry.value) {
                Ok(j) => j,
                Err(e) => {
                    error!(error = %e, "Unable to deserialize claim as json");
                    return;
                }
            };
            let sub = match json.get(SUBJECT_KEY) {
                Some(s) => s.to_owned(),
                None => {
                    debug!("Ignoring claim without sub");
                    return;
                }
            };
            claims.write().await.insert(sub, json);
        }
    }
}