similari/trackers/visual_sort/
batch_api.rs

1use crate::prelude::{
2    NoopNotifier, ObservationBuilder, PositionalMetricType, SortTrack, TrackStoreBuilder,
3    VisualSortObservation, VisualSortOptions,
4};
5use crate::store::track_distance::TrackDistanceOkIterator;
6use crate::store::TrackStore;
7use crate::track::utils::FromVec;
8use crate::track::{Feature, Track};
9use crate::trackers::batch::{PredictionBatchRequest, PredictionBatchResult, SceneTracks};
10use crate::trackers::epoch_db::EpochDb;
11use crate::trackers::sort::{
12    AutoWaste, SortAttributesOptions, DEFAULT_AUTO_WASTE_PERIODICITY,
13    MAHALANOBIS_NEW_TRACK_THRESHOLD,
14};
15use crate::trackers::tracker_api::TrackerAPI;
16use crate::trackers::visual_sort::metric::{VisualMetric, VisualMetricOptions};
17use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
18use crate::trackers::visual_sort::track_attributes::{
19    VisualAttributes, VisualAttributesUpdate, VisualSortLookup,
20};
21use crate::trackers::visual_sort::voting::VisualVoting;
22use crate::utils::clipping::bbox_own_areas::{
23    exclusively_owned_areas, exclusively_owned_areas_normalized_shares,
24};
25use crate::voting::Voting;
26use crossbeam::channel::{Receiver, Sender};
27use log::warn;
28use rand::Rng;
29use std::mem;
30use std::sync::{Arc, Condvar, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
31use std::thread::{spawn, JoinHandle};
32
33type VotingSenderChannel = Sender<VotingCommands>;
34type VotingReceiverChannel = Receiver<VotingCommands>;
35
36type MiddlewareVisualSortTrackStore =
37    TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes>;
38type MiddlewareSortTrack = Track<VisualAttributes, VisualMetric, VisualObservationAttributes>;
39type BatchBusyMonitor = Arc<(Mutex<usize>, Condvar)>;
40
41enum VotingCommands {
42    Distances {
43        scene_id: u64,
44        distances: TrackDistanceOkIterator<VisualObservationAttributes>,
45        channel: Sender<SceneTracks>,
46        tracks: Vec<MiddlewareSortTrack>,
47        monitor: BatchBusyMonitor,
48    },
49    Exit,
50}
51
52// /// Easy to use Visual SORT tracker implementation
53// ///
54pub struct BatchVisualSort {
55    monitor: Option<BatchBusyMonitor>,
56    store: Arc<RwLock<MiddlewareVisualSortTrackStore>>,
57    wasted_store: RwLock<MiddlewareVisualSortTrackStore>,
58    metric_opts: Arc<VisualMetricOptions>,
59    track_opts: Arc<SortAttributesOptions>,
60    voting_threads: Vec<(VotingSenderChannel, JoinHandle<()>)>,
61    auto_waste: AutoWaste,
62}
63
64impl Drop for BatchVisualSort {
65    fn drop(&mut self) {
66        let voting_threads = mem::take(&mut self.voting_threads);
67        for (tx, t) in voting_threads {
68            tx.send(VotingCommands::Exit)
69                .expect("Voting thread must be alive.");
70            drop(tx);
71            t.join()
72                .expect("Voting thread is expected to shutdown successfully.");
73        }
74    }
75}
76
77fn voting_thread(
78    store: Arc<RwLock<MiddlewareVisualSortTrackStore>>,
79    rx: VotingReceiverChannel,
80    metric_opts: Arc<VisualMetricOptions>,
81    track_id: Arc<RwLock<u64>>,
82) {
83    while let Ok(command) = rx.recv() {
84        match command {
85            VotingCommands::Distances {
86                scene_id,
87                distances,
88                channel,
89                tracks,
90                monitor,
91            } => {
92                let voting = VisualVoting::new(
93                    match metric_opts.positional_kind {
94                        PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
95                        PositionalMetricType::IoU(t) => t,
96                    },
97                    f32::MAX,
98                    metric_opts.visual_min_votes,
99                );
100                let winners = voting.winners(distances);
101                let mut res = Vec::default();
102                for mut t in tracks {
103                    let source = t.get_track_id();
104
105                    let tid = {
106                        let mut track_id = track_id.write().unwrap();
107                        *track_id += 1;
108                        *track_id
109                    };
110
111                    let track_id: u64 = if let Some(dest) = winners.get(&source) {
112                        let (dest, vt) = dest[0];
113                        if dest == source {
114                            t.set_track_id(tid);
115                            store.write().unwrap().add_track(t).unwrap();
116                            tid
117                        } else {
118                            t.add_observation(
119                                0,
120                                None,
121                                None,
122                                Some(VisualAttributesUpdate::new_voting_type(vt)),
123                            )
124                            .unwrap();
125                            store
126                                .write()
127                                .unwrap()
128                                .merge_external(dest, &t, Some(&[0]), false)
129                                .unwrap();
130                            dest
131                        }
132                    } else {
133                        t.set_track_id(tid);
134                        store.write().unwrap().add_track(t).unwrap();
135                        tid
136                    };
137
138                    let lock = store.read().unwrap();
139                    let store = lock.get_store(track_id as usize);
140                    let track = store.get(&track_id).unwrap();
141
142                    res.push(SortTrack::from(track))
143                }
144
145                let res = channel.send((scene_id, res));
146                if let Err(e) = res {
147                    warn!("Unable to send results to a caller, likely the caller already closed the channel. Error is: {:?}", e);
148                }
149
150                let (lock, cvar) = &*monitor;
151                let mut lock = lock.lock().unwrap();
152                *lock -= 1;
153                cvar.notify_one();
154            }
155            VotingCommands::Exit => break,
156        }
157    }
158}
159
160impl BatchVisualSort {
161    pub fn new(distance_shards: usize, voting_shards: usize, opts: &VisualSortOptions) -> Self {
162        let (track_opts, metric) = opts.clone().build();
163        let track_opts = Arc::new(track_opts);
164        let metric_opts = metric.opts.clone();
165        let store = Arc::new(RwLock::new(
166            TrackStoreBuilder::new(distance_shards)
167                .default_attributes(VisualAttributes::new(track_opts.clone()))
168                .metric(metric.clone())
169                .notifier(NoopNotifier)
170                .build(),
171        ));
172
173        let wasted_store = RwLock::new(
174            TrackStoreBuilder::new(distance_shards)
175                .default_attributes(VisualAttributes::new(track_opts.clone()))
176                .metric(metric)
177                .notifier(NoopNotifier)
178                .build(),
179        );
180
181        let track_id = Arc::new(RwLock::new(0));
182
183        let voting_threads = (0..voting_shards)
184            .map(|_e| {
185                let (tx, rx) = crossbeam::channel::unbounded();
186                let thread_store = store.clone();
187                let thread_track_id = track_id.clone();
188                let thread_metric_opts = metric_opts.clone();
189
190                (
191                    tx,
192                    spawn(move || {
193                        voting_thread(thread_store, rx, thread_metric_opts, thread_track_id)
194                    }),
195                )
196            })
197            .collect::<Vec<_>>();
198
199        Self {
200            monitor: None,
201            store,
202            wasted_store,
203            track_opts,
204            metric_opts,
205            voting_threads,
206            auto_waste: AutoWaste {
207                periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
208                counter: DEFAULT_AUTO_WASTE_PERIODICITY,
209            },
210        }
211    }
212
213    pub fn predict(&mut self, batch_request: PredictionBatchRequest<VisualSortObservation>) {
214        if self.auto_waste.counter == 0 {
215            self.auto_waste();
216            self.auto_waste.counter = self.auto_waste.periodicity;
217        } else {
218            self.auto_waste.counter -= 1;
219        }
220
221        if let Some(m) = &self.monitor {
222            let (lock, cvar) = &**m;
223            let _guard = cvar.wait_while(lock.lock().unwrap(), |v| *v > 0).unwrap();
224        }
225
226        self.monitor = Some(Arc::new((
227            Mutex::new(batch_request.batch_size()),
228            Condvar::new(),
229        )));
230
231        for (i, (scene_id, observations)) in batch_request.get_batch().iter().enumerate() {
232            let mut percentages = Vec::default();
233            let use_own_area_percentage =
234                self.metric_opts.visual_minimal_own_area_percentage_collect
235                    + self.metric_opts.visual_minimal_own_area_percentage_use
236                    > 0.0;
237
238            if use_own_area_percentage {
239                percentages.reserve(observations.len());
240                let boxes = observations
241                    .iter()
242                    .map(|o| &o.bounding_box)
243                    .collect::<Vec<_>>();
244
245                percentages = exclusively_owned_areas_normalized_shares(
246                    boxes.as_ref(),
247                    exclusively_owned_areas(boxes.as_ref()).as_ref(),
248                );
249            }
250
251            let mut rng = rand::thread_rng();
252            let epoch = self.track_opts.next_epoch(*scene_id).unwrap();
253
254            let tracks = observations
255                .iter()
256                .enumerate()
257                .map(|(i, o)| {
258                    self.store
259                        .read()
260                        .expect("Access to store must always succeed")
261                        .new_track(rng.gen())
262                        .observation({
263                            let mut obs = ObservationBuilder::new(0).observation_attributes(
264                                if use_own_area_percentage {
265                                    VisualObservationAttributes::with_own_area_percentage(
266                                        o.feature_quality.unwrap_or(1.0),
267                                        o.bounding_box.clone(),
268                                        percentages[i],
269                                    )
270                                } else {
271                                    VisualObservationAttributes::new(
272                                        o.feature_quality.unwrap_or(1.0),
273                                        o.bounding_box.clone(),
274                                    )
275                                },
276                            );
277
278                            if let Some(feature) = &o.feature {
279                                obs = obs.observation(Feature::from_vec(feature.to_vec()));
280                            }
281
282                            obs.track_attributes_update(
283                                VisualAttributesUpdate::new_init_with_scene(
284                                    epoch,
285                                    *scene_id,
286                                    o.custom_object_id,
287                                ),
288                            )
289                            .build()
290                        })
291                        .build()
292                        .expect("Track creation must always succeed!")
293                })
294                .collect::<Vec<_>>();
295
296            let (dists, errs) = {
297                let mut store = self
298                    .store
299                    .write()
300                    .expect("Access to store must always succeed");
301                store.foreign_track_distances(tracks.clone(), 0, false)
302            };
303
304            assert!(errs.all().is_empty());
305            let thread_id = i % self.voting_threads.len();
306            self.voting_threads[thread_id]
307                .0
308                .send(VotingCommands::Distances {
309                    monitor: self.monitor.as_ref().unwrap().clone(),
310                    scene_id: *scene_id,
311                    distances: dists.into_iter(),
312                    channel: batch_request.get_sender(),
313                    tracks,
314                })
315                .expect("Sending voting request to voting thread must not fail");
316        }
317    }
318
319    pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
320        self.idle_tracks_with_scene(0)
321    }
322
323    pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
324        let store = self.store.read().unwrap();
325
326        store
327            .lookup(VisualSortLookup::IdleLookup(scene_id))
328            .iter()
329            .map(|(track_id, _status)| {
330                let shard = store.get_store(*track_id as usize);
331                let track = shard.get(track_id).unwrap();
332                SortTrack::from(track)
333            })
334            .collect()
335    }
336}
337
338impl
339    TrackerAPI<
340        VisualAttributes,
341        VisualMetric,
342        VisualObservationAttributes,
343        SortAttributesOptions,
344        NoopNotifier,
345    > for BatchVisualSort
346{
347    fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
348        &mut self.auto_waste
349    }
350
351    fn get_opts(&self) -> &SortAttributesOptions {
352        &self.track_opts
353    }
354
355    fn get_main_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareVisualSortTrackStore> {
356        self.store.write().unwrap()
357    }
358
359    fn get_wasted_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareVisualSortTrackStore> {
360        self.wasted_store.write().unwrap()
361    }
362
363    fn get_main_store(&self) -> RwLockReadGuard<MiddlewareVisualSortTrackStore> {
364        self.store.read().unwrap()
365    }
366
367    fn get_wasted_store(&self) -> RwLockReadGuard<MiddlewareVisualSortTrackStore> {
368        self.wasted_store.read().unwrap()
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use crate::prelude::{
375        BoundingBox, PositionalMetricType, VisualSortMetricType, VisualSortObservation,
376        VisualSortOptions,
377    };
378    use crate::trackers::batch::PredictionBatchRequest;
379    use crate::trackers::visual_sort::batch_api::BatchVisualSort;
380
381    #[test]
382    fn test() {
383        let opts = VisualSortOptions::default()
384            .max_idle_epochs(3)
385            .kept_history_length(3)
386            .visual_metric(VisualSortMetricType::Euclidean(1.0))
387            .positional_metric(PositionalMetricType::Mahalanobis)
388            .visual_minimal_track_length(2)
389            .visual_minimal_area(5.0)
390            .visual_minimal_quality_use(0.45)
391            .visual_minimal_quality_collect(0.7)
392            .visual_max_observations(3)
393            .visual_min_votes(2);
394
395        let mut tracker = BatchVisualSort::new(1, 1, &opts);
396        let (mut batch, predictions) = PredictionBatchRequest::<VisualSortObservation>::new();
397        let vec = &vec![1.0, 1.0];
398        batch.add(
399            1,
400            VisualSortObservation::new(
401                Some(vec),
402                Some(0.9),
403                BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
404                Some(13),
405            ),
406        );
407        tracker.predict(batch);
408        for _ in 0..predictions.batch_size() {
409            let (scene, tracks) = predictions.get();
410            assert_eq!(scene, 1);
411            assert_eq!(tracks.len(), 1);
412            dbg!(tracks);
413        }
414
415        let (mut batch, predictions) = PredictionBatchRequest::<VisualSortObservation>::new();
416        let vec1 = &vec![1.0, 1.0];
417        let vec2 = &vec![0.1, 0.15];
418        batch.add(
419            1,
420            VisualSortObservation::new(
421                Some(vec1),
422                Some(0.9),
423                BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
424                Some(13),
425            ),
426        );
427
428        batch.add(
429            2,
430            VisualSortObservation::new(
431                Some(vec2),
432                Some(0.87),
433                BoundingBox::new(5.0, 10.0, 3.0, 5.0).as_xyaah(),
434                Some(23),
435            ),
436        );
437
438        batch.add(
439            2,
440            VisualSortObservation::new(
441                None,
442                None,
443                BoundingBox::new(25.0, 15.0, 3.0, 5.0).as_xyaah(),
444                Some(33),
445            ),
446        );
447
448        tracker.predict(batch);
449        for _ in 0..predictions.batch_size() {
450            let (scene, tracks) = predictions.get();
451            dbg!(scene, tracks);
452        }
453    }
454}
455
456#[derive(Debug, Clone)]
457pub struct VisualSortPredictionBatchRequest<'a> {
458    pub batch: PredictionBatchRequest<VisualSortObservation<'a>>,
459    pub result: Option<PredictionBatchResult>,
460}
461
462impl<'a> VisualSortPredictionBatchRequest<'a> {
463    pub fn new() -> Self {
464        let (batch, result) = PredictionBatchRequest::new();
465        Self {
466            batch,
467            result: Some(result),
468        }
469    }
470
471    pub fn prediction(&mut self) -> Option<PredictionBatchResult> {
472        self.result.take()
473    }
474
475    pub fn add(&mut self, scene_id: u64, elt: VisualSortObservation<'a>) {
476        self.batch.add(scene_id, elt);
477    }
478}
479
480impl Default for VisualSortPredictionBatchRequest<'_> {
481    fn default() -> Self {
482        Self::new()
483    }
484}
485
486#[cfg(feature = "python")]
487pub mod python {
488    use pyo3::prelude::*;
489
490    use crate::{
491        prelude::VisualSortObservation,
492        trackers::{
493            batch::{python::PyPredictionBatchResult, PredictionBatchRequest},
494            sort::python::PySortTrack,
495            tracker_api::TrackerAPI,
496            visual_sort::{
497                options::python::PyVisualSortOptions,
498                python::{PyVisualSortObservation, PyWastedVisualSortTrack},
499                WastedVisualSortTrack,
500            },
501        },
502    };
503
504    use super::{BatchVisualSort, VisualSortPredictionBatchRequest};
505
506    #[pyclass]
507    #[pyo3(name = "BatchVisualSort")]
508    pub struct PyBatchVisualSort(pub(crate) BatchVisualSort);
509
510    #[pymethods]
511    impl PyBatchVisualSort {
512        #[new]
513        #[pyo3(signature = (distance_shards, voting_shards, opts))]
514        pub fn new(distance_shards: i64, voting_shards: i64, opts: &PyVisualSortOptions) -> Self {
515            Self(BatchVisualSort::new(
516                distance_shards
517                    .try_into()
518                    .expect("Positive number expected"),
519                voting_shards.try_into().expect("Positive number expected"),
520                &opts.0,
521            ))
522        }
523
524        #[pyo3(signature = (n))]
525        fn skip_epochs(&mut self, n: i64) {
526            assert!(n > 0);
527            self.0.skip_epochs(n.try_into().unwrap())
528        }
529
530        #[pyo3(signature = (scene_id, n))]
531        fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
532            assert!(n > 0 && scene_id >= 0);
533            self.0
534                .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
535        }
536
537        /// Get the amount of stored tracks per shard
538        ///
539        #[pyo3(signature = ())]
540        fn shard_stats(&self) -> Vec<i64> {
541            Python::with_gil(|py| {
542                py.allow_threads(|| {
543                    self.0
544                        .store
545                        .read()
546                        .unwrap()
547                        .shard_stats()
548                        .into_iter()
549                        .map(|e| i64::try_from(e).unwrap())
550                        .collect()
551                })
552            })
553        }
554
555        /// Get the current epoch for `scene_id` == 0
556        ///
557        #[pyo3( signature = ())]
558        fn current_epoch(&self) -> i64 {
559            self.current_epoch_with_scene(0).try_into().unwrap()
560        }
561
562        /// Get the current epoch for `scene_id`
563        ///
564        /// # Parameters
565        /// * `scene_id` - scene id
566        ///
567        #[pyo3(
568        signature = (scene_id)
569    )]
570        fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
571            assert!(scene_id >= 0);
572            self.0
573                .current_epoch_with_scene(scene_id.try_into().unwrap())
574                .try_into()
575                .unwrap()
576        }
577
578        /// Receive tracking information for observed bboxes of `scene_id` == 0
579        ///
580        /// # Parameters
581        /// * `bboxes` - bounding boxes received from a detector
582        ///
583        #[pyo3(signature = (py_batch))]
584        fn predict(
585            &mut self,
586            py_batch: PyVisualSortPredictionBatchRequest,
587        ) -> PyPredictionBatchResult {
588            let (mut batch, res) = PredictionBatchRequest::<VisualSortObservation>::new();
589            for (scene_id, observations) in py_batch.0.batch.get_batch() {
590                for o in observations {
591                    let f = o.feature.as_ref();
592                    batch.add(
593                        *scene_id,
594                        VisualSortObservation::new(
595                            f.map(|x| x.as_ref()),
596                            o.feature_quality,
597                            o.bounding_box.clone(),
598                            o.custom_object_id,
599                        ),
600                    );
601                }
602            }
603            self.0.predict(batch);
604
605            PyPredictionBatchResult(res)
606        }
607
608        /// Remove all the tracks with expired life
609        ///
610        #[pyo3(signature = ())]
611        fn wasted(&mut self) -> Vec<PyWastedVisualSortTrack> {
612            Python::with_gil(|py| {
613                py.allow_threads(|| {
614                    self.0
615                        .wasted()
616                        .into_iter()
617                        .map(WastedVisualSortTrack::from)
618                        .map(PyWastedVisualSortTrack)
619                        .collect()
620                })
621            })
622        }
623
624        /// Clear all tracks with expired life
625        ///
626        #[pyo3(signature = ())]
627        pub fn clear_wasted(&mut self) {
628            Python::with_gil(|py| py.allow_threads(|| self.0.clear_wasted()));
629        }
630
631        /// Get idle tracks with not expired life
632        ///
633        #[pyo3(signature = (scene_id))]
634        pub fn idle_tracks(&mut self, scene_id: i64) -> Vec<PySortTrack> {
635            Python::with_gil(|py| {
636                py.allow_threads(|| unsafe {
637                    std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
638                })
639            })
640        }
641    }
642
643    #[derive(Debug, Clone)]
644    #[pyclass]
645    #[pyo3(name = "VisualSortPredictionBatchRequest")]
646    pub(crate) struct PyVisualSortPredictionBatchRequest(
647        pub(crate) VisualSortPredictionBatchRequest<'static>,
648    );
649
650    #[pymethods]
651    impl PyVisualSortPredictionBatchRequest {
652        #[new]
653        fn new() -> Self {
654            Self(VisualSortPredictionBatchRequest::new())
655        }
656
657        fn prediction(&mut self) -> Option<PyPredictionBatchResult> {
658            self.0.prediction().map(PyPredictionBatchResult)
659        }
660
661        fn add(&mut self, scene_id: u64, elt: PyVisualSortObservation) {
662            self.0.add(scene_id, elt.0)
663        }
664    }
665}