Skip to main content

wasm_sql/postgres/bindings/
codecs.rs

1use crate::core::bindings::{
2    SqlHostState,
3    codec_utils,
4    generated::wasm_sql::core::{
5        codecs::{PushResult, ValuePosition},
6        query::QueryResults,
7        query_types::SqlArguments,
8        util_types::Error,
9    },
10};
11use crate::postgres::bindings::generated::wasm_sql::postgres::codecs::{
12    Date, Hstore, Inet, IpAddr, Macaddr, Numeric, PgInterval, Time, Timestamp, Timestamptz, Uuid,
13};
14use bigdecimal::BigDecimal;
15use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
16use ipnet::IpNet;
17use mac_address::MacAddress;
18use sqlx::postgres::types::PgInterval as SqlxPgInterval;
19use sqlx::types::JsonRawValue;
20use sqlx::{postgres::types::PgHstore, types::Json};
21use std::net::{Ipv4Addr, Ipv6Addr};
22use std::str::FromStr;
23
24impl crate::postgres::bindings::generated::wasm_sql::postgres::codecs::Host for SqlHostState {
25    fn push_int16(
26        &mut self,
27        value: Option<i16>,
28        to: wasmtime::component::Resource<SqlArguments>,
29    ) -> PushResult {
30        codec_utils::encode(value, self.table.get(&to)?)
31    }
32
33    fn get_int16(
34        &mut self,
35        result: wasmtime::component::Resource<QueryResults>,
36        position: ValuePosition,
37    ) -> Result<Option<i16>, Error> {
38        codec_utils::decode(self.table.get(&result)?, position)
39    }
40
41    fn push_int32(
42        &mut self,
43        value: Option<i32>,
44        to: wasmtime::component::Resource<SqlArguments>,
45    ) -> PushResult {
46        codec_utils::encode(value, self.table.get(&to)?)
47    }
48
49    fn get_int32(
50        &mut self,
51        result: wasmtime::component::Resource<QueryResults>,
52        position: ValuePosition,
53    ) -> Result<Option<i32>, Error> {
54        codec_utils::decode(self.table.get(&result)?, position)
55    }
56
57    fn push_int64(
58        &mut self,
59        value: Option<i64>,
60        to: wasmtime::component::Resource<SqlArguments>,
61    ) -> PushResult {
62        codec_utils::encode(value, self.table.get(&to)?)
63    }
64
65    fn get_int64(
66        &mut self,
67        result: wasmtime::component::Resource<QueryResults>,
68        position: ValuePosition,
69    ) -> Result<Option<i64>, Error> {
70        codec_utils::decode(self.table.get(&result)?, position)
71    }
72
73    fn push_float32(
74        &mut self,
75        value: Option<f32>,
76        to: wasmtime::component::Resource<SqlArguments>,
77    ) -> PushResult {
78        codec_utils::encode(value, self.table.get(&to)?)
79    }
80
81    fn get_float32(
82        &mut self,
83        result: wasmtime::component::Resource<QueryResults>,
84        position: ValuePosition,
85    ) -> Result<Option<f32>, Error> {
86        codec_utils::decode(self.table.get(&result)?, position)
87    }
88
89    fn push_float64(
90        &mut self,
91        value: Option<f64>,
92        to: wasmtime::component::Resource<SqlArguments>,
93    ) -> PushResult {
94        codec_utils::encode(value, self.table.get(&to)?)
95    }
96
97    fn get_float64(
98        &mut self,
99        result: wasmtime::component::Resource<QueryResults>,
100        position: ValuePosition,
101    ) -> Result<Option<f64>, Error> {
102        codec_utils::decode(self.table.get(&result)?, position)
103    }
104
105    fn push_string(
106        &mut self,
107        value: Option<wasmtime::component::__internal::String>,
108        to: wasmtime::component::Resource<SqlArguments>,
109    ) -> PushResult {
110        codec_utils::encode(value, self.table.get(&to)?)
111    }
112
113    fn get_string(
114        &mut self,
115        result: wasmtime::component::Resource<QueryResults>,
116        position: ValuePosition,
117    ) -> Result<Option<wasmtime::component::__internal::String>, Error> {
118        codec_utils::decode(self.table.get(&result)?, position)
119    }
120
121    fn push_bool(
122        &mut self,
123        value: Option<bool>,
124        to: wasmtime::component::Resource<SqlArguments>,
125    ) -> PushResult {
126        codec_utils::encode(value, self.table.get(&to)?)
127    }
128
129    fn get_bool(
130        &mut self,
131        result: wasmtime::component::Resource<QueryResults>,
132        position: ValuePosition,
133    ) -> Result<Option<bool>, Error> {
134        codec_utils::decode(self.table.get(&result)?, position)
135    }
136
137    fn push_json(
138        &mut self,
139        value: Option<wasmtime::component::__internal::String>,
140        to: wasmtime::component::Resource<SqlArguments>,
141    ) -> PushResult {
142        let to = self.table.get(&to)?;
143
144        match value {
145            Some(value) => {
146                let raw_value =
147                    JsonRawValue::from_string(value).map_err(|e| Error::Encode(e.to_string()))?;
148
149                codec_utils::encode(Json(raw_value), to)
150            }
151            None => codec_utils::encode(None::<Json<Box<JsonRawValue>>>, to),
152        }
153    }
154
155    fn get_json(
156        &mut self,
157        result: wasmtime::component::Resource<QueryResults>,
158        position: ValuePosition,
159    ) -> Result<Option<wasmtime::component::__internal::String>, Error> {
160        let a = codec_utils::decode::<Option<&JsonRawValue>>(self.table.get(&result)?, position)?;
161
162        Ok(a.map(|x| x.get().to_string()))
163    }
164
165    fn push_uuid(
166        &mut self,
167        value: Option<Uuid>,
168        to: wasmtime::component::Resource<SqlArguments>,
169    ) -> PushResult {
170        let value = value
171            .map(|v| sqlx::types::Uuid::try_parse(v.as_str()))
172            .transpose()
173            .map_err(|e| Error::Encode(e.to_string()))?;
174
175        codec_utils::encode(value, self.table.get(&to)?)
176    }
177
178    fn get_uuid(
179        &mut self,
180        result: wasmtime::component::Resource<QueryResults>,
181        position: ValuePosition,
182    ) -> Result<Option<Uuid>, Error> {
183        let value: Option<sqlx::types::Uuid> =
184            codec_utils::decode(self.table.get(&result)?, position)?;
185
186        Ok(value.map(|v| v.to_string()))
187    }
188
189    fn push_hstore(
190        &mut self,
191        value: Option<Hstore>,
192        to: wasmtime::component::Resource<SqlArguments>,
193    ) -> PushResult {
194        let value = value.map(|v: Vec<(String, Option<String>)>| PgHstore(v.into_iter().collect()));
195
196        codec_utils::encode(value, self.table.get(&to)?)
197    }
198
199    fn get_hstore(
200        &mut self,
201        result: wasmtime::component::Resource<QueryResults>,
202        position: ValuePosition,
203    ) -> Result<Option<Hstore>, Error> {
204        let value: Option<PgHstore> = codec_utils::decode(self.table.get(&result)?, position)?;
205
206        Ok(value.map(|v| v.into_iter().collect()))
207    }
208
209    // === Date/Time codecs ===
210
211    fn push_date(
212        &mut self,
213        value: Option<Date>,
214        to: wasmtime::component::Resource<SqlArguments>,
215    ) -> PushResult {
216        let value = value
217            .map(|v| NaiveDate::parse_from_str(&v, "%Y-%m-%d"))
218            .transpose()
219            .map_err(|e| Error::Encode(e.to_string()))?;
220
221        codec_utils::encode(value, self.table.get(&to)?)
222    }
223
224    fn get_date(
225        &mut self,
226        result: wasmtime::component::Resource<QueryResults>,
227        position: ValuePosition,
228    ) -> Result<Option<Date>, Error> {
229        let value: Option<NaiveDate> = codec_utils::decode(self.table.get(&result)?, position)?;
230
231        Ok(value.map(|v| v.format("%Y-%m-%d").to_string()))
232    }
233
234    fn push_time(
235        &mut self,
236        value: Option<Time>,
237        to: wasmtime::component::Resource<SqlArguments>,
238    ) -> PushResult {
239        let value = value
240            .map(|v| NaiveTime::parse_from_str(&v, "%H:%M:%S%.f"))
241            .transpose()
242            .map_err(|e| Error::Encode(e.to_string()))?;
243
244        codec_utils::encode(value, self.table.get(&to)?)
245    }
246
247    fn get_time(
248        &mut self,
249        result: wasmtime::component::Resource<QueryResults>,
250        position: ValuePosition,
251    ) -> Result<Option<Time>, Error> {
252        let value: Option<NaiveTime> = codec_utils::decode(self.table.get(&result)?, position)?;
253
254        Ok(value.map(|v| v.format("%H:%M:%S%.f").to_string()))
255    }
256
257    fn push_timestamp(
258        &mut self,
259        value: Option<Timestamp>,
260        to: wasmtime::component::Resource<SqlArguments>,
261    ) -> PushResult {
262        let value = value
263            .map(|v| NaiveDateTime::parse_from_str(&v, "%Y-%m-%dT%H:%M:%S%.f"))
264            .transpose()
265            .map_err(|e| Error::Encode(e.to_string()))?;
266
267        codec_utils::encode(value, self.table.get(&to)?)
268    }
269
270    fn get_timestamp(
271        &mut self,
272        result: wasmtime::component::Resource<QueryResults>,
273        position: ValuePosition,
274    ) -> Result<Option<Timestamp>, Error> {
275        let value: Option<NaiveDateTime> =
276            codec_utils::decode(self.table.get(&result)?, position)?;
277
278        Ok(value.map(|v| v.format("%Y-%m-%dT%H:%M:%S%.f").to_string()))
279    }
280
281    fn push_timestamptz(
282        &mut self,
283        value: Option<Timestamptz>,
284        to: wasmtime::component::Resource<SqlArguments>,
285    ) -> PushResult {
286        let value = value
287            .map(|v| DateTime::parse_from_rfc3339(&v).map(|dt| dt.with_timezone(&Utc)))
288            .transpose()
289            .map_err(|e| Error::Encode(e.to_string()))?;
290
291        codec_utils::encode(value, self.table.get(&to)?)
292    }
293
294    fn get_timestamptz(
295        &mut self,
296        result: wasmtime::component::Resource<QueryResults>,
297        position: ValuePosition,
298    ) -> Result<Option<Timestamptz>, Error> {
299        let value: Option<DateTime<Utc>> =
300            codec_utils::decode(self.table.get(&result)?, position)?;
301
302        Ok(value.map(|v| v.to_rfc3339()))
303    }
304
305    fn push_interval(
306        &mut self,
307        value: Option<PgInterval>,
308        to: wasmtime::component::Resource<SqlArguments>,
309    ) -> PushResult {
310        let value = value.map(|v| SqlxPgInterval {
311            months: v.months,
312            days: v.days,
313            microseconds: v.microseconds,
314        });
315
316        codec_utils::encode(value, self.table.get(&to)?)
317    }
318
319    fn get_interval(
320        &mut self,
321        result: wasmtime::component::Resource<QueryResults>,
322        position: ValuePosition,
323    ) -> Result<Option<PgInterval>, Error> {
324        let value: Option<SqlxPgInterval> =
325            codec_utils::decode(self.table.get(&result)?, position)?;
326
327        Ok(value.map(|v| PgInterval {
328            months: v.months,
329            days: v.days,
330            microseconds: v.microseconds,
331        }))
332    }
333
334    // === Network codecs ===
335
336    fn push_inet(
337        &mut self,
338        value: Option<Inet>,
339        to: wasmtime::component::Resource<SqlArguments>,
340    ) -> PushResult {
341        let value: Option<IpNet> = value.as_ref().map(Into::into);
342
343        codec_utils::encode(value, self.table.get(&to)?)
344    }
345
346    fn get_inet(
347        &mut self,
348        result: wasmtime::component::Resource<QueryResults>,
349        position: ValuePosition,
350    ) -> Result<Option<Inet>, Error> {
351        let value: Option<IpNet> = codec_utils::decode(self.table.get(&result)?, position)?;
352
353        Ok(value.as_ref().map(Into::into))
354    }
355
356    fn push_cidr(
357        &mut self,
358        value: Option<Inet>,
359        to: wasmtime::component::Resource<SqlArguments>,
360    ) -> PushResult {
361        let value: Option<IpNet> = value.as_ref().map(Into::into);
362
363        codec_utils::encode(value, self.table.get(&to)?)
364    }
365
366    fn get_cidr(
367        &mut self,
368        result: wasmtime::component::Resource<QueryResults>,
369        position: ValuePosition,
370    ) -> Result<Option<Inet>, Error> {
371        let value: Option<IpNet> = codec_utils::decode(self.table.get(&result)?, position)?;
372
373        Ok(value.as_ref().map(Into::into))
374    }
375
376    fn push_macaddr(
377        &mut self,
378        value: Option<Macaddr>,
379        to: wasmtime::component::Resource<SqlArguments>,
380    ) -> PushResult {
381        let value = value.map(|v| MacAddress::new(v.into()));
382
383        codec_utils::encode(value, self.table.get(&to)?)
384    }
385
386    fn get_macaddr(
387        &mut self,
388        result: wasmtime::component::Resource<QueryResults>,
389        position: ValuePosition,
390    ) -> Result<Option<Macaddr>, Error> {
391        let value: Option<MacAddress> = codec_utils::decode(self.table.get(&result)?, position)?;
392
393        Ok(value.map(|v| {
394            let b = v.bytes();
395            (b[0], b[1], b[2], b[3], b[4], b[5])
396        }))
397    }
398
399    // === Numeric codec ===
400
401    fn push_numeric(
402        &mut self,
403        value: Option<Numeric>,
404        to: wasmtime::component::Resource<SqlArguments>,
405    ) -> PushResult {
406        let value = value
407            .map(|v| BigDecimal::from_str(&v))
408            .transpose()
409            .map_err(|e| Error::Encode(e.to_string()))?;
410
411        codec_utils::encode(value, self.table.get(&to)?)
412    }
413
414    fn get_numeric(
415        &mut self,
416        result: wasmtime::component::Resource<QueryResults>,
417        position: ValuePosition,
418    ) -> Result<Option<Numeric>, Error> {
419        let value: Option<BigDecimal> = codec_utils::decode(self.table.get(&result)?, position)?;
420
421        Ok(value.map(|v| v.to_string()))
422    }
423}
424
425// Conversions for IP address types
426
427impl From<&Inet> for IpNet {
428    fn from(inet: &Inet) -> Self {
429        match &inet.addr {
430            IpAddr::V4(v4) => {
431                let addr = Ipv4Addr::new(v4.0, v4.1, v4.2, v4.3);
432                IpNet::V4(ipnet::Ipv4Net::new(addr, inet.prefix_len).unwrap_or_else(|_| {
433                    ipnet::Ipv4Net::new(addr, 32).unwrap()
434                }))
435            }
436            IpAddr::V6(v6) => {
437                let addr = Ipv6Addr::new(v6.0, v6.1, v6.2, v6.3, v6.4, v6.5, v6.6, v6.7);
438                IpNet::V6(ipnet::Ipv6Net::new(addr, inet.prefix_len).unwrap_or_else(|_| {
439                    ipnet::Ipv6Net::new(addr, 128).unwrap()
440                }))
441            }
442        }
443    }
444}
445
446impl From<&IpNet> for Inet {
447    fn from(ipnet: &IpNet) -> Self {
448        match ipnet {
449            IpNet::V4(v4) => {
450                let octets = v4.addr().octets();
451                Inet {
452                    addr: IpAddr::V4((octets[0], octets[1], octets[2], octets[3])),
453                    prefix_len: v4.prefix_len(),
454                }
455            }
456            IpNet::V6(v6) => {
457                let segments = v6.addr().segments();
458                Inet {
459                    addr: IpAddr::V6((
460                        segments[0],
461                        segments[1],
462                        segments[2],
463                        segments[3],
464                        segments[4],
465                        segments[5],
466                        segments[6],
467                        segments[7],
468                    )),
469                    prefix_len: v6.prefix_len(),
470                }
471            }
472        }
473    }
474}