sqlx_scylladb_core/
arguments.rs

1use std::{
2    collections::HashMap,
3    net::IpAddr,
4    ops::{Deref, DerefMut},
5    sync::Arc,
6};
7
8use scylla::{
9    cluster::metadata::ColumnType,
10    errors::SerializationError,
11    serialize::{
12        row::{RowSerializationContext, SerializeRow},
13        value::SerializeValue,
14        writers::{CellWriter, RowWriter, WrittenCellProof},
15    },
16    value::{CqlDate, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid},
17};
18use sqlx::Arguments;
19use uuid::Uuid;
20
21use crate::{ScyllaDB, ScyllaDBTypeInfo};
22
23/// Implementation of [sqlx::Arguments] for ScyllaDB.
24#[derive(Default)]
25pub struct ScyllaDBArguments {
26    pub(crate) types: Vec<ScyllaDBTypeInfo>,
27    pub(crate) buffer: ScyllaDBArgumentBuffer,
28}
29
30impl<'q> Arguments<'q> for ScyllaDBArguments {
31    type Database = ScyllaDB;
32
33    fn reserve(&mut self, additional: usize, size: usize) {
34        self.types.reserve(additional);
35        self.buffer.reserve(size);
36    }
37
38    fn add<T>(&mut self, value: T) -> Result<(), sqlx::error::BoxDynError>
39    where
40        T: 'q + sqlx::Encode<'q, Self::Database> + sqlx::Type<Self::Database>,
41    {
42        let ty = value.produces().unwrap_or_else(T::type_info);
43        let is_null = value.encode(&mut self.buffer)?;
44        if is_null.is_null() {
45            self.buffer.push(ScyllaDBArgument::Null);
46        }
47
48        self.types.push(ty);
49
50        Ok(())
51    }
52
53    #[inline(always)]
54    fn len(&self) -> usize {
55        self.buffer.len()
56    }
57}
58
59impl SerializeRow for ScyllaDBArguments {
60    fn serialize(
61        &self,
62        ctx: &RowSerializationContext<'_>,
63        writer: &mut RowWriter,
64    ) -> Result<(), SerializationError> {
65        let columns = ctx.columns();
66        for (i, column) in columns.iter().enumerate() {
67            if let Some(argument) = self.buffer.get(i) {
68                let cell_writer = writer.make_cell_writer();
69                let typ = column.typ();
70                argument.serialize(typ, cell_writer)?;
71            }
72        }
73
74        Ok(())
75    }
76
77    #[inline(always)]
78    fn is_empty(&self) -> bool {
79        self.buffer.is_empty()
80    }
81}
82
83/// An array of [ScyllaDBArguments] used during encoding.
84#[derive(Default)]
85pub struct ScyllaDBArgumentBuffer {
86    pub(crate) buffer: Vec<ScyllaDBArgument>,
87}
88
89impl Deref for ScyllaDBArgumentBuffer {
90    type Target = Vec<ScyllaDBArgument>;
91
92    fn deref(&self) -> &Self::Target {
93        &self.buffer
94    }
95}
96
97impl<'q> DerefMut for ScyllaDBArgumentBuffer {
98    fn deref_mut(&mut self) -> &mut Self::Target {
99        &mut self.buffer
100    }
101}
102
103/// The enum of data types that can be handled by scylla-rust-driver.
104pub enum ScyllaDBArgument {
105    /// Internally used NULL.
106    Null,
107    /// Any type can be used.
108    Any(Arc<dyn SerializeValue + Send + Sync>),
109    /// `boolean` type.
110    Boolean(bool),
111    /// array of `boolean` type.
112    BooleanArray(Arc<Vec<bool>>),
113    /// `tinyint` type.
114    TinyInt(i8),
115    /// array of `tinyint` type.
116    TinyIntArray(Arc<Vec<i8>>),
117    /// `smallint` type.
118    SmallInt(i16),
119    /// array of `smallint` type
120    SmallIntArray(Arc<Vec<i16>>),
121    /// `int` type.
122    Int(i32),
123    /// array of `int` type.
124    IntArray(Arc<Vec<i32>>),
125    /// `bigint` type.
126    BigInt(i64),
127    /// array of `bigint` type.
128    BigIntArray(Arc<Vec<i64>>),
129    /// `float` type.
130    Float(f32),
131    /// array of `float` type.
132    FloatArray(Arc<Vec<f32>>),
133    /// `double` type.
134    Double(f64),
135    /// array of `double` type.
136    DoubleArray(Arc<Vec<f64>>),
137    /// `text` or `ascii` type.
138    Text(Arc<String>),
139    /// array of `text` or `ascii` type.
140    TextArray(Arc<Vec<String>>),
141    /// `text` or `ascii` type implemented with [secrecy_08] crate.
142    #[cfg(feature = "secrecy-08")]
143    SecretText(Arc<secrecy_08::SecretString>),
144    /// `blob` type.
145    Blob(Arc<Vec<u8>>),
146    /// array of `blob` type.
147    BlobArray(Arc<Vec<Vec<u8>>>),
148    /// `blob` type implemented with [secrecy_08] crate.
149    #[cfg(feature = "secrecy-08")]
150    SecretBlob(Arc<secrecy_08::SecretVec<u8>>),
151    /// `uuid` type.
152    Uuid(Uuid),
153    /// array of `uuid` type.
154    UuidArray(Arc<Vec<Uuid>>),
155    /// `timeuuid` type.
156    Timeuuid(CqlTimeuuid),
157    /// array of `timeuuid` type.
158    TimeuuidArray(Arc<Vec<CqlTimeuuid>>),
159    /// `inet` type.
160    IpAddr(IpAddr),
161    /// array of `inet` type.
162    IpAddrArray(Arc<Vec<IpAddr>>),
163    /// `duration` type.
164    Duration(CqlDuration),
165    /// array of `duration` type.
166    DurationArray(Arc<Vec<CqlDuration>>),
167    /// `decimal` type.
168    #[cfg(feature = "bigdecimal-04")]
169    BigDecimal(bigdecimal_04::BigDecimal),
170    /// array of `decimal` type.
171    #[cfg(feature = "bigdecimal-04")]
172    BigDecimalArray(Arc<Vec<bigdecimal_04::BigDecimal>>),
173    /// `timestamp` type.
174    CqlTimestamp(CqlTimestamp),
175    /// array of `timestamp` type.
176    CqlTimestampArray(Arc<Vec<CqlTimestamp>>),
177    /// `timestamp` type implemented with [time_03] crate.
178    #[cfg(feature = "time-03")]
179    OffsetDateTime(time_03::OffsetDateTime),
180    /// array of `timestamp` type implemented with [time_03] crate.
181    #[cfg(feature = "time-03")]
182    OffsetDateTimeArray(Arc<Vec<time_03::OffsetDateTime>>),
183    /// `timestamp` type implemented with [chrono_04] crate.
184    #[cfg(feature = "chrono-04")]
185    ChronoDateTimeUTC(chrono_04::DateTime<chrono_04::Utc>),
186    /// array of `timestamp` type implemented with [chrono_04] crate.
187    #[cfg(feature = "chrono-04")]
188    ChronoDateTimeUTCArray(Arc<Vec<chrono_04::DateTime<chrono_04::Utc>>>),
189    /// `date` type.
190    CqlDate(CqlDate),
191    /// array of `date` type.
192    CqlDateArray(Arc<Vec<CqlDate>>),
193    /// `date` type implemented with [time_03] crate.
194    #[cfg(feature = "time-03")]
195    Date(time_03::Date),
196    /// array of `date` type implemented with [time_03] crate.
197    #[cfg(feature = "time-03")]
198    DateArray(Arc<Vec<time_03::Date>>),
199    /// `date` type implemented with [chrono_04] crate.
200    #[cfg(feature = "chrono-04")]
201    ChronoNaiveDate(chrono_04::NaiveDate),
202    /// array of `date` type implemented with [chrono_04] crate.
203    #[cfg(feature = "chrono-04")]
204    ChronoNaiveDateArray(Arc<Vec<chrono_04::NaiveDate>>),
205    /// `time` type.
206    CqlTime(CqlTime),
207    /// array of `time` type.
208    CqlTimeArray(Arc<Vec<CqlTime>>),
209    /// `time` type implemented with [time_03] crate.
210    #[cfg(feature = "time-03")]
211    Time(time_03::Time),
212    /// array of `time` type implemented with [time_03] crate.
213    #[cfg(feature = "time-03")]
214    TimeArray(Arc<Vec<time_03::Time>>),
215    /// `time` type implemented with [chrono_04] crate.
216    #[cfg(feature = "chrono-04")]
217    ChronoNaiveTime(chrono_04::NaiveTime),
218    /// array of `time` type implemented with [chrono_04] crate.
219    #[cfg(feature = "chrono-04")]
220    ChronoNaiveTimeArray(Arc<Vec<chrono_04::NaiveTime>>),
221    /// any tuple type.
222    Tuple(Arc<dyn SerializeValue + Send + Sync>),
223    /// user-defined type.
224    UserDefinedType(Arc<dyn SerializeValue + Send + Sync>),
225    /// array of user-defined type.
226    UserDefinedTypeArray(Arc<dyn SerializeValue + Send + Sync>),
227    /// map type for `text` and `text`.
228    TextTextMap(Arc<HashMap<String, String>>),
229    /// map type for `text` and `boolean`.
230    TextBooleanMap(Arc<HashMap<String, bool>>),
231    /// map type for `text` and `tinyint`.
232    TextTinyIntMap(Arc<HashMap<String, i8>>),
233    /// map type for `text` and `smallint`.
234    TextSmallIntMap(Arc<HashMap<String, i16>>),
235    /// map type for `text` and `int`.
236    TextIntMap(Arc<HashMap<String, i32>>),
237    /// map type for `text` and `bigint`.
238    TextBigIntMap(Arc<HashMap<String, i64>>),
239    /// map type for `text` and `float`.
240    TextFloatMap(Arc<HashMap<String, f32>>),
241    /// map type for `text` and `double`.
242    TextDoubleMap(Arc<HashMap<String, f64>>),
243    /// map type for `text` and `uuid`.
244    TextUuidMap(Arc<HashMap<String, Uuid>>),
245    /// map type for `text` and `inet`.
246    TextIpAddrMap(Arc<HashMap<String, IpAddr>>),
247}
248
249impl SerializeValue for ScyllaDBArgument {
250    fn serialize<'b>(
251        &self,
252        typ: &ColumnType,
253        writer: CellWriter<'b>,
254    ) -> Result<WrittenCellProof<'b>, SerializationError> {
255        match self {
256            Self::Any(value) => <_ as SerializeValue>::serialize(value, typ, writer),
257            Self::Null => Ok(writer.set_null()),
258            Self::Boolean(value) => <_ as SerializeValue>::serialize(value, typ, writer),
259            Self::BooleanArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
260            Self::TinyInt(value) => <_ as SerializeValue>::serialize(value, typ, writer),
261            Self::TinyIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
262            Self::SmallInt(value) => <_ as SerializeValue>::serialize(value, typ, writer),
263            Self::SmallIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
264            Self::Int(value) => <_ as SerializeValue>::serialize(value, typ, writer),
265            Self::IntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
266            Self::BigInt(value) => <_ as SerializeValue>::serialize(value, typ, writer),
267            Self::BigIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
268            Self::Float(value) => <_ as SerializeValue>::serialize(value, typ, writer),
269            Self::FloatArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
270            Self::Double(value) => <_ as SerializeValue>::serialize(value, typ, writer),
271            Self::DoubleArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
272            Self::Text(value) => <_ as SerializeValue>::serialize(value, typ, writer),
273            Self::TextArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
274            #[cfg(feature = "secrecy-08")]
275            Self::SecretText(value) => <_ as SerializeValue>::serialize(value, typ, writer),
276            Self::Blob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
277            Self::BlobArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
278            #[cfg(feature = "secrecy-08")]
279            Self::SecretBlob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
280            Self::Uuid(uuid) => <_ as SerializeValue>::serialize(uuid, typ, writer),
281            Self::UuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
282            Self::Timeuuid(timeuuid) => <_ as SerializeValue>::serialize(timeuuid, typ, writer),
283            Self::TimeuuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
284            Self::IpAddr(ip_addr) => <_ as SerializeValue>::serialize(ip_addr, typ, writer),
285            Self::IpAddrArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
286            Self::Duration(value) => <_ as SerializeValue>::serialize(value, typ, writer),
287            Self::DurationArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
288            #[cfg(feature = "bigdecimal-04")]
289            Self::BigDecimal(value) => <_ as SerializeValue>::serialize(value, typ, writer),
290            #[cfg(feature = "bigdecimal-04")]
291            Self::BigDecimalArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
292            Self::CqlTimestamp(value) => <_ as SerializeValue>::serialize(value, typ, writer),
293            Self::CqlTimestampArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
294            #[cfg(feature = "time-03")]
295            Self::OffsetDateTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
296            #[cfg(feature = "time-03")]
297            Self::OffsetDateTimeArray(value) => {
298                <_ as SerializeValue>::serialize(value, typ, writer)
299            }
300            #[cfg(feature = "chrono-04")]
301            Self::ChronoDateTimeUTC(value) => <_ as SerializeValue>::serialize(value, typ, writer),
302            #[cfg(feature = "chrono-04")]
303            Self::ChronoDateTimeUTCArray(value) => {
304                <_ as SerializeValue>::serialize(value, typ, writer)
305            }
306            Self::CqlTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
307            Self::CqlTimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
308            #[cfg(feature = "time-03")]
309            Self::Time(value) => <_ as SerializeValue>::serialize(value, typ, writer),
310            #[cfg(feature = "time-03")]
311            Self::TimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
312            #[cfg(feature = "chrono-04")]
313            Self::ChronoNaiveTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
314            #[cfg(feature = "chrono-04")]
315            Self::ChronoNaiveTimeArray(value) => {
316                <_ as SerializeValue>::serialize(value, typ, writer)
317            }
318            Self::CqlDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
319            Self::CqlDateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
320            #[cfg(feature = "time-03")]
321            Self::Date(value) => <_ as SerializeValue>::serialize(value, typ, writer),
322            #[cfg(feature = "time-03")]
323            Self::DateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
324            #[cfg(feature = "chrono-04")]
325            Self::ChronoNaiveDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
326            #[cfg(feature = "chrono-04")]
327            Self::ChronoNaiveDateArray(value) => {
328                <_ as SerializeValue>::serialize(value, typ, writer)
329            }
330            Self::Tuple(value) => <_ as SerializeValue>::serialize(value, typ, writer),
331            Self::UserDefinedType(value) => <_ as SerializeValue>::serialize(value, typ, writer),
332            Self::UserDefinedTypeArray(value) => {
333                <_ as SerializeValue>::serialize(value, typ, writer)
334            }
335            Self::TextTextMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
336            Self::TextBooleanMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
337            Self::TextTinyIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
338            Self::TextSmallIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
339            Self::TextIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
340            Self::TextBigIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
341            Self::TextFloatMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
342            Self::TextDoubleMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
343            Self::TextUuidMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
344            Self::TextIpAddrMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
345        }
346    }
347}