1use std::{
2 collections::HashMap,
3 net::IpAddr,
4 ops::{Deref, DerefMut},
5 sync::Arc,
6};
7
8use scylla::{
9 cluster::metadata::{ColumnType, NativeType},
10 errors::SerializationError,
11 serialize::{
12 row::{RowSerializationContext, SerializeRow},
13 value::SerializeValue,
14 writers::{CellWriter, RowWriter, WrittenCellProof},
15 },
16 value::{Counter, CqlDate, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid},
17};
18use sqlx::Arguments;
19use uuid::Uuid;
20
21use crate::{ScyllaDB, ScyllaDBTypeInfo};
22
23#[derive(Default)]
25pub struct ScyllaDBArguments {
26 pub(crate) types: Vec<ScyllaDBTypeInfo>,
27 pub(crate) buffer: ScyllaDBArgumentBuffer,
28}
29
30impl<'q> Arguments<'q> for ScyllaDBArguments {
31 type Database = ScyllaDB;
32
33 fn reserve(&mut self, additional: usize, size: usize) {
34 self.types.reserve(additional);
35 self.buffer.reserve(size);
36 }
37
38 fn add<T>(&mut self, value: T) -> Result<(), sqlx::error::BoxDynError>
39 where
40 T: 'q + sqlx::Encode<'q, Self::Database> + sqlx::Type<Self::Database>,
41 {
42 let ty = value.produces().unwrap_or_else(T::type_info);
43 let is_null = value.encode(&mut self.buffer)?;
44 if is_null.is_null() {
45 self.buffer.push(ScyllaDBArgument::Null);
46 }
47
48 self.types.push(ty);
49
50 Ok(())
51 }
52
53 #[inline(always)]
54 fn len(&self) -> usize {
55 self.buffer.len()
56 }
57}
58
59impl SerializeRow for ScyllaDBArguments {
60 fn serialize(
61 &self,
62 ctx: &RowSerializationContext<'_>,
63 writer: &mut RowWriter,
64 ) -> Result<(), SerializationError> {
65 let columns = ctx.columns();
66 for (i, column) in columns.iter().enumerate() {
67 if let Some(argument) = self.buffer.get(i) {
68 let cell_writer = writer.make_cell_writer();
69 let typ = column.typ();
70 argument.serialize(typ, cell_writer)?;
71 }
72 }
73
74 Ok(())
75 }
76
77 #[inline(always)]
78 fn is_empty(&self) -> bool {
79 self.buffer.is_empty()
80 }
81}
82
83#[derive(Default)]
85pub struct ScyllaDBArgumentBuffer {
86 pub(crate) buffer: Vec<ScyllaDBArgument>,
87}
88
89impl Deref for ScyllaDBArgumentBuffer {
90 type Target = Vec<ScyllaDBArgument>;
91
92 fn deref(&self) -> &Self::Target {
93 &self.buffer
94 }
95}
96
97impl<'q> DerefMut for ScyllaDBArgumentBuffer {
98 fn deref_mut(&mut self) -> &mut Self::Target {
99 &mut self.buffer
100 }
101}
102
103pub enum ScyllaDBArgument {
105 Null,
107 Unset,
109 Any(Arc<dyn SerializeValue + Send + Sync>),
111 Boolean(bool),
113 BooleanArray(Vec<bool>),
115 TinyInt(i8),
117 TinyIntArray(Vec<i8>),
119 SmallInt(i16),
121 SmallIntArray(Vec<i16>),
123 Int(i32),
125 IntArray(Vec<i32>),
127 BigInt(i64),
129 BigIntArray(Vec<i64>),
131 Float(f32),
133 FloatArray(Vec<f32>),
135 Double(f64),
137 DoubleArray(Vec<f64>),
139 Text(String),
141 TextArray(Vec<String>),
143 #[cfg(feature = "secrecy-08")]
145 SecretText(secrecy_08::SecretString),
146 #[cfg(feature = "secrecy-08")]
148 SecretTextArray(Vec<secrecy_08::SecretString>),
149 Blob(Vec<u8>),
151 BlobArray(Vec<Vec<u8>>),
153 #[cfg(feature = "secrecy-08")]
155 SecretBlob(secrecy_08::SecretVec<u8>),
156 #[cfg(feature = "secrecy-08")]
158 SecretBlobArray(Vec<secrecy_08::SecretVec<u8>>),
159 Uuid(Uuid),
161 UuidArray(Vec<Uuid>),
163 Timeuuid(CqlTimeuuid),
165 TimeuuidArray(Vec<CqlTimeuuid>),
167 IpAddr(IpAddr),
169 IpAddrArray(Vec<IpAddr>),
171 Duration(CqlDuration),
173 DurationArray(Vec<CqlDuration>),
175 #[cfg(feature = "bigdecimal-04")]
177 BigDecimal(bigdecimal_04::BigDecimal),
178 #[cfg(feature = "bigdecimal-04")]
180 BigDecimalArray(Vec<bigdecimal_04::BigDecimal>),
181 CqlTimestamp(CqlTimestamp),
183 CqlTimestampArray(Vec<CqlTimestamp>),
185 #[cfg(feature = "time-03")]
187 OffsetDateTime(time_03::OffsetDateTime),
188 #[cfg(feature = "time-03")]
190 OffsetDateTimeArray(Vec<time_03::OffsetDateTime>),
191 #[cfg(feature = "chrono-04")]
193 ChronoDateTimeUTC(chrono_04::DateTime<chrono_04::Utc>),
194 #[cfg(feature = "chrono-04")]
196 ChronoDateTimeUTCArray(Vec<chrono_04::DateTime<chrono_04::Utc>>),
197 CqlDate(CqlDate),
199 CqlDateArray(Vec<CqlDate>),
201 #[cfg(feature = "time-03")]
203 Date(time_03::Date),
204 #[cfg(feature = "time-03")]
206 DateArray(Vec<time_03::Date>),
207 #[cfg(feature = "chrono-04")]
209 ChronoNaiveDate(chrono_04::NaiveDate),
210 #[cfg(feature = "chrono-04")]
212 ChronoNaiveDateArray(Vec<chrono_04::NaiveDate>),
213 CqlTime(CqlTime),
215 CqlTimeArray(Vec<CqlTime>),
217 #[cfg(feature = "time-03")]
219 Time(time_03::Time),
220 #[cfg(feature = "time-03")]
222 TimeArray(Vec<time_03::Time>),
223 #[cfg(feature = "chrono-04")]
225 ChronoNaiveTime(chrono_04::NaiveTime),
226 #[cfg(feature = "chrono-04")]
228 ChronoNaiveTimeArray(Vec<chrono_04::NaiveTime>),
229 Tuple(Box<dyn SerializeValue + Send + Sync>),
231 UserDefinedType(Box<dyn SerializeValue + Send + Sync>),
233 UserDefinedTypeArray(Vec<Box<dyn SerializeValue + Send + Sync>>),
235 TextTextMap(HashMap<String, String>),
237 TextBooleanMap(HashMap<String, bool>),
239 TextTinyIntMap(HashMap<String, i8>),
241 TextSmallIntMap(HashMap<String, i16>),
243 TextIntMap(HashMap<String, i32>),
245 TextBigIntMap(HashMap<String, i64>),
247 TextFloatMap(HashMap<String, f32>),
249 TextDoubleMap(HashMap<String, f64>),
251 TextUuidMap(HashMap<String, Uuid>),
253 TextIpAddrMap(HashMap<String, IpAddr>),
255}
256
257impl SerializeValue for ScyllaDBArgument {
258 fn serialize<'b>(
259 &self,
260 typ: &ColumnType,
261 writer: CellWriter<'b>,
262 ) -> Result<WrittenCellProof<'b>, SerializationError> {
263 match self {
264 Self::Any(value) => <_ as SerializeValue>::serialize(value, typ, writer),
265 Self::Null => Ok(writer.set_null()),
266 Self::Unset => Ok(writer.set_unset()),
267 Self::Boolean(value) => <_ as SerializeValue>::serialize(value, typ, writer),
268 Self::BooleanArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
269 Self::TinyInt(value) => {
270 if ColumnType::Native(NativeType::Counter) == *typ {
271 <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
272 } else {
273 <_ as SerializeValue>::serialize(value, typ, writer)
274 }
275 }
276 Self::TinyIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
277 Self::SmallInt(value) => {
278 if ColumnType::Native(NativeType::Counter) == *typ {
279 <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
280 } else {
281 <_ as SerializeValue>::serialize(value, typ, writer)
282 }
283 }
284 Self::SmallIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
285 Self::Int(value) => {
286 if ColumnType::Native(NativeType::Counter) == *typ {
287 <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
288 } else {
289 <_ as SerializeValue>::serialize(value, typ, writer)
290 }
291 }
292 Self::IntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
293 Self::BigInt(value) => {
294 if ColumnType::Native(NativeType::Counter) == *typ {
295 <_ as SerializeValue>::serialize(&Counter(*value as i64), typ, writer)
296 } else {
297 <_ as SerializeValue>::serialize(value, typ, writer)
298 }
299 }
300 Self::BigIntArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
301 Self::Float(value) => <_ as SerializeValue>::serialize(value, typ, writer),
302 Self::FloatArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
303 Self::Double(value) => <_ as SerializeValue>::serialize(value, typ, writer),
304 Self::DoubleArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
305 Self::Text(value) => <_ as SerializeValue>::serialize(value, typ, writer),
306 Self::TextArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
307 #[cfg(feature = "secrecy-08")]
308 Self::SecretText(value) => <_ as SerializeValue>::serialize(value, typ, writer),
309 #[cfg(feature = "secrecy-08")]
310 Self::SecretTextArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
311 Self::Blob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
312 Self::BlobArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
313 #[cfg(feature = "secrecy-08")]
314 Self::SecretBlob(value) => <_ as SerializeValue>::serialize(value, typ, writer),
315 #[cfg(feature = "secrecy-08")]
316 Self::SecretBlobArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
317 Self::Uuid(uuid) => <_ as SerializeValue>::serialize(uuid, typ, writer),
318 Self::UuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
319 Self::Timeuuid(timeuuid) => <_ as SerializeValue>::serialize(timeuuid, typ, writer),
320 Self::TimeuuidArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
321 Self::IpAddr(ip_addr) => <_ as SerializeValue>::serialize(ip_addr, typ, writer),
322 Self::IpAddrArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
323 Self::Duration(value) => <_ as SerializeValue>::serialize(value, typ, writer),
324 Self::DurationArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
325 #[cfg(feature = "bigdecimal-04")]
326 Self::BigDecimal(value) => <_ as SerializeValue>::serialize(value, typ, writer),
327 #[cfg(feature = "bigdecimal-04")]
328 Self::BigDecimalArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
329 Self::CqlTimestamp(value) => <_ as SerializeValue>::serialize(value, typ, writer),
330 Self::CqlTimestampArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
331 #[cfg(feature = "time-03")]
332 Self::OffsetDateTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
333 #[cfg(feature = "time-03")]
334 Self::OffsetDateTimeArray(value) => {
335 <_ as SerializeValue>::serialize(value, typ, writer)
336 }
337 #[cfg(feature = "chrono-04")]
338 Self::ChronoDateTimeUTC(value) => <_ as SerializeValue>::serialize(value, typ, writer),
339 #[cfg(feature = "chrono-04")]
340 Self::ChronoDateTimeUTCArray(value) => {
341 <_ as SerializeValue>::serialize(value, typ, writer)
342 }
343 Self::CqlTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
344 Self::CqlTimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
345 #[cfg(feature = "time-03")]
346 Self::Time(value) => <_ as SerializeValue>::serialize(value, typ, writer),
347 #[cfg(feature = "time-03")]
348 Self::TimeArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
349 #[cfg(feature = "chrono-04")]
350 Self::ChronoNaiveTime(value) => <_ as SerializeValue>::serialize(value, typ, writer),
351 #[cfg(feature = "chrono-04")]
352 Self::ChronoNaiveTimeArray(value) => {
353 <_ as SerializeValue>::serialize(value, typ, writer)
354 }
355 Self::CqlDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
356 Self::CqlDateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
357 #[cfg(feature = "time-03")]
358 Self::Date(value) => <_ as SerializeValue>::serialize(value, typ, writer),
359 #[cfg(feature = "time-03")]
360 Self::DateArray(value) => <_ as SerializeValue>::serialize(value, typ, writer),
361 #[cfg(feature = "chrono-04")]
362 Self::ChronoNaiveDate(value) => <_ as SerializeValue>::serialize(value, typ, writer),
363 #[cfg(feature = "chrono-04")]
364 Self::ChronoNaiveDateArray(value) => {
365 <_ as SerializeValue>::serialize(value, typ, writer)
366 }
367 Self::Tuple(value) => <_ as SerializeValue>::serialize(value, typ, writer),
368 Self::UserDefinedType(value) => <_ as SerializeValue>::serialize(value, typ, writer),
369 Self::UserDefinedTypeArray(value) => {
370 <_ as SerializeValue>::serialize(value, typ, writer)
371 }
372 Self::TextTextMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
373 Self::TextBooleanMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
374 Self::TextTinyIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
375 Self::TextSmallIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
376 Self::TextIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
377 Self::TextBigIntMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
378 Self::TextFloatMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
379 Self::TextDoubleMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
380 Self::TextUuidMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
381 Self::TextIpAddrMap(value) => <_ as SerializeValue>::serialize(value, typ, writer),
382 }
383 }
384}