1use std::{
2 collections::HashMap,
3 net::IpAddr,
4 ops::{Deref, DerefMut},
5 sync::Arc,
6};
7
8use scylla::{
9 cluster::metadata::ColumnType,
10 errors::SerializationError,
11 serialize::{
12 row::{RowSerializationContext, SerializeRow},
13 value::SerializeValue,
14 writers::{CellWriter, RowWriter, WrittenCellProof},
15 },
16 value::{CqlDate, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid},
17};
18use sqlx::Arguments;
19use uuid::Uuid;
20
21use crate::{ScyllaDB, ScyllaDBTypeInfo};
22
23#[derive(Default)]
25pub struct ScyllaDBArguments {
26 pub(crate) types: Vec<ScyllaDBTypeInfo>,
27 pub(crate) buffer: ScyllaDBArgumentBuffer,
28}
29
30impl<'q> Arguments<'q> for ScyllaDBArguments {
31 type Database = ScyllaDB;
32
33 fn reserve(&mut self, additional: usize, size: usize) {
34 self.types.reserve(additional);
35 self.buffer.reserve(size);
36 }
37
38 fn add<T>(&mut self, value: T) -> Result<(), sqlx::error::BoxDynError>
39 where
40 T: 'q + sqlx::Encode<'q, Self::Database> + sqlx::Type<Self::Database>,
41 {
42 let ty = value.produces().unwrap_or_else(T::type_info);
43 let is_null = value.encode(&mut self.buffer)?;
44 if is_null.is_null() {
45 self.buffer.push(ScyllaDBArgument::Null);
46 }
47
48 self.types.push(ty);
49
50 Ok(())
51 }
52
53 #[inline(always)]
54 fn len(&self) -> usize {
55 self.buffer.len()
56 }
57}
58
59impl SerializeRow for ScyllaDBArguments {
60 fn serialize(
61 &self,
62 ctx: &RowSerializationContext<'_>,
63 writer: &mut RowWriter,
64 ) -> Result<(), SerializationError> {
65 let columns = ctx.columns();
66 for (i, column) in columns.iter().enumerate() {
67 if let Some(argument) = self.buffer.get(i) {
68 let cell_writer = writer.make_cell_writer();
69 let typ = column.typ();
70 argument.serialize(typ, cell_writer)?;
71 }
72 }
73
74 Ok(())
75 }
76
77 #[inline(always)]
78 fn is_empty(&self) -> bool {
79 self.buffer.is_empty()
80 }
81}
82
83#[derive(Default)]
85pub struct ScyllaDBArgumentBuffer {
86 pub(crate) buffer: Vec<ScyllaDBArgument>,
87}
88
89impl Deref for ScyllaDBArgumentBuffer {
90 type Target = Vec<ScyllaDBArgument>;
91
92 fn deref(&self) -> &Self::Target {
93 &self.buffer
94 }
95}
96
97impl<'q> DerefMut for ScyllaDBArgumentBuffer {
98 fn deref_mut(&mut self) -> &mut Self::Target {
99 &mut self.buffer
100 }
101}
102
103pub enum ScyllaDBArgument {
105 Null,
107 Any(Arc<dyn SerializeValue + Send + Sync>),
109 Boolean(bool),
111 BooleanArray(Arc<Vec<bool>>),
113 TinyInt(i8),
115 TinyIntArray(Arc<Vec<i8>>),
117 SmallInt(i16),
119 SmallIntArray(Arc<Vec<i16>>),
121 Int(i32),
123 IntArray(Arc<Vec<i32>>),
125 BigInt(i64),
127 BigIntArray(Arc<Vec<i64>>),
129 Float(f32),
131 FloatArray(Arc<Vec<f32>>),
133 Double(f64),
135 DoubleArray(Arc<Vec<f64>>),
137 Text(Arc<String>),
139 TextArray(Arc<Vec<String>>),
141 #[cfg(feature = "secrecy-08")]
143 SecretText(Arc<secrecy_08::SecretString>),
144 Blob(Arc<Vec<u8>>),
146 BlobArray(Arc<Vec<Vec<u8>>>),
148 #[cfg(feature = "secrecy-08")]
150 SecretBlob(Arc<secrecy_08::SecretVec<u8>>),
151 Uuid(Uuid),
153 UuidArray(Arc<Vec<Uuid>>),
155 Timeuuid(CqlTimeuuid),
157 TimeuuidArray(Arc<Vec<CqlTimeuuid>>),
159 IpAddr(IpAddr),
161 IpAddrArray(Arc<Vec<IpAddr>>),
163 Duration(CqlDuration),
165 DurationArray(Arc<Vec<CqlDuration>>),
167 #[cfg(feature = "bigdecimal-04")]
169 BigDecimal(bigdecimal_04::BigDecimal),
170 #[cfg(feature = "bigdecimal-04")]
172 BigDecimalArray(Arc<Vec<bigdecimal_04::BigDecimal>>),
173 CqlTimestamp(CqlTimestamp),
175 CqlTimestampArray(Arc<Vec<CqlTimestamp>>),
177 #[cfg(feature = "time-03")]
179 OffsetDateTime(time_03::OffsetDateTime),
180 #[cfg(feature = "time-03")]
182 OffsetDateTimeArray(Arc<Vec<time_03::OffsetDateTime>>),
183 #[cfg(feature = "chrono-04")]
185 ChronoDateTimeUTC(chrono_04::DateTime<chrono_04::Utc>),
186 #[cfg(feature = "chrono-04")]
188 ChronoDateTimeUTCArray(Arc<Vec<chrono_04::DateTime<chrono_04::Utc>>>),
189 CqlDate(CqlDate),
191 CqlDateArray(Arc<Vec<CqlDate>>),
193 #[cfg(feature = "time-03")]
195 Date(time_03::Date),
196 #[cfg(feature = "time-03")]
198 DateArray(Arc<Vec<time_03::Date>>),
199 #[cfg(feature = "chrono-04")]
201 ChronoNaiveDate(chrono_04::NaiveDate),
202 #[cfg(feature = "chrono-04")]
204 ChronoNaiveDateArray(Arc<Vec<chrono_04::NaiveDate>>),
205 CqlTime(CqlTime),
207 CqlTimeArray(Arc<Vec<CqlTime>>),
209 #[cfg(feature = "time-03")]
211 Time(time_03::Time),
212 #[cfg(feature = "time-03")]
214 TimeArray(Arc<Vec<time_03::Time>>),
215 #[cfg(feature = "chrono-04")]
217 ChronoNaiveTime(chrono_04::NaiveTime),
218 #[cfg(feature = "chrono-04")]
220 ChronoNaiveTimeArray(Arc<Vec<chrono_04::NaiveTime>>),
221 Tuple(Arc<dyn SerializeValue + Send + Sync>),
223 UserDefinedType(Arc<dyn SerializeValue + Send + Sync>),
225 UserDefinedTypeArray(Arc<dyn SerializeValue + Send + Sync>),
227 TextTextMap(Arc<HashMap<String, String>>),
229 TextBooleanMap(Arc<HashMap<String, bool>>),
231 TextTinyIntMap(Arc<HashMap<String, i8>>),
233 TextSmallIntMap(Arc<HashMap<String, i16>>),
235 TextIntMap(Arc<HashMap<String, i32>>),
237 TextBigIntMap(Arc<HashMap<String, i64>>),
239 TextFloatMap(Arc<HashMap<String, f32>>),
241 TextDoubleMap(Arc<HashMap<String, f64>>),
243 TextUuidMap(Arc<HashMap<String, Uuid>>),
245 TextIpAddrMap(Arc<HashMap<String, IpAddr>>),
247}
248
249impl SerializeValue for ScyllaDBArgument {
250 fn serialize<'b>(
251 &self,
252 typ: &ColumnType,
253 writer: CellWriter<'b>,
254 ) -> Result<WrittenCellProof<'b>, SerializationError> {
255 match self {
256 Self::Any(value) => <_ as SerializeValue>::serialize(value, typ, writer),
257 Self::Null => Ok(writer.set_null()),
258 Self::Boolean(value) => <_ as SerializeValue>::serialize(value, typ, writer),
259 Self::BooleanArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
260 Self::TinyInt(value) => <_ as SerializeValue>::serialize(value, typ, writer),
261 Self::TinyIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
262 Self::SmallInt(value) => <_ as SerializeValue>::serialize(value, typ, writer),
263 Self::SmallIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
264 Self::Int(value) => <_ as SerializeValue>::serialize(value, typ, writer),
265 Self::IntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
266 Self::BigInt(value) => <_ as SerializeValue>::serialize(value, typ, writer),
267 Self::BigIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
268 Self::Float(value) => <_ as SerializeValue>::serialize(value, typ, writer),
269 Self::FloatArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
270 Self::Double(value) => <_ as SerializeValue>::serialize(value, typ, writer),
271 Self::DoubleArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
272 Self::Text(value) => <_ as SerializeValue>::serialize(value, typ, writer),
273 Self::TextArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
274 #[cfg(feature = "secrecy-08")]
275 Self::SecretText(value) => <_ as SerializeValue>::serialize(value, typ, writer),
276 Self::Blob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
277 Self::BlobArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
278 #[cfg(feature = "secrecy-08")]
279 Self::SecretBlob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
280 Self::Uuid(uuid) => <_ as SerializeValue>::serialize(uuid, typ, writer),
281 Self::UuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
282 Self::Timeuuid(timeuuid) => <_ as SerializeValue>::serialize(timeuuid, typ, writer),
283 Self::TimeuuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
284 Self::IpAddr(ip_addr) => <_ as SerializeValue>::serialize(ip_addr, typ, writer),
285 Self::IpAddrArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
286 Self::Duration(value) => <_ as SerializeValue>::serialize(value, typ, writer),
287 Self::DurationArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
288 #[cfg(feature = "bigdecimal-04")]
289 Self::BigDecimal(value) => <_ as SerializeValue>::serialize(value, typ, writer),
290 #[cfg(feature = "bigdecimal-04")]
291 Self::BigDecimalArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
292 Self::CqlTimestamp(value) => <_ as SerializeValue>::serialize(value, typ, writer),
293 Self::CqlTimestampArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
294 #[cfg(feature = "time-03")]
295 Self::OffsetDateTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
296 #[cfg(feature = "time-03")]
297 Self::OffsetDateTimeArray(value) => {
298 <_ as SerializeValue>::serialize(value, typ, writer)
299 }
300 #[cfg(feature = "chrono-04")]
301 Self::ChronoDateTimeUTC(value) => <_ as SerializeValue>::serialize(value, typ, writer),
302 #[cfg(feature = "chrono-04")]
303 Self::ChronoDateTimeUTCArray(value) => {
304 <_ as SerializeValue>::serialize(value, typ, writer)
305 }
306 Self::CqlTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
307 Self::CqlTimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
308 #[cfg(feature = "time-03")]
309 Self::Time(value) => <_ as SerializeValue>::serialize(value, typ, writer),
310 #[cfg(feature = "time-03")]
311 Self::TimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
312 #[cfg(feature = "chrono-04")]
313 Self::ChronoNaiveTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
314 #[cfg(feature = "chrono-04")]
315 Self::ChronoNaiveTimeArray(value) => {
316 <_ as SerializeValue>::serialize(value, typ, writer)
317 }
318 Self::CqlDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
319 Self::CqlDateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
320 #[cfg(feature = "time-03")]
321 Self::Date(value) => <_ as SerializeValue>::serialize(value, typ, writer),
322 #[cfg(feature = "time-03")]
323 Self::DateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
324 #[cfg(feature = "chrono-04")]
325 Self::ChronoNaiveDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
326 #[cfg(feature = "chrono-04")]
327 Self::ChronoNaiveDateArray(value) => {
328 <_ as SerializeValue>::serialize(value, typ, writer)
329 }
330 Self::Tuple(value) => <_ as SerializeValue>::serialize(value, typ, writer),
331 Self::UserDefinedType(value) => <_ as SerializeValue>::serialize(value, typ, writer),
332 Self::UserDefinedTypeArray(value) => {
333 <_ as SerializeValue>::serialize(value, typ, writer)
334 }
335 Self::TextTextMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
336 Self::TextBooleanMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
337 Self::TextTinyIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
338 Self::TextSmallIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
339 Self::TextIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
340 Self::TextBigIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
341 Self::TextFloatMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
342 Self::TextDoubleMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
343 Self::TextUuidMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
344 Self::TextIpAddrMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
345 }
346 }
347}