Skip to main content

worker/
vectorize.rs

1use crate::send::SendFuture;
2use crate::{env::EnvBinding, Error, Result};
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use wasm_bindgen::{JsCast, JsValue};
5use wasm_bindgen_futures::JsFuture;
6use worker_sys::VectorizeIndex as VectorizeIndexSys;
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
9pub struct VectorizeQueryOptions {
10    #[serde(rename = "topK", skip_serializing_if = "Option::is_none")]
11    pub top_k: Option<u32>,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub namespace: Option<String>,
14    #[serde(rename = "returnValues", skip_serializing_if = "Option::is_none")]
15    pub return_values: Option<bool>,
16    #[serde(rename = "returnMetadata", skip_serializing_if = "Option::is_none")]
17    pub return_metadata: Option<bool>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub struct VectorizeGetByIdsRequest {
22    pub ids: Vec<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub namespace: Option<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28pub struct VectorizeDeleteByIdsRequest {
29    pub ids: Vec<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub namespace: Option<String>,
32}
33
34#[derive(Debug, Clone)]
35pub struct VectorizeIndex(VectorizeIndexSys);
36
37unsafe impl Send for VectorizeIndex {}
38unsafe impl Sync for VectorizeIndex {}
39
40impl VectorizeIndex {
41    async fn invoke_without_input<U: DeserializeOwned>(
42        &self,
43        method: impl FnOnce(&VectorizeIndexSys) -> std::result::Result<js_sys::Promise, JsValue>,
44    ) -> Result<U> {
45        let promise = method(&self.0)?;
46        let output = SendFuture::new(JsFuture::from(promise)).await;
47        let value = output.map_err(Error::from)?;
48        Ok(serde_wasm_bindgen::from_value(value)?)
49    }
50
51    async fn invoke_with_input<T: Serialize, U: DeserializeOwned>(
52        &self,
53        input: T,
54        method: impl FnOnce(
55            &VectorizeIndexSys,
56            JsValue,
57        ) -> std::result::Result<js_sys::Promise, JsValue>,
58    ) -> Result<U> {
59        let promise = method(&self.0, serde_wasm_bindgen::to_value(&input)?)?;
60        let output = SendFuture::new(JsFuture::from(promise)).await;
61        let value = output.map_err(Error::from)?;
62        Ok(serde_wasm_bindgen::from_value(value)?)
63    }
64
65    pub async fn upsert<T: Serialize, U: DeserializeOwned>(&self, input: T) -> Result<U> {
66        self.invoke_with_input(input, |inner, value| inner.upsert(value))
67            .await
68    }
69
70    pub async fn describe<U: DeserializeOwned>(&self) -> Result<U> {
71        self.invoke_without_input(|inner| inner.describe()).await
72    }
73
74    pub async fn query<T: Serialize, U: DeserializeOwned>(&self, input: T) -> Result<U> {
75        self.invoke_with_input(input, |inner, value| inner.query(value))
76            .await
77    }
78
79    pub async fn get_by_ids<T: Serialize, U: DeserializeOwned>(&self, input: T) -> Result<U> {
80        self.invoke_with_input(input, |inner, value| inner.get_by_ids(value))
81            .await
82    }
83
84    pub async fn delete<T: Serialize, U: DeserializeOwned>(&self, input: T) -> Result<U> {
85        self.invoke_with_input(input, |inner, value| inner.delete(value))
86            .await
87    }
88
89    pub async fn delete_by_ids<T: Serialize, U: DeserializeOwned>(&self, input: T) -> Result<U> {
90        self.invoke_with_input(input, |inner, value| inner.delete_by_ids(value))
91            .await
92    }
93}
94
95impl EnvBinding for VectorizeIndex {
96    const TYPE_NAME: &'static str = "Object";
97
98    fn get(val: JsValue) -> Result<Self> {
99        if !val.is_object() {
100            return Err("Binding cannot be cast to VectorizeIndex from non-object value".into());
101        }
102
103        let has_query = js_sys::Reflect::has(&val, &JsValue::from("query"))?;
104        if !has_query {
105            return Err("Binding cannot be cast to VectorizeIndex: missing `query` method".into());
106        }
107
108        Ok(Self(val.unchecked_into()))
109    }
110}
111
112impl JsCast for VectorizeIndex {
113    fn instanceof(val: &JsValue) -> bool {
114        val.is_object()
115    }
116
117    fn unchecked_from_js(val: JsValue) -> Self {
118        Self(val.unchecked_into())
119    }
120
121    fn unchecked_from_js_ref(val: &JsValue) -> &Self {
122        unsafe { &*(val as *const JsValue as *const Self) }
123    }
124}
125
126impl AsRef<JsValue> for VectorizeIndex {
127    fn as_ref(&self) -> &JsValue {
128        &self.0
129    }
130}
131
132impl From<JsValue> for VectorizeIndex {
133    fn from(val: JsValue) -> Self {
134        Self(val.unchecked_into())
135    }
136}
137
138impl From<VectorizeIndex> for JsValue {
139    fn from(value: VectorizeIndex) -> Self {
140        value.0.into()
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::{VectorizeDeleteByIdsRequest, VectorizeGetByIdsRequest, VectorizeQueryOptions};
147
148    #[test]
149    fn query_options_uses_top_k_field_name() {
150        let options = VectorizeQueryOptions {
151            top_k: Some(8),
152            namespace: Some("ns-1".to_string()),
153            return_values: Some(true),
154            return_metadata: None,
155        };
156
157        let serialized = serde_json::to_value(options).expect("query options should serialize");
158        assert_eq!(serialized.get("topK").and_then(|v| v.as_u64()), Some(8));
159        assert_eq!(
160            serialized.get("namespace").and_then(|v| v.as_str()),
161            Some("ns-1")
162        );
163        assert_eq!(
164            serialized.get("returnValues").and_then(|v| v.as_bool()),
165            Some(true)
166        );
167        assert!(serialized.get("returnMetadata").is_none());
168    }
169
170    #[test]
171    fn id_request_types_serialize_ids_and_namespace() {
172        let get_by_ids = VectorizeGetByIdsRequest {
173            ids: vec!["id-1".to_string(), "id-2".to_string()],
174            namespace: Some("ns-2".to_string()),
175        };
176        let delete_by_ids = VectorizeDeleteByIdsRequest {
177            ids: vec!["id-3".to_string()],
178            namespace: None,
179        };
180
181        let get_serialized = serde_json::to_value(get_by_ids).expect("request should serialize");
182        let delete_serialized =
183            serde_json::to_value(delete_by_ids).expect("request should serialize");
184
185        assert_eq!(get_serialized["ids"], serde_json::json!(["id-1", "id-2"]));
186        assert_eq!(get_serialized["namespace"], serde_json::json!("ns-2"));
187        assert_eq!(delete_serialized["ids"], serde_json::json!(["id-3"]));
188        assert!(delete_serialized.get("namespace").is_none());
189    }
190}