1use std::path::Path;
4
5use arcstr::ArcStr;
6
7use crate::error::YoloError;
8
9#[derive(Debug)]
13pub struct YoloModelSession {
14 pub session: ort::session::Session,
15 pub labels: Vec<ArcStr>,
16
17 pub probability_threshold: Option<f32>, pub iou_threshold: Option<f32>, }
20
21impl YoloModelSession {
22 pub fn new(
26 session: ort::session::Session,
27 labels: impl Iterator<Item = impl Into<ArcStr>>,
28 ) -> Self {
29 Self {
30 session,
31 labels: labels.map(Into::into).collect(),
32 probability_threshold: None,
33 iou_threshold: None,
34 }
35 }
36
37 pub fn new_v8(session: ort::session::Session) -> Self {
39 const LABELS: &[ArcStr] = &[
40 arcstr::literal!("person"),
41 arcstr::literal!("bicycle"),
42 arcstr::literal!("car"),
43 arcstr::literal!("motorcycle"),
44 arcstr::literal!("airplane"),
45 arcstr::literal!("bus"),
46 arcstr::literal!("train"),
47 arcstr::literal!("truck"),
48 arcstr::literal!("boat"),
49 arcstr::literal!("traffic light"),
50 arcstr::literal!("fire hydrant"),
51 arcstr::literal!("stop sign"),
52 arcstr::literal!("parking meter"),
53 arcstr::literal!("bench"),
54 arcstr::literal!("bird"),
55 arcstr::literal!("cat"),
56 arcstr::literal!("dog"),
57 arcstr::literal!("horse"),
58 arcstr::literal!("sheep"),
59 arcstr::literal!("cow"),
60 arcstr::literal!("elephant"),
61 arcstr::literal!("bear"),
62 arcstr::literal!("zebra"),
63 arcstr::literal!("giraffe"),
64 arcstr::literal!("backpack"),
65 arcstr::literal!("umbrella"),
66 arcstr::literal!("handbag"),
67 arcstr::literal!("tie"),
68 arcstr::literal!("suitcase"),
69 arcstr::literal!("frisbee"),
70 arcstr::literal!("skis"),
71 arcstr::literal!("snowboard"),
72 arcstr::literal!("sports ball"),
73 arcstr::literal!("kite"),
74 arcstr::literal!("baseball bat"),
75 arcstr::literal!("baseball glove"),
76 arcstr::literal!("skateboard"),
77 arcstr::literal!("surfboard"),
78 arcstr::literal!("tennis racket"),
79 arcstr::literal!("bottle"),
80 arcstr::literal!("wine glass"),
81 arcstr::literal!("cup"),
82 arcstr::literal!("fork"),
83 arcstr::literal!("knife"),
84 arcstr::literal!("spoon"),
85 arcstr::literal!("bowl"),
86 arcstr::literal!("banana"),
87 arcstr::literal!("apple"),
88 arcstr::literal!("sandwich"),
89 arcstr::literal!("orange"),
90 arcstr::literal!("broccoli"),
91 arcstr::literal!("carrot"),
92 arcstr::literal!("hot dog"),
93 arcstr::literal!("pizza"),
94 arcstr::literal!("donut"),
95 arcstr::literal!("cake"),
96 arcstr::literal!("chair"),
97 arcstr::literal!("couch"),
98 arcstr::literal!("potted plant"),
99 arcstr::literal!("bed"),
100 arcstr::literal!("dining table"),
101 arcstr::literal!("toilet"),
102 arcstr::literal!("tv"),
103 arcstr::literal!("laptop"),
104 arcstr::literal!("mouse"),
105 arcstr::literal!("remote"),
106 arcstr::literal!("keyboard"),
107 arcstr::literal!("cell phone"),
108 arcstr::literal!("microwave"),
109 arcstr::literal!("oven"),
110 arcstr::literal!("toaster"),
111 arcstr::literal!("sink"),
112 arcstr::literal!("refrigerator"),
113 arcstr::literal!("book"),
114 arcstr::literal!("clock"),
115 arcstr::literal!("vase"),
116 arcstr::literal!("scissors"),
117 arcstr::literal!("teddy bear"),
118 arcstr::literal!("hair drier"),
119 arcstr::literal!("toothbrush"),
120 ];
121
122 Self {
123 session,
124 labels: LABELS.to_vec(),
125 probability_threshold: None,
126 iou_threshold: None,
127 }
128 }
129
130 pub fn from_filename_v8(filename: impl AsRef<Path>) -> Result<Self, YoloError> {
138 let session = ort::session::Session::builder()
139 .map_err(YoloError::OrtSessionBuildError)?
140 .commit_from_file(filename)
141 .map_err(YoloError::OrtSessionLoadError)?;
142
143 Ok(Self::new_v8(session))
144 }
145
146 pub fn get_labels(&self) -> &[ArcStr] {
147 &self.labels
148 }
149
150 pub fn get_probability_threshold(&self) -> f32 {
151 self.probability_threshold.unwrap_or(0.5)
152 }
153
154 pub fn get_iou_threshold(&self) -> f32 {
155 self.iou_threshold.unwrap_or(0.7)
156 }
157}
158
159impl AsRef<ort::session::Session> for YoloModelSession {
160 fn as_ref(&self) -> &ort::session::Session {
161 &self.session
162 }
163}
164
165impl AsMut<ort::session::Session> for YoloModelSession {
166 fn as_mut(&mut self) -> &mut ort::session::Session {
167 &mut self.session
168 }
169}