sqlx_scylladb_core/
arguments.rs

1use std::{
2    collections::HashMap,
3    net::IpAddr,
4    ops::{Deref, DerefMut},
5    sync::Arc,
6};
7
8use scylla::{
9    cluster::metadata::{ColumnType, NativeType},
10    errors::SerializationError,
11    serialize::{
12        row::{RowSerializationContext, SerializeRow},
13        value::SerializeValue,
14        writers::{CellWriter, RowWriter, WrittenCellProof},
15    },
16    value::{Counter, CqlDate, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid},
17};
18use sqlx::Arguments;
19use uuid::Uuid;
20
21use crate::{ScyllaDB, ScyllaDBTypeInfo};
22
23/// Implementation of [sqlx::Arguments] for ScyllaDB.
24#[derive(Default)]
25pub struct ScyllaDBArguments {
26    pub(crate) types: Vec<ScyllaDBTypeInfo>,
27    pub(crate) buffer: ScyllaDBArgumentBuffer,
28}
29
30impl<'q> Arguments<'q> for ScyllaDBArguments {
31    type Database = ScyllaDB;
32
33    fn reserve(&mut self, additional: usize, size: usize) {
34        self.types.reserve(additional);
35        self.buffer.reserve(size);
36    }
37
38    fn add<T>(&mut self, value: T) -> Result<(), sqlx::error::BoxDynError>
39    where
40        T: 'q + sqlx::Encode<'q, Self::Database> + sqlx::Type<Self::Database>,
41    {
42        let ty = value.produces().unwrap_or_else(T::type_info);
43        let is_null = value.encode(&mut self.buffer)?;
44        if is_null.is_null() {
45            self.buffer.push(ScyllaDBArgument::Null);
46        }
47
48        self.types.push(ty);
49
50        Ok(())
51    }
52
53    #[inline(always)]
54    fn len(&self) -> usize {
55        self.buffer.len()
56    }
57}
58
59impl SerializeRow for ScyllaDBArguments {
60    fn serialize(
61        &self,
62        ctx: &RowSerializationContext<'_>,
63        writer: &mut RowWriter,
64    ) -> Result<(), SerializationError> {
65        let columns = ctx.columns();
66        for (i, column) in columns.iter().enumerate() {
67            if let Some(argument) = self.buffer.get(i) {
68                let cell_writer = writer.make_cell_writer();
69                let typ = column.typ();
70                argument.serialize(typ, cell_writer)?;
71            }
72        }
73
74        Ok(())
75    }
76
77    #[inline(always)]
78    fn is_empty(&self) -> bool {
79        self.buffer.is_empty()
80    }
81}
82
83/// An array of [ScyllaDBArguments] used during encoding.
84#[derive(Default)]
85pub struct ScyllaDBArgumentBuffer {
86    pub(crate) buffer: Vec<ScyllaDBArgument>,
87}
88
89impl Deref for ScyllaDBArgumentBuffer {
90    type Target = Vec<ScyllaDBArgument>;
91
92    fn deref(&self) -> &Self::Target {
93        &self.buffer
94    }
95}
96
97impl<'q> DerefMut for ScyllaDBArgumentBuffer {
98    fn deref_mut(&mut self) -> &mut Self::Target {
99        &mut self.buffer
100    }
101}
102
103/// The enum of data types that can be handled by scylla-rust-driver.
104pub enum ScyllaDBArgument {
105    /// Internally used NULL.
106    Null,
107    /// Internally used Unset.
108    Unset,
109    /// Any type can be used.
110    Any(Arc<dyn SerializeValue + Send + Sync>),
111    /// `boolean` type.
112    Boolean(bool),
113    /// array of `boolean` type.
114    BooleanArray(Vec<bool>),
115    /// `tinyint` type.
116    TinyInt(i8),
117    /// array of `tinyint` type.
118    TinyIntArray(Vec<i8>),
119    /// `smallint` type.
120    SmallInt(i16),
121    /// array of `smallint` type
122    SmallIntArray(Vec<i16>),
123    /// `int` type.
124    Int(i32),
125    /// array of `int` type.
126    IntArray(Vec<i32>),
127    /// `bigint` type.
128    BigInt(i64),
129    /// array of `bigint` type.
130    BigIntArray(Vec<i64>),
131    /// `float` type.
132    Float(f32),
133    /// array of `float` type.
134    FloatArray(Vec<f32>),
135    /// `double` type.
136    Double(f64),
137    /// array of `double` type.
138    DoubleArray(Vec<f64>),
139    /// `text` or `ascii` type.
140    Text(String),
141    /// array of `text` or `ascii` type.
142    TextArray(Vec<String>),
143    /// `text` or `ascii` type implemented with [secrecy_08] crate.
144    #[cfg(feature = "secrecy-08")]
145    SecretText(secrecy_08::SecretString),
146    /// array of `text` or `ascii` type implemented with [secrecy_08] crate.
147    #[cfg(feature = "secrecy-08")]
148    SecretTextArray(Vec<secrecy_08::SecretString>),
149    /// `blob` type.
150    Blob(Vec<u8>),
151    /// array of `blob` type.
152    BlobArray(Vec<Vec<u8>>),
153    /// `blob` type implemented with [secrecy_08] crate.
154    #[cfg(feature = "secrecy-08")]
155    SecretBlob(secrecy_08::SecretVec<u8>),
156    /// array of `blob` type implemented with [secrecy_08] crate.
157    #[cfg(feature = "secrecy-08")]
158    SecretBlobArray(Vec<secrecy_08::SecretVec<u8>>),
159    /// `uuid` type.
160    Uuid(Uuid),
161    /// array of `uuid` type.
162    UuidArray(Vec<Uuid>),
163    /// `timeuuid` type.
164    Timeuuid(CqlTimeuuid),
165    /// array of `timeuuid` type.
166    TimeuuidArray(Vec<CqlTimeuuid>),
167    /// `inet` type.
168    IpAddr(IpAddr),
169    /// array of `inet` type.
170    IpAddrArray(Vec<IpAddr>),
171    /// `duration` type.
172    Duration(CqlDuration),
173    /// array of `duration` type.
174    DurationArray(Vec<CqlDuration>),
175    /// `decimal` type.
176    #[cfg(feature = "bigdecimal-04")]
177    BigDecimal(bigdecimal_04::BigDecimal),
178    /// array of `decimal` type.
179    #[cfg(feature = "bigdecimal-04")]
180    BigDecimalArray(Vec<bigdecimal_04::BigDecimal>),
181    /// `timestamp` type.
182    CqlTimestamp(CqlTimestamp),
183    /// array of `timestamp` type.
184    CqlTimestampArray(Vec<CqlTimestamp>),
185    /// `timestamp` type implemented with [time_03] crate.
186    #[cfg(feature = "time-03")]
187    OffsetDateTime(time_03::OffsetDateTime),
188    /// array of `timestamp` type implemented with [time_03] crate.
189    #[cfg(feature = "time-03")]
190    OffsetDateTimeArray(Vec<time_03::OffsetDateTime>),
191    /// `timestamp` type implemented with [chrono_04] crate.
192    #[cfg(feature = "chrono-04")]
193    ChronoDateTimeUTC(chrono_04::DateTime<chrono_04::Utc>),
194    /// array of `timestamp` type implemented with [chrono_04] crate.
195    #[cfg(feature = "chrono-04")]
196    ChronoDateTimeUTCArray(Vec<chrono_04::DateTime<chrono_04::Utc>>),
197    /// `date` type.
198    CqlDate(CqlDate),
199    /// array of `date` type.
200    CqlDateArray(Vec<CqlDate>),
201    /// `date` type implemented with [time_03] crate.
202    #[cfg(feature = "time-03")]
203    Date(time_03::Date),
204    /// array of `date` type implemented with [time_03] crate.
205    #[cfg(feature = "time-03")]
206    DateArray(Vec<time_03::Date>),
207    /// `date` type implemented with [chrono_04] crate.
208    #[cfg(feature = "chrono-04")]
209    ChronoNaiveDate(chrono_04::NaiveDate),
210    /// array of `date` type implemented with [chrono_04] crate.
211    #[cfg(feature = "chrono-04")]
212    ChronoNaiveDateArray(Vec<chrono_04::NaiveDate>),
213    /// `time` type.
214    CqlTime(CqlTime),
215    /// array of `time` type.
216    CqlTimeArray(Vec<CqlTime>),
217    /// `time` type implemented with [time_03] crate.
218    #[cfg(feature = "time-03")]
219    Time(time_03::Time),
220    /// array of `time` type implemented with [time_03] crate.
221    #[cfg(feature = "time-03")]
222    TimeArray(Vec<time_03::Time>),
223    /// `time` type implemented with [chrono_04] crate.
224    #[cfg(feature = "chrono-04")]
225    ChronoNaiveTime(chrono_04::NaiveTime),
226    /// array of `time` type implemented with [chrono_04] crate.
227    #[cfg(feature = "chrono-04")]
228    ChronoNaiveTimeArray(Vec<chrono_04::NaiveTime>),
229    /// any tuple type.
230    Tuple(Box<dyn SerializeValue + Send + Sync>),
231    /// user-defined type.
232    UserDefinedType(Box<dyn SerializeValue + Send + Sync>),
233    /// array of user-defined type.
234    UserDefinedTypeArray(Vec<Box<dyn SerializeValue + Send + Sync>>),
235    /// map type for `text` and `text`.
236    TextTextMap(HashMap<String, String>),
237    /// map type for `text` and `boolean`.
238    TextBooleanMap(HashMap<String, bool>),
239    /// map type for `text` and `tinyint`.
240    TextTinyIntMap(HashMap<String, i8>),
241    /// map type for `text` and `smallint`.
242    TextSmallIntMap(HashMap<String, i16>),
243    /// map type for `text` and `int`.
244    TextIntMap(HashMap<String, i32>),
245    /// map type for `text` and `bigint`.
246    TextBigIntMap(HashMap<String, i64>),
247    /// map type for `text` and `float`.
248    TextFloatMap(HashMap<String, f32>),
249    /// map type for `text` and `double`.
250    TextDoubleMap(HashMap<String, f64>),
251    /// map type for `text` and `uuid`.
252    TextUuidMap(HashMap<String, Uuid>),
253    /// map type for `text` and `inet`.
254    TextIpAddrMap(HashMap<String, IpAddr>),
255}
256
257impl SerializeValue for ScyllaDBArgument {
258    fn serialize<'b>(
259        &self,
260        typ: &ColumnType,
261        writer: CellWriter<'b>,
262    ) -> Result<WrittenCellProof<'b>, SerializationError> {
263        match self {
264            Self::Any(value) => <_ as SerializeValue>::serialize(value, typ, writer),
265            Self::Null => Ok(writer.set_null()),
266            Self::Unset => Ok(writer.set_unset()),
267            Self::Boolean(value) => <_ as SerializeValue>::serialize(value, typ, writer),
268            Self::BooleanArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
269            Self::TinyInt(value) => {
270                if ColumnType::Native(NativeType::Counter) == *typ {
271                    <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
272                } else {
273                    <_ as SerializeValue>::serialize(value, typ, writer)
274                }
275            }
276            Self::TinyIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
277            Self::SmallInt(value) => {
278                if ColumnType::Native(NativeType::Counter) == *typ {
279                    <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
280                } else {
281                    <_ as SerializeValue>::serialize(value, typ, writer)
282                }
283            }
284            Self::SmallIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
285            Self::Int(value) => {
286                if ColumnType::Native(NativeType::Counter) == *typ {
287                    <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
288                } else {
289                    <_ as SerializeValue>::serialize(value, typ, writer)
290                }
291            }
292            Self::IntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
293            Self::BigInt(value) => {
294                if ColumnType::Native(NativeType::Counter) == *typ {
295                    <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
296                } else {
297                    <_ as SerializeValue>::serialize(value, typ, writer)
298                }
299            }
300            Self::BigIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
301            Self::Float(value) => <_ as SerializeValue>::serialize(value, typ, writer),
302            Self::FloatArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
303            Self::Double(value) => <_ as SerializeValue>::serialize(value, typ, writer),
304            Self::DoubleArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
305            Self::Text(value) => <_ as SerializeValue>::serialize(value, typ, writer),
306            Self::TextArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
307            #[cfg(feature = "secrecy-08")]
308            Self::SecretText(value) => <_ as SerializeValue>::serialize(value, typ, writer),
309            #[cfg(feature = "secrecy-08")]
310            Self::SecretTextArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
311            Self::Blob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
312            Self::BlobArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
313            #[cfg(feature = "secrecy-08")]
314            Self::SecretBlob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
315            #[cfg(feature = "secrecy-08")]
316            Self::SecretBlobArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
317            Self::Uuid(uuid) => <_ as SerializeValue>::serialize(uuid, typ, writer),
318            Self::UuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
319            Self::Timeuuid(timeuuid) => <_ as SerializeValue>::serialize(timeuuid, typ, writer),
320            Self::TimeuuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
321            Self::IpAddr(ip_addr) => <_ as SerializeValue>::serialize(ip_addr, typ, writer),
322            Self::IpAddrArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
323            Self::Duration(value) => <_ as SerializeValue>::serialize(value, typ, writer),
324            Self::DurationArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
325            #[cfg(feature = "bigdecimal-04")]
326            Self::BigDecimal(value) => <_ as SerializeValue>::serialize(value, typ, writer),
327            #[cfg(feature = "bigdecimal-04")]
328            Self::BigDecimalArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
329            Self::CqlTimestamp(value) => <_ as SerializeValue>::serialize(value, typ, writer),
330            Self::CqlTimestampArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
331            #[cfg(feature = "time-03")]
332            Self::OffsetDateTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
333            #[cfg(feature = "time-03")]
334            Self::OffsetDateTimeArray(value) => {
335                <_ as SerializeValue>::serialize(value, typ, writer)
336            }
337            #[cfg(feature = "chrono-04")]
338            Self::ChronoDateTimeUTC(value) => <_ as SerializeValue>::serialize(value, typ, writer),
339            #[cfg(feature = "chrono-04")]
340            Self::ChronoDateTimeUTCArray(value) => {
341                <_ as SerializeValue>::serialize(value, typ, writer)
342            }
343            Self::CqlTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
344            Self::CqlTimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
345            #[cfg(feature = "time-03")]
346            Self::Time(value) => <_ as SerializeValue>::serialize(value, typ, writer),
347            #[cfg(feature = "time-03")]
348            Self::TimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
349            #[cfg(feature = "chrono-04")]
350            Self::ChronoNaiveTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
351            #[cfg(feature = "chrono-04")]
352            Self::ChronoNaiveTimeArray(value) => {
353                <_ as SerializeValue>::serialize(value, typ, writer)
354            }
355            Self::CqlDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
356            Self::CqlDateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
357            #[cfg(feature = "time-03")]
358            Self::Date(value) => <_ as SerializeValue>::serialize(value, typ, writer),
359            #[cfg(feature = "time-03")]
360            Self::DateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
361            #[cfg(feature = "chrono-04")]
362            Self::ChronoNaiveDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
363            #[cfg(feature = "chrono-04")]
364            Self::ChronoNaiveDateArray(value) => {
365                <_ as SerializeValue>::serialize(value, typ, writer)
366            }
367            Self::Tuple(value) => <_ as SerializeValue>::serialize(value, typ, writer),
368            Self::UserDefinedType(value) => <_ as SerializeValue>::serialize(value, typ, writer),
369            Self::UserDefinedTypeArray(value) => {
370                <_ as SerializeValue>::serialize(value, typ, writer)
371            }
372            Self::TextTextMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
373            Self::TextBooleanMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
374            Self::TextTinyIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
375            Self::TextSmallIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
376            Self::TextIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
377            Self::TextBigIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
378            Self::TextFloatMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
379            Self::TextDoubleMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
380            Self::TextUuidMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
381            Self::TextIpAddrMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
382        }
383    }
384}