1use crate::prelude::{
2 NoopNotifier, ObservationBuilder, PositionalMetricType, SortTrack, TrackStoreBuilder,
3 VisualSortObservation, VisualSortOptions,
4};
5use crate::store::track_distance::TrackDistanceOkIterator;
6use crate::store::TrackStore;
7use crate::track::utils::FromVec;
8use crate::track::{Feature, Track};
9use crate::trackers::batch::{PredictionBatchRequest, PredictionBatchResult, SceneTracks};
10use crate::trackers::epoch_db::EpochDb;
11use crate::trackers::sort::{
12 AutoWaste, SortAttributesOptions, DEFAULT_AUTO_WASTE_PERIODICITY,
13 MAHALANOBIS_NEW_TRACK_THRESHOLD,
14};
15use crate::trackers::tracker_api::TrackerAPI;
16use crate::trackers::visual_sort::metric::{VisualMetric, VisualMetricOptions};
17use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
18use crate::trackers::visual_sort::track_attributes::{
19 VisualAttributes, VisualAttributesUpdate, VisualSortLookup,
20};
21use crate::trackers::visual_sort::voting::VisualVoting;
22use crate::utils::clipping::bbox_own_areas::{
23 exclusively_owned_areas, exclusively_owned_areas_normalized_shares,
24};
25use crate::voting::Voting;
26use crossbeam::channel::{Receiver, Sender};
27use log::warn;
28use rand::Rng;
29use std::mem;
30use std::sync::{Arc, Condvar, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
31use std::thread::{spawn, JoinHandle};
32
33type VotingSenderChannel = Sender<VotingCommands>;
34type VotingReceiverChannel = Receiver<VotingCommands>;
35
36type MiddlewareVisualSortTrackStore =
37 TrackStore<VisualAttributes, VisualMetric, VisualObservationAttributes>;
38type MiddlewareSortTrack = Track<VisualAttributes, VisualMetric, VisualObservationAttributes>;
39type BatchBusyMonitor = Arc<(Mutex<usize>, Condvar)>;
40
41enum VotingCommands {
42 Distances {
43 scene_id: u64,
44 distances: TrackDistanceOkIterator<VisualObservationAttributes>,
45 channel: Sender<SceneTracks>,
46 tracks: Vec<MiddlewareSortTrack>,
47 monitor: BatchBusyMonitor,
48 },
49 Exit,
50}
51
52pub struct BatchVisualSort {
55 monitor: Option<BatchBusyMonitor>,
56 store: Arc<RwLock<MiddlewareVisualSortTrackStore>>,
57 wasted_store: RwLock<MiddlewareVisualSortTrackStore>,
58 metric_opts: Arc<VisualMetricOptions>,
59 track_opts: Arc<SortAttributesOptions>,
60 voting_threads: Vec<(VotingSenderChannel, JoinHandle<()>)>,
61 auto_waste: AutoWaste,
62}
63
64impl Drop for BatchVisualSort {
65 fn drop(&mut self) {
66 let voting_threads = mem::take(&mut self.voting_threads);
67 for (tx, t) in voting_threads {
68 tx.send(VotingCommands::Exit)
69 .expect("Voting thread must be alive.");
70 drop(tx);
71 t.join()
72 .expect("Voting thread is expected to shutdown successfully.");
73 }
74 }
75}
76
77fn voting_thread(
78 store: Arc<RwLock<MiddlewareVisualSortTrackStore>>,
79 rx: VotingReceiverChannel,
80 metric_opts: Arc<VisualMetricOptions>,
81 track_id: Arc<RwLock<u64>>,
82) {
83 while let Ok(command) = rx.recv() {
84 match command {
85 VotingCommands::Distances {
86 scene_id,
87 distances,
88 channel,
89 tracks,
90 monitor,
91 } => {
92 let voting = VisualVoting::new(
93 match metric_opts.positional_kind {
94 PositionalMetricType::Mahalanobis => MAHALANOBIS_NEW_TRACK_THRESHOLD,
95 PositionalMetricType::IoU(t) => t,
96 },
97 f32::MAX,
98 metric_opts.visual_min_votes,
99 );
100 let winners = voting.winners(distances);
101 let mut res = Vec::default();
102 for mut t in tracks {
103 let source = t.get_track_id();
104
105 let tid = {
106 let mut track_id = track_id.write().unwrap();
107 *track_id += 1;
108 *track_id
109 };
110
111 let track_id: u64 = if let Some(dest) = winners.get(&source) {
112 let (dest, vt) = dest[0];
113 if dest == source {
114 t.set_track_id(tid);
115 store.write().unwrap().add_track(t).unwrap();
116 tid
117 } else {
118 t.add_observation(
119 0,
120 None,
121 None,
122 Some(VisualAttributesUpdate::new_voting_type(vt)),
123 )
124 .unwrap();
125 store
126 .write()
127 .unwrap()
128 .merge_external(dest, &t, Some(&[0]), false)
129 .unwrap();
130 dest
131 }
132 } else {
133 t.set_track_id(tid);
134 store.write().unwrap().add_track(t).unwrap();
135 tid
136 };
137
138 let lock = store.read().unwrap();
139 let store = lock.get_store(track_id as usize);
140 let track = store.get(&track_id).unwrap();
141
142 res.push(SortTrack::from(track))
143 }
144
145 let res = channel.send((scene_id, res));
146 if let Err(e) = res {
147 warn!("Unable to send results to a caller, likely the caller already closed the channel. Error is: {:?}", e);
148 }
149
150 let (lock, cvar) = &*monitor;
151 let mut lock = lock.lock().unwrap();
152 *lock -= 1;
153 cvar.notify_one();
154 }
155 VotingCommands::Exit => break,
156 }
157 }
158}
159
160impl BatchVisualSort {
161 pub fn new(distance_shards: usize, voting_shards: usize, opts: &VisualSortOptions) -> Self {
162 let (track_opts, metric) = opts.clone().build();
163 let track_opts = Arc::new(track_opts);
164 let metric_opts = metric.opts.clone();
165 let store = Arc::new(RwLock::new(
166 TrackStoreBuilder::new(distance_shards)
167 .default_attributes(VisualAttributes::new(track_opts.clone()))
168 .metric(metric.clone())
169 .notifier(NoopNotifier)
170 .build(),
171 ));
172
173 let wasted_store = RwLock::new(
174 TrackStoreBuilder::new(distance_shards)
175 .default_attributes(VisualAttributes::new(track_opts.clone()))
176 .metric(metric)
177 .notifier(NoopNotifier)
178 .build(),
179 );
180
181 let track_id = Arc::new(RwLock::new(0));
182
183 let voting_threads = (0..voting_shards)
184 .map(|_e| {
185 let (tx, rx) = crossbeam::channel::unbounded();
186 let thread_store = store.clone();
187 let thread_track_id = track_id.clone();
188 let thread_metric_opts = metric_opts.clone();
189
190 (
191 tx,
192 spawn(move || {
193 voting_thread(thread_store, rx, thread_metric_opts, thread_track_id)
194 }),
195 )
196 })
197 .collect::<Vec<_>>();
198
199 Self {
200 monitor: None,
201 store,
202 wasted_store,
203 track_opts,
204 metric_opts,
205 voting_threads,
206 auto_waste: AutoWaste {
207 periodicity: DEFAULT_AUTO_WASTE_PERIODICITY,
208 counter: DEFAULT_AUTO_WASTE_PERIODICITY,
209 },
210 }
211 }
212
213 pub fn predict(&mut self, batch_request: PredictionBatchRequest<VisualSortObservation>) {
214 if self.auto_waste.counter == 0 {
215 self.auto_waste();
216 self.auto_waste.counter = self.auto_waste.periodicity;
217 } else {
218 self.auto_waste.counter -= 1;
219 }
220
221 if let Some(m) = &self.monitor {
222 let (lock, cvar) = &**m;
223 let _guard = cvar.wait_while(lock.lock().unwrap(), |v| *v > 0).unwrap();
224 }
225
226 self.monitor = Some(Arc::new((
227 Mutex::new(batch_request.batch_size()),
228 Condvar::new(),
229 )));
230
231 for (i, (scene_id, observations)) in batch_request.get_batch().iter().enumerate() {
232 let mut percentages = Vec::default();
233 let use_own_area_percentage =
234 self.metric_opts.visual_minimal_own_area_percentage_collect
235 + self.metric_opts.visual_minimal_own_area_percentage_use
236 > 0.0;
237
238 if use_own_area_percentage {
239 percentages.reserve(observations.len());
240 let boxes = observations
241 .iter()
242 .map(|o| &o.bounding_box)
243 .collect::<Vec<_>>();
244
245 percentages = exclusively_owned_areas_normalized_shares(
246 boxes.as_ref(),
247 exclusively_owned_areas(boxes.as_ref()).as_ref(),
248 );
249 }
250
251 let mut rng = rand::thread_rng();
252 let epoch = self.track_opts.next_epoch(*scene_id).unwrap();
253
254 let tracks = observations
255 .iter()
256 .enumerate()
257 .map(|(i, o)| {
258 self.store
259 .read()
260 .expect("Access to store must always succeed")
261 .new_track(rng.gen())
262 .observation({
263 let mut obs = ObservationBuilder::new(0).observation_attributes(
264 if use_own_area_percentage {
265 VisualObservationAttributes::with_own_area_percentage(
266 o.feature_quality.unwrap_or(1.0),
267 o.bounding_box.clone(),
268 percentages[i],
269 )
270 } else {
271 VisualObservationAttributes::new(
272 o.feature_quality.unwrap_or(1.0),
273 o.bounding_box.clone(),
274 )
275 },
276 );
277
278 if let Some(feature) = &o.feature {
279 obs = obs.observation(Feature::from_vec(feature.to_vec()));
280 }
281
282 obs.track_attributes_update(
283 VisualAttributesUpdate::new_init_with_scene(
284 epoch,
285 *scene_id,
286 o.custom_object_id,
287 ),
288 )
289 .build()
290 })
291 .build()
292 .expect("Track creation must always succeed!")
293 })
294 .collect::<Vec<_>>();
295
296 let (dists, errs) = {
297 let mut store = self
298 .store
299 .write()
300 .expect("Access to store must always succeed");
301 store.foreign_track_distances(tracks.clone(), 0, false)
302 };
303
304 assert!(errs.all().is_empty());
305 let thread_id = i % self.voting_threads.len();
306 self.voting_threads[thread_id]
307 .0
308 .send(VotingCommands::Distances {
309 monitor: self.monitor.as_ref().unwrap().clone(),
310 scene_id: *scene_id,
311 distances: dists.into_iter(),
312 channel: batch_request.get_sender(),
313 tracks,
314 })
315 .expect("Sending voting request to voting thread must not fail");
316 }
317 }
318
319 pub fn idle_tracks(&mut self) -> Vec<SortTrack> {
320 self.idle_tracks_with_scene(0)
321 }
322
323 pub fn idle_tracks_with_scene(&mut self, scene_id: u64) -> Vec<SortTrack> {
324 let store = self.store.read().unwrap();
325
326 store
327 .lookup(VisualSortLookup::IdleLookup(scene_id))
328 .iter()
329 .map(|(track_id, _status)| {
330 let shard = store.get_store(*track_id as usize);
331 let track = shard.get(track_id).unwrap();
332 SortTrack::from(track)
333 })
334 .collect()
335 }
336}
337
338impl
339 TrackerAPI<
340 VisualAttributes,
341 VisualMetric,
342 VisualObservationAttributes,
343 SortAttributesOptions,
344 NoopNotifier,
345 > for BatchVisualSort
346{
347 fn get_auto_waste_obj_mut(&mut self) -> &mut AutoWaste {
348 &mut self.auto_waste
349 }
350
351 fn get_opts(&self) -> &SortAttributesOptions {
352 &self.track_opts
353 }
354
355 fn get_main_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareVisualSortTrackStore> {
356 self.store.write().unwrap()
357 }
358
359 fn get_wasted_store_mut(&mut self) -> RwLockWriteGuard<MiddlewareVisualSortTrackStore> {
360 self.wasted_store.write().unwrap()
361 }
362
363 fn get_main_store(&self) -> RwLockReadGuard<MiddlewareVisualSortTrackStore> {
364 self.store.read().unwrap()
365 }
366
367 fn get_wasted_store(&self) -> RwLockReadGuard<MiddlewareVisualSortTrackStore> {
368 self.wasted_store.read().unwrap()
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use crate::prelude::{
375 BoundingBox, PositionalMetricType, VisualSortMetricType, VisualSortObservation,
376 VisualSortOptions,
377 };
378 use crate::trackers::batch::PredictionBatchRequest;
379 use crate::trackers::visual_sort::batch_api::BatchVisualSort;
380
381 #[test]
382 fn test() {
383 let opts = VisualSortOptions::default()
384 .max_idle_epochs(3)
385 .kept_history_length(3)
386 .visual_metric(VisualSortMetricType::Euclidean(1.0))
387 .positional_metric(PositionalMetricType::Mahalanobis)
388 .visual_minimal_track_length(2)
389 .visual_minimal_area(5.0)
390 .visual_minimal_quality_use(0.45)
391 .visual_minimal_quality_collect(0.7)
392 .visual_max_observations(3)
393 .visual_min_votes(2);
394
395 let mut tracker = BatchVisualSort::new(1, 1, &opts);
396 let (mut batch, predictions) = PredictionBatchRequest::<VisualSortObservation>::new();
397 let vec = &vec![1.0, 1.0];
398 batch.add(
399 1,
400 VisualSortObservation::new(
401 Some(vec),
402 Some(0.9),
403 BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
404 Some(13),
405 ),
406 );
407 tracker.predict(batch);
408 for _ in 0..predictions.batch_size() {
409 let (scene, tracks) = predictions.get();
410 assert_eq!(scene, 1);
411 assert_eq!(tracks.len(), 1);
412 dbg!(tracks);
413 }
414
415 let (mut batch, predictions) = PredictionBatchRequest::<VisualSortObservation>::new();
416 let vec1 = &vec![1.0, 1.0];
417 let vec2 = &vec![0.1, 0.15];
418 batch.add(
419 1,
420 VisualSortObservation::new(
421 Some(vec1),
422 Some(0.9),
423 BoundingBox::new(1.0, 1.0, 3.0, 5.0).as_xyaah(),
424 Some(13),
425 ),
426 );
427
428 batch.add(
429 2,
430 VisualSortObservation::new(
431 Some(vec2),
432 Some(0.87),
433 BoundingBox::new(5.0, 10.0, 3.0, 5.0).as_xyaah(),
434 Some(23),
435 ),
436 );
437
438 batch.add(
439 2,
440 VisualSortObservation::new(
441 None,
442 None,
443 BoundingBox::new(25.0, 15.0, 3.0, 5.0).as_xyaah(),
444 Some(33),
445 ),
446 );
447
448 tracker.predict(batch);
449 for _ in 0..predictions.batch_size() {
450 let (scene, tracks) = predictions.get();
451 dbg!(scene, tracks);
452 }
453 }
454}
455
456#[derive(Debug, Clone)]
457pub struct VisualSortPredictionBatchRequest<'a> {
458 pub batch: PredictionBatchRequest<VisualSortObservation<'a>>,
459 pub result: Option<PredictionBatchResult>,
460}
461
462impl<'a> VisualSortPredictionBatchRequest<'a> {
463 pub fn new() -> Self {
464 let (batch, result) = PredictionBatchRequest::new();
465 Self {
466 batch,
467 result: Some(result),
468 }
469 }
470
471 pub fn prediction(&mut self) -> Option<PredictionBatchResult> {
472 self.result.take()
473 }
474
475 pub fn add(&mut self, scene_id: u64, elt: VisualSortObservation<'a>) {
476 self.batch.add(scene_id, elt);
477 }
478}
479
480impl Default for VisualSortPredictionBatchRequest<'_> {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486#[cfg(feature = "python")]
487pub mod python {
488 use pyo3::prelude::*;
489
490 use crate::{
491 prelude::VisualSortObservation,
492 trackers::{
493 batch::{python::PyPredictionBatchResult, PredictionBatchRequest},
494 sort::python::PySortTrack,
495 tracker_api::TrackerAPI,
496 visual_sort::{
497 options::python::PyVisualSortOptions,
498 python::{PyVisualSortObservation, PyWastedVisualSortTrack},
499 WastedVisualSortTrack,
500 },
501 },
502 };
503
504 use super::{BatchVisualSort, VisualSortPredictionBatchRequest};
505
506 #[pyclass]
507 #[pyo3(name = "BatchVisualSort")]
508 pub struct PyBatchVisualSort(pub(crate) BatchVisualSort);
509
510 #[pymethods]
511 impl PyBatchVisualSort {
512 #[new]
513 #[pyo3(signature = (distance_shards, voting_shards, opts))]
514 pub fn new(distance_shards: i64, voting_shards: i64, opts: &PyVisualSortOptions) -> Self {
515 Self(BatchVisualSort::new(
516 distance_shards
517 .try_into()
518 .expect("Positive number expected"),
519 voting_shards.try_into().expect("Positive number expected"),
520 &opts.0,
521 ))
522 }
523
524 #[pyo3(signature = (n))]
525 fn skip_epochs(&mut self, n: i64) {
526 assert!(n > 0);
527 self.0.skip_epochs(n.try_into().unwrap())
528 }
529
530 #[pyo3(signature = (scene_id, n))]
531 fn skip_epochs_for_scene(&mut self, scene_id: i64, n: i64) {
532 assert!(n > 0 && scene_id >= 0);
533 self.0
534 .skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
535 }
536
537 #[pyo3(signature = ())]
540 fn shard_stats(&self) -> Vec<i64> {
541 Python::with_gil(|py| {
542 py.allow_threads(|| {
543 self.0
544 .store
545 .read()
546 .unwrap()
547 .shard_stats()
548 .into_iter()
549 .map(|e| i64::try_from(e).unwrap())
550 .collect()
551 })
552 })
553 }
554
555 #[pyo3( signature = ())]
558 fn current_epoch(&self) -> i64 {
559 self.current_epoch_with_scene(0).try_into().unwrap()
560 }
561
562 #[pyo3(
568 signature = (scene_id)
569 )]
570 fn current_epoch_with_scene(&self, scene_id: i64) -> isize {
571 assert!(scene_id >= 0);
572 self.0
573 .current_epoch_with_scene(scene_id.try_into().unwrap())
574 .try_into()
575 .unwrap()
576 }
577
578 #[pyo3(signature = (py_batch))]
584 fn predict(
585 &mut self,
586 py_batch: PyVisualSortPredictionBatchRequest,
587 ) -> PyPredictionBatchResult {
588 let (mut batch, res) = PredictionBatchRequest::<VisualSortObservation>::new();
589 for (scene_id, observations) in py_batch.0.batch.get_batch() {
590 for o in observations {
591 let f = o.feature.as_ref();
592 batch.add(
593 *scene_id,
594 VisualSortObservation::new(
595 f.map(|x| x.as_ref()),
596 o.feature_quality,
597 o.bounding_box.clone(),
598 o.custom_object_id,
599 ),
600 );
601 }
602 }
603 self.0.predict(batch);
604
605 PyPredictionBatchResult(res)
606 }
607
608 #[pyo3(signature = ())]
611 fn wasted(&mut self) -> Vec<PyWastedVisualSortTrack> {
612 Python::with_gil(|py| {
613 py.allow_threads(|| {
614 self.0
615 .wasted()
616 .into_iter()
617 .map(WastedVisualSortTrack::from)
618 .map(PyWastedVisualSortTrack)
619 .collect()
620 })
621 })
622 }
623
624 #[pyo3(signature = ())]
627 pub fn clear_wasted(&mut self) {
628 Python::with_gil(|py| py.allow_threads(|| self.0.clear_wasted()));
629 }
630
631 #[pyo3(signature = (scene_id))]
634 pub fn idle_tracks(&mut self, scene_id: i64) -> Vec<PySortTrack> {
635 Python::with_gil(|py| {
636 py.allow_threads(|| unsafe {
637 std::mem::transmute(self.0.idle_tracks_with_scene(scene_id.try_into().unwrap()))
638 })
639 })
640 }
641 }
642
643 #[derive(Debug, Clone)]
644 #[pyclass]
645 #[pyo3(name = "VisualSortPredictionBatchRequest")]
646 pub(crate) struct PyVisualSortPredictionBatchRequest(
647 pub(crate) VisualSortPredictionBatchRequest<'static>,
648 );
649
650 #[pymethods]
651 impl PyVisualSortPredictionBatchRequest {
652 #[new]
653 fn new() -> Self {
654 Self(VisualSortPredictionBatchRequest::new())
655 }
656
657 fn prediction(&mut self) -> Option<PyPredictionBatchResult> {
658 self.0.prediction().map(PyPredictionBatchResult)
659 }
660
661 fn add(&mut self, scene_id: u64, elt: PyVisualSortObservation) {
662 self.0.add(scene_id, elt.0)
663 }
664 }
665}