wasm_tensorflow_models_pose_detection/
pose_detector.rs1use crate::{call_method::call_method, pose::Pose};
2use serde::{Deserialize, Serialize};
3use wasm_bindgen::{JsCast, JsValue};
4use wasm_bindgen_futures::{
5 js_sys::{Array, Object, Promise},
6 JsFuture,
7};
8
9#[derive(Clone)]
10pub struct PoseDetector {
11 js_value: JsValue,
12}
13
14impl JsCast for PoseDetector {
15 fn instanceof(val: &JsValue) -> bool {
16 Object::instanceof(val)
18 }
19
20 fn unchecked_from_js(val: JsValue) -> Self {
21 PoseDetector { js_value: val }
22 }
23
24 fn unchecked_from_js_ref(_val: &JsValue) -> &Self {
25 panic!("unchecked_from_js_ref not implemented for PoseDetector");
26 }
27}
28
29impl AsRef<JsValue> for PoseDetector {
30 fn as_ref(&self) -> &JsValue {
31 &self.js_value
32 }
33}
34
35impl From<JsValue> for PoseDetector {
36 fn from(value: JsValue) -> Self {
37 PoseDetector { js_value: value }
38 }
39}
40
41impl Into<JsValue> for PoseDetector {
42 fn into(self) -> JsValue {
43 self.js_value
44 }
45}
46
47#[derive(Serialize, Deserialize, Debug)]
48#[serde(rename_all = "camelCase")]
49pub struct CommonEstimationConfig {
50 pub max_poses: Option<u32>,
51 pub flip_horizontal: Option<bool>,
52}
53impl Into<JsValue> for CommonEstimationConfig {
54 fn into(self) -> JsValue {
55 serde_wasm_bindgen::to_value(&self).unwrap()
56 }
57}
58
59#[derive(Serialize, Deserialize, Debug)]
60#[serde(rename_all = "camelCase")]
61pub struct PoseNetEstimationConfig {
62 #[serde(skip_serializing)]
63 pub common_config: CommonEstimationConfig,
64 pub score_threshold: Option<f64>,
65 pub nms_radius: Option<f64>,
66}
67impl Into<JsValue> for PoseNetEstimationConfig {
68 fn into(self) -> JsValue {
69 let common_config: Object = serde_wasm_bindgen::to_value(&self.common_config)
70 .unwrap()
71 .dyn_into()
72 .unwrap();
73 let pose_net_config: Object = serde_wasm_bindgen::to_value(&self)
74 .unwrap()
75 .dyn_into()
76 .unwrap();
77 Object::assign2(&Object::new(), &common_config, &pose_net_config).into()
78 }
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub enum EstimationConfig {
83 PoseNet(PoseNetEstimationConfig),
84 BlazePoseOrMoveNet(CommonEstimationConfig),
85}
86impl Into<JsValue> for EstimationConfig {
87 fn into(self) -> JsValue {
88 match self {
89 EstimationConfig::PoseNet(config) => config.into(),
90 EstimationConfig::BlazePoseOrMoveNet(config) => config.into(),
91 }
92 }
93}
94
95impl PoseDetector {
96 pub async fn estimate_poses(
97 &self,
98 image: &JsValue,
99 config: EstimationConfig,
100 timestamp: Option<i32>,
101 ) -> Result<Vec<Pose>, JsValue> {
102 let inputs = Array::from_iter(vec![image, &config.into(), ×tamp.into()].iter());
103 let poses = JsFuture::from(Promise::from(call_method(
104 &self.js_value,
105 &"estimatePoses".into(),
106 &inputs,
107 )?))
108 .await?;
109 let poses = Array::from(&poses)
110 .to_vec()
111 .into_iter()
112 .map(|pose| Pose::try_from(pose).unwrap())
113 .collect::<Vec<_>>();
114 Ok(poses)
115 }
116
117 pub fn dispose(&self) {
119 call_method(&self.js_value, &"dispose".into(), &Array::new()).unwrap();
120 }
121
122 pub fn reset(&self) {
123 call_method(&self.js_value, &"reset".into(), &Array::new()).unwrap();
124 }
125}