wasm_tensorflow_models_pose_detection/
model.rs

1use core::panic;
2use std::fmt::Display;
3
4use enum_iterator::Sequence;
5use serde::{Deserialize, Serialize, Serializer};
6use strum_macros::{Display, EnumString, IntoStaticStr};
7use wasm_bindgen::JsValue;
8use wasm_bindgen_futures::js_sys::{Object, Reflect};
9
10#[derive(Display)]
11pub enum PoseNetArchitecture {
12    ResNet50,
13    MobileNetV1,
14}
15
16#[repr(i32)]
17pub enum PoseNetOutputStride {
18    Is32 = 32,
19    Is16 = 16,
20    Is8 = 8,
21}
22
23pub enum MobileNetMultiplier {
24    Is1,
25    Is0Point5,
26    Is0Point75,
27}
28impl Into<f64> for MobileNetMultiplier {
29    fn into(self) -> f64 {
30        match self {
31            MobileNetMultiplier::Is1 => 1.0,
32            MobileNetMultiplier::Is0Point5 => 0.5,
33            MobileNetMultiplier::Is0Point75 => 0.75,
34        }
35    }
36}
37
38#[repr(i32)]
39pub enum QuantBytes {
40    Is1 = 1,
41    Is2 = 2,
42    Is4 = 4,
43}
44
45pub struct InputResolution {
46    pub width: i32,
47    pub height: i32,
48}
49
50pub struct PoseNetModelConfig {
51    pub architecture: PoseNetArchitecture,
52    pub output_stride: PoseNetOutputStride,
53    pub input_resolution: InputResolution,
54    pub multiplier: Option<MobileNetMultiplier>,
55    pub model_url: Option<String>,
56    pub quant_bytes: Option<QuantBytes>,
57}
58
59#[derive(Serialize)]
60#[serde(rename_all = "camelCase")]
61pub struct BlazePoseMediaPipeModelConfig {
62    pub solution_path: Option<String>,
63}
64impl Into<JsValue> for BlazePoseMediaPipeModelConfig {
65    fn into(self) -> JsValue {
66        serde_wasm_bindgen::to_value(&self).unwrap()
67    }
68}
69
70#[derive(Serialize)]
71#[serde(rename_all = "camelCase")]
72pub struct BlazePoseTfjsModelConfig {
73    // TODO: Also allow io.IOHandler
74    pub detector_model_url: Option<String>,
75    // TODO: Also allow io.IOHandler
76    pub landmark_model_url: Option<String>,
77}
78impl Into<JsValue> for BlazePoseTfjsModelConfig {
79    fn into(self) -> JsValue {
80        serde_wasm_bindgen::to_value(&self).unwrap()
81    }
82}
83
84pub enum Runtime {
85    Mediapipe(BlazePoseMediaPipeModelConfig),
86    Tfjs(BlazePoseTfjsModelConfig),
87}
88impl Display for Runtime {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let string = match &self {
91            Self::Mediapipe(_) => "mediapipe",
92            Self::Tfjs(_) => "tfjs",
93        };
94        write!(f, "{string}")
95    }
96}
97impl Into<JsValue> for Runtime {
98    fn into(self) -> JsValue {
99        let runtime = self.to_string();
100        let o: JsValue = match self {
101            Self::Mediapipe(c) => c.into(),
102            Self::Tfjs(c) => c.into(),
103        };
104        Reflect::set(&o, &"runtime".into(), &runtime.into()).unwrap();
105        o
106    }
107}
108
109pub enum BlazePoseModelType {
110    Lite,
111    Full,
112    Heavy,
113}
114impl Display for BlazePoseModelType {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        let string = match &self {
117            Self::Lite => "lite",
118            Self::Full => "full",
119            Self::Heavy => "heavy",
120        };
121        write!(f, "{string}")
122    }
123}
124impl Serialize for BlazePoseModelType {
125    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126    where
127        S: Serializer,
128    {
129        serializer.serialize_str(&self.to_string()[..])
130    }
131}
132
133#[derive(Serialize)]
134#[serde(rename_all = "camelCase")]
135pub struct BlazePoseModelConfig {
136    #[serde(skip_serializing)]
137    pub runtime: Runtime,
138    pub enable_smoothing: Option<bool>,
139    pub enable_segmentation: Option<bool>,
140    pub smooth_segmentation: Option<bool>,
141    pub model_type: Option<BlazePoseModelType>,
142}
143
144impl Into<JsValue> for BlazePoseModelConfig {
145    fn into(self) -> JsValue {
146        let config = Object::from(serde_wasm_bindgen::to_value(&self).unwrap());
147        Object::assign(
148            &config,
149            &Object::from({
150                let runtime_config: JsValue = self.runtime.into();
151                runtime_config
152            }),
153        );
154        config.into()
155    }
156}
157
158pub enum TrackerType {
159    Keypoint,
160    BoundingBox,
161}
162impl Display for TrackerType {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        let string = match &self {
165            Self::Keypoint => "keypoint",
166            Self::BoundingBox => "boundingBox",
167        };
168        write!(f, "{string}")
169    }
170}
171impl Into<JsValue> for TrackerType {
172    fn into(self) -> JsValue {
173        self.to_string().into()
174    }
175}
176impl Serialize for TrackerType {
177    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178    where
179        S: Serializer,
180    {
181        serializer.serialize_str(&self.to_string()[..])
182    }
183}
184
185#[derive(Serialize)]
186#[serde(rename_all = "camelCase")]
187pub struct KeypointTrackerConfig {
188    pub keypoint_confidence_threshold: i32,
189    pub keypoint_falloff: Vec<i32>,
190    pub min_number_of_keypoints: i32,
191}
192
193#[derive(Serialize)]
194pub struct BoundingBoxTrackerConfig;
195
196#[derive(Serialize)]
197#[serde(rename_all = "camelCase")]
198pub struct TrackerConfig {
199    pub max_tracks: i32,
200    pub max_age: i32,
201    pub min_similarity: i32,
202    pub keypoint_tracker_params: Option<KeypointTrackerConfig>,
203    pub bounding_box_tracker_params: Option<BoundingBoxTrackerConfig>,
204}
205
206#[derive(Serialize)]
207#[serde(rename_all = "camelCase")]
208pub struct MoveNetModelConfig {
209    pub enable_smoothing: Option<bool>,
210    pub model_type: Option<String>,
211    // TODO: Also allow io.IOHandler
212    pub model_url: Option<String>,
213    pub min_pose_score: Option<f64>,
214    pub multi_pose_max_dimension: Option<i32>,
215    pub enable_tracking: Option<bool>,
216    pub tracker_type: Option<TrackerType>,
217    pub tracker_config: Option<TrackerConfig>,
218}
219impl Into<JsValue> for MoveNetModelConfig {
220    fn into(self) -> JsValue {
221        serde_wasm_bindgen::to_value(&self).unwrap()
222    }
223}
224
225#[derive(
226    IntoStaticStr, EnumString, PartialEq, Eq, Hash, Serialize, Deserialize, Clone, Sequence,
227)]
228pub enum Model {
229    PoseNet,
230    BlazePose,
231    MoveNet,
232}
233impl Into<JsValue> for Model {
234    fn into(self) -> JsValue {
235        let value: &'static str = self.into();
236        value.into()
237    }
238}
239
240pub enum ModelWithConfig {
241    PoseNet(Option<PoseNetModelConfig>),
242    BlazePose(Option<BlazePoseModelConfig>),
243    MoveNet(Option<MoveNetModelConfig>),
244}
245
246impl ModelWithConfig {
247    pub fn get_name(&self) -> &'static str {
248        match self {
249            Self::PoseNet(_) => "PoseNet",
250            Self::BlazePose(_) => "BlazePose",
251            Self::MoveNet(_) => "MoveNet",
252        }
253    }
254
255    pub fn get_config(self) -> JsValue {
256        match self {
257            Self::BlazePose(blaze_pose_model_config) => blaze_pose_model_config
258                .map(|config| config.into())
259                .unwrap_or(JsValue::UNDEFINED),
260            Self::MoveNet(move_net_model_config) => move_net_model_config
261                .map(|config| config.into())
262                .unwrap_or(JsValue::UNDEFINED),
263            _ => panic!("Not implemented. Make an issue :)"),
264        }
265    }
266}