wasm_tensorflow_models_pose_detection/
model.rs1use core::panic;
2use std::fmt::Display;
3
4use enum_iterator::Sequence;
5use serde::{Deserialize, Serialize, Serializer};
6use strum_macros::{Display, EnumString, IntoStaticStr};
7use wasm_bindgen::JsValue;
8use wasm_bindgen_futures::js_sys::{Object, Reflect};
9
10#[derive(Display)]
11pub enum PoseNetArchitecture {
12 ResNet50,
13 MobileNetV1,
14}
15
16#[repr(i32)]
17pub enum PoseNetOutputStride {
18 Is32 = 32,
19 Is16 = 16,
20 Is8 = 8,
21}
22
23pub enum MobileNetMultiplier {
24 Is1,
25 Is0Point5,
26 Is0Point75,
27}
28impl Into<f64> for MobileNetMultiplier {
29 fn into(self) -> f64 {
30 match self {
31 MobileNetMultiplier::Is1 => 1.0,
32 MobileNetMultiplier::Is0Point5 => 0.5,
33 MobileNetMultiplier::Is0Point75 => 0.75,
34 }
35 }
36}
37
38#[repr(i32)]
39pub enum QuantBytes {
40 Is1 = 1,
41 Is2 = 2,
42 Is4 = 4,
43}
44
45pub struct InputResolution {
46 pub width: i32,
47 pub height: i32,
48}
49
50pub struct PoseNetModelConfig {
51 pub architecture: PoseNetArchitecture,
52 pub output_stride: PoseNetOutputStride,
53 pub input_resolution: InputResolution,
54 pub multiplier: Option<MobileNetMultiplier>,
55 pub model_url: Option<String>,
56 pub quant_bytes: Option<QuantBytes>,
57}
58
59#[derive(Serialize)]
60#[serde(rename_all = "camelCase")]
61pub struct BlazePoseMediaPipeModelConfig {
62 pub solution_path: Option<String>,
63}
64impl Into<JsValue> for BlazePoseMediaPipeModelConfig {
65 fn into(self) -> JsValue {
66 serde_wasm_bindgen::to_value(&self).unwrap()
67 }
68}
69
70#[derive(Serialize)]
71#[serde(rename_all = "camelCase")]
72pub struct BlazePoseTfjsModelConfig {
73 pub detector_model_url: Option<String>,
75 pub landmark_model_url: Option<String>,
77}
78impl Into<JsValue> for BlazePoseTfjsModelConfig {
79 fn into(self) -> JsValue {
80 serde_wasm_bindgen::to_value(&self).unwrap()
81 }
82}
83
84pub enum Runtime {
85 Mediapipe(BlazePoseMediaPipeModelConfig),
86 Tfjs(BlazePoseTfjsModelConfig),
87}
88impl Display for Runtime {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 let string = match &self {
91 Self::Mediapipe(_) => "mediapipe",
92 Self::Tfjs(_) => "tfjs",
93 };
94 write!(f, "{string}")
95 }
96}
97impl Into<JsValue> for Runtime {
98 fn into(self) -> JsValue {
99 let runtime = self.to_string();
100 let o: JsValue = match self {
101 Self::Mediapipe(c) => c.into(),
102 Self::Tfjs(c) => c.into(),
103 };
104 Reflect::set(&o, &"runtime".into(), &runtime.into()).unwrap();
105 o
106 }
107}
108
109pub enum BlazePoseModelType {
110 Lite,
111 Full,
112 Heavy,
113}
114impl Display for BlazePoseModelType {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 let string = match &self {
117 Self::Lite => "lite",
118 Self::Full => "full",
119 Self::Heavy => "heavy",
120 };
121 write!(f, "{string}")
122 }
123}
124impl Serialize for BlazePoseModelType {
125 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126 where
127 S: Serializer,
128 {
129 serializer.serialize_str(&self.to_string()[..])
130 }
131}
132
133#[derive(Serialize)]
134#[serde(rename_all = "camelCase")]
135pub struct BlazePoseModelConfig {
136 #[serde(skip_serializing)]
137 pub runtime: Runtime,
138 pub enable_smoothing: Option<bool>,
139 pub enable_segmentation: Option<bool>,
140 pub smooth_segmentation: Option<bool>,
141 pub model_type: Option<BlazePoseModelType>,
142}
143
144impl Into<JsValue> for BlazePoseModelConfig {
145 fn into(self) -> JsValue {
146 let config = Object::from(serde_wasm_bindgen::to_value(&self).unwrap());
147 Object::assign(
148 &config,
149 &Object::from({
150 let runtime_config: JsValue = self.runtime.into();
151 runtime_config
152 }),
153 );
154 config.into()
155 }
156}
157
158pub enum TrackerType {
159 Keypoint,
160 BoundingBox,
161}
162impl Display for TrackerType {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 let string = match &self {
165 Self::Keypoint => "keypoint",
166 Self::BoundingBox => "boundingBox",
167 };
168 write!(f, "{string}")
169 }
170}
171impl Into<JsValue> for TrackerType {
172 fn into(self) -> JsValue {
173 self.to_string().into()
174 }
175}
176impl Serialize for TrackerType {
177 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178 where
179 S: Serializer,
180 {
181 serializer.serialize_str(&self.to_string()[..])
182 }
183}
184
185#[derive(Serialize)]
186#[serde(rename_all = "camelCase")]
187pub struct KeypointTrackerConfig {
188 pub keypoint_confidence_threshold: i32,
189 pub keypoint_falloff: Vec<i32>,
190 pub min_number_of_keypoints: i32,
191}
192
193#[derive(Serialize)]
194pub struct BoundingBoxTrackerConfig;
195
196#[derive(Serialize)]
197#[serde(rename_all = "camelCase")]
198pub struct TrackerConfig {
199 pub max_tracks: i32,
200 pub max_age: i32,
201 pub min_similarity: i32,
202 pub keypoint_tracker_params: Option<KeypointTrackerConfig>,
203 pub bounding_box_tracker_params: Option<BoundingBoxTrackerConfig>,
204}
205
206#[derive(Serialize)]
207#[serde(rename_all = "camelCase")]
208pub struct MoveNetModelConfig {
209 pub enable_smoothing: Option<bool>,
210 pub model_type: Option<String>,
211 pub model_url: Option<String>,
213 pub min_pose_score: Option<f64>,
214 pub multi_pose_max_dimension: Option<i32>,
215 pub enable_tracking: Option<bool>,
216 pub tracker_type: Option<TrackerType>,
217 pub tracker_config: Option<TrackerConfig>,
218}
219impl Into<JsValue> for MoveNetModelConfig {
220 fn into(self) -> JsValue {
221 serde_wasm_bindgen::to_value(&self).unwrap()
222 }
223}
224
225#[derive(
226 IntoStaticStr, EnumString, PartialEq, Eq, Hash, Serialize, Deserialize, Clone, Sequence,
227)]
228pub enum Model {
229 PoseNet,
230 BlazePose,
231 MoveNet,
232}
233impl Into<JsValue> for Model {
234 fn into(self) -> JsValue {
235 let value: &'static str = self.into();
236 value.into()
237 }
238}
239
240pub enum ModelWithConfig {
241 PoseNet(Option<PoseNetModelConfig>),
242 BlazePose(Option<BlazePoseModelConfig>),
243 MoveNet(Option<MoveNetModelConfig>),
244}
245
246impl ModelWithConfig {
247 pub fn get_name(&self) -> &'static str {
248 match self {
249 Self::PoseNet(_) => "PoseNet",
250 Self::BlazePose(_) => "BlazePose",
251 Self::MoveNet(_) => "MoveNet",
252 }
253 }
254
255 pub fn get_config(self) -> JsValue {
256 match self {
257 Self::BlazePose(blaze_pose_model_config) => blaze_pose_model_config
258 .map(|config| config.into())
259 .unwrap_or(JsValue::UNDEFINED),
260 Self::MoveNet(move_net_model_config) => move_net_model_config
261 .map(|config| config.into())
262 .unwrap_or(JsValue::UNDEFINED),
263 _ => panic!("Not implemented. Make an issue :)"),
264 }
265 }
266}