wasm_tensorflow_models_pose_detection/
pose_detector.rs

1use crate::{call_method::call_method, pose::Pose};
2use serde::{Deserialize, Serialize};
3use wasm_bindgen::{JsCast, JsValue};
4use wasm_bindgen_futures::{
5    js_sys::{Array, Object, Promise},
6    JsFuture,
7};
8
9#[derive(Clone)]
10pub struct PoseDetector {
11    js_value: JsValue,
12}
13
14impl JsCast for PoseDetector {
15    fn instanceof(val: &JsValue) -> bool {
16        // I'm pretty sure there is no `PoseDetector` class in JavaScript.
17        Object::instanceof(val)
18    }
19
20    fn unchecked_from_js(val: JsValue) -> Self {
21        PoseDetector { js_value: val }
22    }
23
24    fn unchecked_from_js_ref(_val: &JsValue) -> &Self {
25        panic!("unchecked_from_js_ref not implemented for PoseDetector");
26    }
27}
28
29impl AsRef<JsValue> for PoseDetector {
30    fn as_ref(&self) -> &JsValue {
31        &self.js_value
32    }
33}
34
35impl From<JsValue> for PoseDetector {
36    fn from(value: JsValue) -> Self {
37        PoseDetector { js_value: value }
38    }
39}
40
41impl Into<JsValue> for PoseDetector {
42    fn into(self) -> JsValue {
43        self.js_value
44    }
45}
46
47#[derive(Serialize, Deserialize, Debug)]
48#[serde(rename_all = "camelCase")]
49pub struct CommonEstimationConfig {
50    pub max_poses: Option<u32>,
51    pub flip_horizontal: Option<bool>,
52}
53impl Into<JsValue> for CommonEstimationConfig {
54    fn into(self) -> JsValue {
55        serde_wasm_bindgen::to_value(&self).unwrap()
56    }
57}
58
59#[derive(Serialize, Deserialize, Debug)]
60#[serde(rename_all = "camelCase")]
61pub struct PoseNetEstimationConfig {
62    #[serde(skip_serializing)]
63    pub common_config: CommonEstimationConfig,
64    pub score_threshold: Option<f64>,
65    pub nms_radius: Option<f64>,
66}
67impl Into<JsValue> for PoseNetEstimationConfig {
68    fn into(self) -> JsValue {
69        let common_config: Object = serde_wasm_bindgen::to_value(&self.common_config)
70            .unwrap()
71            .dyn_into()
72            .unwrap();
73        let pose_net_config: Object = serde_wasm_bindgen::to_value(&self)
74            .unwrap()
75            .dyn_into()
76            .unwrap();
77        Object::assign2(&Object::new(), &common_config, &pose_net_config).into()
78    }
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub enum EstimationConfig {
83    PoseNet(PoseNetEstimationConfig),
84    BlazePoseOrMoveNet(CommonEstimationConfig),
85}
86impl Into<JsValue> for EstimationConfig {
87    fn into(self) -> JsValue {
88        match self {
89            EstimationConfig::PoseNet(config) => config.into(),
90            EstimationConfig::BlazePoseOrMoveNet(config) => config.into(),
91        }
92    }
93}
94
95impl PoseDetector {
96    pub async fn estimate_poses(
97        &self,
98        image: &JsValue,
99        config: EstimationConfig,
100        timestamp: Option<i32>,
101    ) -> Result<Vec<Pose>, JsValue> {
102        let inputs = Array::from_iter(vec![image, &config.into(), &timestamp.into()].iter());
103        let poses = JsFuture::from(Promise::from(call_method(
104            &self.js_value,
105            &"estimatePoses".into(),
106            &inputs,
107        )?))
108        .await?;
109        let poses = Array::from(&poses)
110            .to_vec()
111            .into_iter()
112            .map(|pose| Pose::try_from(pose).unwrap())
113            .collect::<Vec<_>>();
114        Ok(poses)
115    }
116
117    /// The `Drop` trait is implemented to call this function, so in Rust you can just drop it instead of calling this function directly.
118    pub fn dispose(&self) {
119        call_method(&self.js_value, &"dispose".into(), &Array::new()).unwrap();
120    }
121
122    pub fn reset(&self) {
123        call_method(&self.js_value, &"reset".into(), &Array::new()).unwrap();
124    }
125}