scatter_net/legacy/net/methods/
fetch_encrypted_chunk.rs1use std::{
2 collections::VecDeque,
3 future::Future,
4 pin::Pin,
5 sync::Arc,
6 task::{
7 Poll::{Pending, Ready},
8 Waker,
9 },
10};
11
12use anyhow::Result;
13use n0_future::FutureExt;
14use parking_lot::RwLock;
15use ps_buffer::BufferError;
16use ps_datachunk::{DataChunk, SerializedDataChunk};
17use ps_datalake::error::PsDataLakeError;
18use ps_hkey::{Hash, Hkey};
19use tokio::task::JoinHandle;
20
21use crate::{FetchResponse, Peer, PeerGroup, ScatterNet};
22
23impl ScatterNet {
24 #[must_use]
25 pub fn fetch_encrypted_chunk(&self, hash: Arc<Hash>) -> ScatterNetFetchEncryptedChunk<'_> {
26 ScatterNetFetchEncryptedChunk::init(self, hash)
27 }
28}
29
30type SerializedDataChunkFuture = dyn Future<Output = Option<SerializedDataChunk>> + Send + Sync;
31type BoxedSerializedDataChunkFuture = Pin<Box<SerializedDataChunkFuture>>;
32
33pub struct ScatterNetFetchEncryptedChunk<'lt> {
34 futures: RwLock<Vec<BoxedSerializedDataChunkFuture>>,
35 net: &'lt ScatterNet,
36 hash: Arc<Hash>,
37 peer_groups: VecDeque<PeerGroup>,
38 value: Option<SerializedDataChunk>,
39 timeout: Option<JoinHandle<()>>,
40 num_attempts_all_peers: u8,
41}
42
43impl<'lt> ScatterNetFetchEncryptedChunk<'lt> {
44 pub fn init(net: &'lt ScatterNet, hash: Arc<Hash>) -> Self {
45 let locally_found = match net.lake.get_encrypted_chunk(&hash) {
46 Ok(chunk) => chunk.serialize().ok(),
47 Err(PsDataLakeError::NotFound) => None,
48 Err(err) => {
49 eprintln!("Fetching chunk {hash} from DataLake failed: {err:?}");
50 None
51 }
52 };
53
54 let peer_groups = match &locally_found {
55 None => net.read().peer_groups.clone().into(),
56 Some(_) => VecDeque::new(),
57 };
58
59 Self {
60 futures: RwLock::default(),
61 hash,
62 net,
63 peer_groups,
64 value: locally_found,
65 timeout: None,
66 num_attempts_all_peers: 0,
67 }
68 }
69
70 pub fn schedule(&mut self, waker: Waker) {
71 let new_task = tokio::spawn(async move {
72 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
73 waker.wake();
74 });
75
76 if let Some(old_task) = self.timeout.replace(new_task) {
77 old_task.abort();
78 }
79 }
80}
81
82impl Future for ScatterNetFetchEncryptedChunk<'_> {
83 type Output = Result<SerializedDataChunk, ScatterNetFetchEncryptedChunkError>;
84
85 fn poll(
86 self: std::pin::Pin<&mut Self>,
87 cx: &mut std::task::Context<'_>,
88 ) -> std::task::Poll<Self::Output> {
89 let this = self.get_mut();
90 let hash = this.hash.clone();
91
92 this.futures
93 .write()
94 .retain_mut(|future| match future.poll(cx) {
95 std::task::Poll::Pending => true,
96 std::task::Poll::Ready(None) => false,
97 std::task::Poll::Ready(Some(chunk)) => {
98 this.value = Some(chunk);
99 false
100 }
101 });
102
103 if let Some(chunk) = this.value.take() {
104 return Ready(Ok(chunk));
105 }
106
107 let mut request_from_peer = |peer: Peer, hash: Arc<Hash>| {
108 let mut future = Box::pin(async move {
109 let fetched = Peer::fetch_blob(peer, Hkey::Direct(hash)).await;
110
111 match fetched {
112 Ok(Some(FetchResponse::Success(buffer))) => {
113 SerializedDataChunk::from_data(buffer).ok()
114 }
115 _ => None,
116 }
117 });
118
119 match future.poll(cx) {
120 Pending => {
121 this.futures.write().push(future);
122 }
123 Ready(Some(chunk)) => {
124 this.value.replace(chunk);
125 }
126 Ready(None) => (),
127 }
128 };
129
130 if let Some(peer_group) = this.peer_groups.pop_front() {
131 if let Some(peer) = peer_group.get_peer_by_hash(&this.hash) {
132 request_from_peer(peer, hash.clone());
133 }
134 }
135
136 if this.futures.read().is_empty() && this.peer_groups.is_empty() {
137 if this.num_attempts_all_peers >= 3 {
138 return Ready(Err(ScatterNetFetchEncryptedChunkError::NotFound));
139 }
140
141 let peers: Vec<Peer> = this.net.read().peers.values().cloned().collect();
142
143 for peer in peers {
144 request_from_peer(peer, hash.clone());
145 }
146
147 this.num_attempts_all_peers += 1;
148 }
149
150 this.schedule(cx.waker().clone());
151
152 Pending
153 }
154}
155
156#[derive(thiserror::Error, Debug)]
157pub enum ScatterNetFetchEncryptedChunkError {
158 #[error(transparent)]
159 BufferError(#[from] BufferError),
160 #[error("Unable to fetch the data in question.")]
161 NotFound,
162}