Skip to main content

timsrust_utils/
graph_buffer.rs

1// use std::{collections::HashMap, sync::Arc};
2
3// use indicatif::{ParallelProgressIterator, ProgressStyle};
4// use rayon::prelude::*;
5
6// use crate::{Peak, PeakReader, utils::Synced};
7
8// #[derive(Debug)]
9// enum Node<T> {
10//     Loading,
11//     Ready {
12//         index: usize,
13//         data: Arc<T>,
14//         neighbors: HashMap<usize, Arc<T>>,
15//     },
16// }
17
18// impl<T> Node<T> {
19//     fn new(index: usize, data: Arc<T>) -> Self {
20//         Self::Ready {
21//             index,
22//             data,
23//             neighbors: HashMap::new(),
24//         }
25//     }
26
27//     fn get_data(&self) -> Option<Arc<T>> {
28//         if let Self::Ready { data, .. } = self {
29//             Some(Arc::clone(data))
30//         } else {
31//             None
32//         }
33//     }
34
35//     fn get_neighbors(&self) -> Option<&HashMap<usize, Arc<T>>> {
36//         if let Self::Ready { neighbors, .. } = self {
37//             Some(neighbors)
38//         } else {
39//             None
40//         }
41//     }
42// }
43
44// pub trait GraphProcessor<T> {
45//     fn load(&self, index: usize) -> Option<T>;
46//     fn neighbor_indices(&self, index: usize, data: Arc<T>) -> Vec<usize>;
47//     fn process(&self, data: &Arc<T>, neighbors: &HashMap<usize, Arc<T>>);
48// }
49
50// pub struct Graph<T, P: GraphProcessor<T>> {
51//     buffer: Synced<HashMap<usize, Node<T>>>,
52//     processor: P,
53// }
54
55// impl<T, P: GraphProcessor<T>> Graph<T, P> {
56//     pub fn new(processor: P) -> Self {
57//         Self {
58//             buffer: Synced::default(),
59//             processor,
60//         }
61//     }
62
63//     fn load(&self, index: usize) -> Option<Arc<T>> {
64//         loop {
65//             match self
66//                 .buffer
67//                 .with_lock(|b| match b.get(&index) {
68//                     Some(opt) => opt.get_data(),
69//                     _ => unreachable!("set in previous step"),
70//                 })
71//                 .unwrap()
72//             {
73//                 Some(data) => return Some(data),
74//                 None => {},
75//             }
76//             std::thread::sleep(std::time::Duration::from_millis(1));
77//         }
78//     }
79
80//     fn set(&self, index: usize) -> Option<Arc<T>> {
81//         let result = Arc::new(self.processor.load(index)?);
82//         _ = self.buffer.with_lock(|b| {
83//             b.insert(index, Node::new(index, result.clone()));
84//         });
85//         Some(result)
86//     }
87
88//     fn is_available(&self, index: usize) -> bool {
89//         self.buffer
90//             .with_lock(|b| match b.get_mut(&index) {
91//                 Some(_) => true,
92//                 None => {
93//                     b.insert(index, Node::Loading);
94//                     false
95//                 },
96//             })
97//             .unwrap()
98//     }
99
100//     fn get(&self, index: usize) -> Option<Arc<T>> {
101//         if self.is_available(index) {
102//             self.load(index)
103//         } else {
104//             self.set(index)
105//         }
106//     }
107
108//     fn get_neighbors(&self, index: usize) -> Option<HashMap<usize, Arc<T>>> {
109//         let data = self.get(index)?;
110//         let neighbor_indices =
111//             self.processor.neighbor_indices(index, data.clone());
112//         let neighbors_todo = self
113//             .buffer
114//             .with_lock(|b| {
115//                 let neighbors = b
116//                     .get(&index)
117//                     .expect("Ready node")
118//                     .get_neighbors()
119//                     .expect("Ready node");
120//                 let neighbors_todo = neighbor_indices
121//                     .iter()
122//                     .filter(|i| !neighbors.contains_key(i))
123//                     .filter(|i| match b.get(i) {
124//                         Some(Node::Ready { .. }) => false,
125//                         Some(Node::Loading) => false,
126//                         None => {
127//                             b.insert(**i, Node::Loading);
128//                             true
129//                         },
130//                     })
131//                     .collect::<Vec<_>>();
132//                 neighbors_todo
133//             })
134//             .unwrap();
135//         let neighbors_todo = neighbor_indices
136//             .into_iter()
137//             .map(|i| {
138//                 let neighbor = self
139//                     .buffer
140//                     .with_lock(|b| {
141//                         let neighbors = b
142//                             .get(&index)
143//                             .expect("Ready node")
144//                             .get_neighbors()
145//                             .expect("Ready node");
146//                     })
147//                     .unwrap();
148//                 (i, neighbor.unwrap())
149//             })
150//             .collect::<HashMap<_, _>>();
151//         Some(neighbors)
152//     }
153
154//     pub fn process(&self, index: usize) -> Option<()> {
155//         let data = self.get(index)?;
156//         let neighbors = self.get_neighbors(index)?;
157//         self.processor.process(&data, &neighbors);
158//         _ = self.remove(index);
159//         Some(())
160//     }
161
162//     fn remove(&self, index: usize) -> Option<()> {
163//         match self.buffer.with_lock(|b| b.remove(&index)).unwrap() {
164//             Some(_) => Some(()),
165//             None => None,
166//         }
167//     }
168// }
169
170// struct FramePeaks {
171//     index: usize,
172//     peaks: Vec<Peak>,
173// }
174
175// struct PeakProcessor {
176//     peak_loader: PeakReader,
177// }
178
179// impl GraphProcessor<FramePeaks> for PeakProcessor {
180//     fn load(&self, index: usize) -> Option<FramePeaks> {
181//         let peaks = self.peak_loader.get_peaks_from_frame(index).ok()?;
182//         Some(FramePeaks { index, peaks })
183//     }
184
185//     fn neighbor_indices(
186//         &self,
187//         index: usize,
188//         data: Arc<FramePeaks>,
189//     ) -> Vec<usize> {
190//         todo!("Define neighbor logic");
191//         let mut neighbors = Vec::new();
192//         if index > 0 {
193//             neighbors.push(index - 1);
194//         }
195//         neighbors.push(index + 1);
196//         neighbors
197//     }
198
199//     fn process(
200//         &self,
201//         data: &Arc<FramePeaks>,
202//         neighbors: &HashMap<usize, Arc<FramePeaks>>,
203//     ) {
204//         todo!("Implement peak processing with neighbors");
205//     }
206// }
207
208// pub fn run(peak_reader: PeakReader) {
209//     let len = peak_reader.frame_count();
210//     let processor = PeakProcessor {
211//         peak_loader: peak_reader,
212//     };
213//     let buffer = Graph::new(processor);
214//     // todo!("Set atomuic usize for index order");
215//     (0..len)
216//         .into_par_iter()
217//         .progress_with_style(
218//             ProgressStyle::default_bar()
219//                 .template(" [{elapsed_precise}] {bar} {pos:>7}/{len:7} ({eta}, {per_sec} frames/s)")
220//                 .expect("Failed to set progress style")
221//         )
222//         .for_each(|index| {
223//             buffer.process(index);
224//         });
225// }
226
227use rayon::prelude::*;
228use std::{
229    collections::HashMap,
230    sync::atomic::{AtomicUsize, Ordering},
231    sync::{Arc, Condvar, Mutex},
232};
233
234// ------------------ Node state ------------------
235
236#[derive(Debug)]
237enum NodeState<T> {
238    Loading,
239    Ready(Arc<T>),
240}
241
242struct Node<T> {
243    state: Mutex<NodeState<T>>,
244    ready: Condvar,
245    ref_count: AtomicUsize, // counts threads and neighbors
246}
247
248impl<T> Node<T> {
249    fn new_loading() -> Arc<Self> {
250        Arc::new(Self {
251            state: Mutex::new(NodeState::Loading),
252            ready: Condvar::new(),
253            ref_count: AtomicUsize::new(1), // initial ref for the loader
254        })
255    }
256
257    fn set_ready(&self, data: Arc<T>) {
258        let mut state = self.state.lock().unwrap();
259        *state = NodeState::Ready(data);
260        self.ready.notify_all();
261    }
262
263    fn wait_ready(&self) -> Arc<T> {
264        let mut state = self.state.lock().unwrap();
265        loop {
266            match &*state {
267                NodeState::Ready(data) => return Arc::clone(data),
268                NodeState::Loading => {
269                    state = self.ready.wait(state).unwrap();
270                },
271            }
272        }
273    }
274
275    fn increment(&self) {
276        self.ref_count.fetch_add(1, Ordering::Relaxed);
277    }
278
279    fn decrement(&self) -> usize {
280        self.ref_count.fetch_sub(1, Ordering::Release) - 1
281    }
282}
283
284// ------------------ Trait for user ------------------
285
286pub trait GraphProcessor<T>: Sync + Send {
287    fn load(&self, index: usize) -> Option<T>;
288    fn neighbor_indices(&self, index: usize, data: Arc<T>) -> Vec<usize>;
289    fn process(&self, data: &Arc<T>, neighbors: &HashMap<usize, Arc<T>>);
290}
291
292// ------------------ Graph ------------------
293
294pub struct Graph<T, P: GraphProcessor<T>> {
295    buffer: Mutex<HashMap<usize, Arc<Node<T>>>>,
296    processor: Arc<P>,
297}
298
299impl<T: Send + Sync + 'static, P: GraphProcessor<T>> Graph<T, P> {
300    pub fn new(processor: P) -> Self {
301        Self {
302            buffer: Mutex::new(HashMap::new()),
303            processor: Arc::new(processor),
304        }
305    }
306
307    fn get(&self, index: usize) -> Option<Arc<T>> {
308        // Fast path: node already exists
309        if let Some(node) = self.buffer.lock().unwrap().get(&index) {
310            node.increment();
311            return Some(node.wait_ready());
312        }
313
314        // Create new loading node
315        let node = Node::new_loading();
316        let mut buffer = self.buffer.lock().unwrap();
317        // Another thread might have inserted in the meantime
318        let node = match buffer.entry(index) {
319            std::collections::hash_map::Entry::Occupied(e) => {
320                let existing: &Arc<Node<T>> = e.get();
321                existing.increment();
322                existing.clone()
323            },
324            std::collections::hash_map::Entry::Vacant(e) => {
325                e.insert(node.clone());
326                node
327            },
328        };
329        drop(buffer);
330
331        // Load data only if we were the loader
332        let state = node.state.lock().unwrap();
333        if let NodeState::Loading = &*state {
334            drop(state); // release lock before loading
335            let data = Arc::new(self.processor.load(index)?);
336            node.set_ready(data.clone());
337            Some(data)
338        } else {
339            Some(node.wait_ready())
340        }
341    }
342
343    fn get_neighbors(
344        &self,
345        index: usize,
346        data: Arc<T>,
347    ) -> HashMap<usize, Arc<T>> {
348        let mut neighbors = HashMap::new();
349        for n_idx in self.processor.neighbor_indices(index, data.clone()) {
350            if let Some(n_data) = self.get(n_idx) {
351                neighbors.insert(n_idx, n_data);
352            }
353        }
354        neighbors
355    }
356
357    pub fn process(&self, index: usize) -> Option<()> {
358        let data = self.get(index)?;
359        let neighbors = self.get_neighbors(index, data.clone());
360        self.processor.process(&data, &neighbors);
361
362        // Decrement self
363        #[allow(clippy::collapsible_if)]
364        if let Some(node) = self.buffer.lock().unwrap().get(&index) {
365            if node.decrement() == 0 {
366                self.buffer.lock().unwrap().remove(&index);
367            }
368        }
369
370        // Decrement neighbors
371        #[allow(clippy::collapsible_if)]
372        for (n_idx, _) in neighbors {
373            if let Some(node) = self.buffer.lock().unwrap().get(&n_idx) {
374                if node.decrement() == 0 {
375                    self.buffer.lock().unwrap().remove(&n_idx);
376                }
377            }
378        }
379
380        Some(())
381    }
382
383    // Optional: process all indices in parallel
384    pub fn process_all<I: IntoParallelIterator<Item = usize>>(
385        &self,
386        indices: I,
387    ) {
388        indices.into_par_iter().for_each(|i| {
389            self.process(i);
390        });
391    }
392}