1#![allow(clippy::cast_possible_truncation)]
30
31use super::{Command, PacketWriter};
32use crate::types::{ColumnDef, FieldType};
33use sqlmodel_core::Value;
34
35#[derive(Debug, Clone)]
39pub struct StmtPrepareOk {
40 pub statement_id: u32,
42 pub num_columns: u16,
44 pub num_params: u16,
46 pub warnings: u16,
48}
49
50#[derive(Debug, Clone)]
55pub struct PreparedStatement {
56 pub statement_id: u32,
58 pub sql: String,
60 pub params: Vec<ColumnDef>,
62 pub columns: Vec<ColumnDef>,
64}
65
66impl PreparedStatement {
67 pub fn new(
69 statement_id: u32,
70 sql: String,
71 params: Vec<ColumnDef>,
72 columns: Vec<ColumnDef>,
73 ) -> Self {
74 Self {
75 statement_id,
76 sql,
77 params,
78 columns,
79 }
80 }
81
82 #[must_use]
84 pub fn param_count(&self) -> usize {
85 self.params.len()
86 }
87
88 #[must_use]
90 pub fn column_count(&self) -> usize {
91 self.columns.len()
92 }
93}
94
95pub fn build_stmt_prepare_packet(sql: &str, sequence_id: u8) -> Vec<u8> {
106 let mut writer = PacketWriter::with_capacity(1 + sql.len());
107 writer.write_u8(Command::StmtPrepare as u8);
108 writer.write_bytes(sql.as_bytes());
109 writer.build_packet(sequence_id)
110}
111
112pub fn build_stmt_execute_packet(
132 statement_id: u32,
133 params: &[Value],
134 param_types: Option<&[FieldType]>,
135 sequence_id: u8,
136) -> Vec<u8> {
137 let mut writer = PacketWriter::with_capacity(64 + params.len() * 16);
138
139 writer.write_u8(Command::StmtExecute as u8);
141
142 writer.write_u32_le(statement_id);
144
145 writer.write_u8(0x00);
147
148 writer.write_u32_le(1);
150
151 if !params.is_empty() {
152 let null_bitmap_len = params.len().div_ceil(8);
154 let mut null_bitmap = vec![0u8; null_bitmap_len];
155
156 for (i, param) in params.iter().enumerate() {
157 if matches!(param, Value::Null) {
158 null_bitmap[i / 8] |= 1 << (i % 8);
159 }
160 }
161 writer.write_bytes(&null_bitmap);
162
163 writer.write_u8(1);
165
166 for (i, param) in params.iter().enumerate() {
168 let field_type = if let Some(types) = param_types {
169 if i < types.len() {
170 types[i]
171 } else {
172 value_to_field_type(param)
173 }
174 } else {
175 value_to_field_type(param)
176 };
177
178 writer.write_u8(field_type as u8);
180 let flags = if is_unsigned_value(param) { 0x80 } else { 0x00 };
182 writer.write_u8(flags);
183 }
184
185 for param in params {
187 if !matches!(param, Value::Null) {
188 encode_binary_param(&mut writer, param);
189 }
190 }
191 }
192
193 writer.build_packet(sequence_id)
194}
195
196pub fn build_stmt_close_packet(statement_id: u32, sequence_id: u8) -> Vec<u8> {
207 let mut writer = PacketWriter::with_capacity(5);
208 writer.write_u8(Command::StmtClose as u8);
209 writer.write_u32_le(statement_id);
210 writer.build_packet(sequence_id)
211}
212
213pub fn build_stmt_reset_packet(statement_id: u32, sequence_id: u8) -> Vec<u8> {
223 let mut writer = PacketWriter::with_capacity(5);
224 writer.write_u8(Command::StmtReset as u8);
225 writer.write_u32_le(statement_id);
226 writer.build_packet(sequence_id)
227}
228
229pub fn parse_stmt_prepare_ok(data: &[u8]) -> Option<StmtPrepareOk> {
244 if data.len() < 12 {
245 return None;
246 }
247
248 if data[0] != 0x00 {
250 return None;
251 }
252
253 let statement_id = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
254 let num_columns = u16::from_le_bytes([data[5], data[6]]);
255 let num_params = u16::from_le_bytes([data[7], data[8]]);
256 let warnings = if data.len() >= 12 {
258 u16::from_le_bytes([data[10], data[11]])
259 } else {
260 0
261 };
262
263 Some(StmtPrepareOk {
264 statement_id,
265 num_columns,
266 num_params,
267 warnings,
268 })
269}
270
271fn value_to_field_type(value: &Value) -> FieldType {
273 match value {
274 Value::Null => FieldType::Null,
275 Value::Bool(_) => FieldType::Tiny,
276 Value::TinyInt(_) => FieldType::Tiny,
277 Value::SmallInt(_) => FieldType::Short,
278 Value::Int(_) => FieldType::Long,
279 Value::BigInt(_) => FieldType::LongLong,
280 Value::Float(_) => FieldType::Float,
281 Value::Double(_) => FieldType::Double,
282 Value::Decimal(_) => FieldType::NewDecimal,
283 Value::Text(_) => FieldType::VarString,
284 Value::Bytes(_) => FieldType::Blob,
285 Value::Json(_) => FieldType::Json,
286 Value::Date(_) => FieldType::Date,
287 Value::Time(_) => FieldType::Time,
288 Value::Timestamp(_) | Value::TimestampTz(_) => FieldType::DateTime,
289 Value::Uuid(_) => FieldType::Blob,
290 Value::Array(_) => FieldType::Json,
291 Value::Default => FieldType::Null,
292 }
293}
294
295fn is_unsigned_value(value: &Value) -> bool {
297 matches!(value, Value::BigInt(i) if *i > i64::MAX / 2)
300}
301
302fn encode_binary_param(writer: &mut PacketWriter, value: &Value) {
304 match value {
305 Value::Null => {
306 }
308 Value::Bool(b) => {
309 writer.write_u8(if *b { 1 } else { 0 });
310 }
311 Value::TinyInt(i) => {
312 writer.write_u8(*i as u8);
313 }
314 Value::SmallInt(i) => {
315 writer.write_u16_le(*i as u16);
316 }
317 Value::Int(i) => {
318 writer.write_u32_le(*i as u32);
319 }
320 Value::BigInt(i) => {
321 writer.write_u64_le(*i as u64);
322 }
323 Value::Float(f) => {
324 writer.write_bytes(&f.to_le_bytes());
325 }
326 Value::Double(f) => {
327 writer.write_bytes(&f.to_le_bytes());
328 }
329 Value::Decimal(s) => {
330 write_length_encoded_string(writer, s);
331 }
332 Value::Text(s) => {
333 write_length_encoded_string(writer, s);
334 }
335 Value::Bytes(b) => {
336 write_length_encoded_bytes(writer, b);
337 }
338 Value::Json(j) => {
339 let s = j.to_string();
340 write_length_encoded_string(writer, &s);
341 }
342 Value::Date(days) => {
343 encode_binary_date(writer, *days);
346 }
347 Value::Time(micros) => {
348 encode_binary_time(writer, *micros);
350 }
351 Value::Timestamp(micros) | Value::TimestampTz(micros) => {
352 encode_binary_datetime(writer, *micros);
354 }
355 Value::Uuid(bytes) => {
356 write_length_encoded_bytes(writer, bytes);
357 }
358 Value::Array(arr) => {
359 let s = serde_json::to_string(arr).unwrap_or_default();
361 write_length_encoded_string(writer, &s);
362 }
363 Value::Default => {
364 }
366 }
367}
368
369fn write_length_encoded_string(writer: &mut PacketWriter, s: &str) {
371 write_length_encoded_bytes(writer, s.as_bytes());
372}
373
374fn write_length_encoded_bytes(writer: &mut PacketWriter, data: &[u8]) {
376 let len = data.len();
377 if len < 251 {
378 writer.write_u8(len as u8);
379 } else if len < 0x10000 {
380 writer.write_u8(0xFC);
381 writer.write_u16_le(len as u16);
382 } else if len < 0x0100_0000 {
383 writer.write_u8(0xFD);
384 writer.write_u8((len & 0xFF) as u8);
385 writer.write_u8(((len >> 8) & 0xFF) as u8);
386 writer.write_u8(((len >> 16) & 0xFF) as u8);
387 } else {
388 writer.write_u8(0xFE);
389 writer.write_u64_le(len as u64);
390 }
391 writer.write_bytes(data);
392}
393
394fn encode_binary_date(writer: &mut PacketWriter, days: i32) {
396 let (year, month, day) = days_to_ymd(days);
399
400 if year == 0 && month == 0 && day == 0 {
401 writer.write_u8(0);
403 } else {
404 writer.write_u8(4); writer.write_u16_le(year as u16);
406 writer.write_u8(month as u8);
407 writer.write_u8(day as u8);
408 }
409}
410
411fn encode_binary_time(writer: &mut PacketWriter, micros: i64) {
413 let is_negative = micros < 0;
414 let micros = micros.unsigned_abs();
415
416 let total_seconds = micros / 1_000_000;
417 let microseconds = (micros % 1_000_000) as u32;
418
419 let hours = total_seconds / 3600;
420 let minutes = (total_seconds % 3600) / 60;
421 let seconds = total_seconds % 60;
422
423 let days = hours / 24;
425 let hours = hours % 24;
426
427 if days == 0 && hours == 0 && minutes == 0 && seconds == 0 && microseconds == 0 {
428 writer.write_u8(0); } else if microseconds == 0 {
430 writer.write_u8(8); writer.write_u8(if is_negative { 1 } else { 0 });
432 writer.write_u32_le(days as u32);
433 writer.write_u8(hours as u8);
434 writer.write_u8(minutes as u8);
435 writer.write_u8(seconds as u8);
436 } else {
437 writer.write_u8(12); writer.write_u8(if is_negative { 1 } else { 0 });
439 writer.write_u32_le(days as u32);
440 writer.write_u8(hours as u8);
441 writer.write_u8(minutes as u8);
442 writer.write_u8(seconds as u8);
443 writer.write_u32_le(microseconds);
444 }
445}
446
447fn encode_binary_datetime(writer: &mut PacketWriter, micros: i64) {
449 let total_seconds = micros / 1_000_000;
451 let microseconds = (micros % 1_000_000).unsigned_abs() as u32;
452
453 let days = (total_seconds / 86400) as i32;
455 let time_of_day = (total_seconds % 86400).unsigned_abs();
456
457 let (year, month, day) = days_to_ymd(days);
458 let hour = (time_of_day / 3600) as u8;
459 let minute = ((time_of_day % 3600) / 60) as u8;
460 let second = (time_of_day % 60) as u8;
461
462 if year == 0
463 && month == 0
464 && day == 0
465 && hour == 0
466 && minute == 0
467 && second == 0
468 && microseconds == 0
469 {
470 writer.write_u8(0); } else if hour == 0 && minute == 0 && second == 0 && microseconds == 0 {
472 writer.write_u8(4); writer.write_u16_le(year as u16);
474 writer.write_u8(month as u8);
475 writer.write_u8(day as u8);
476 } else if microseconds == 0 {
477 writer.write_u8(7); writer.write_u16_le(year as u16);
479 writer.write_u8(month as u8);
480 writer.write_u8(day as u8);
481 writer.write_u8(hour);
482 writer.write_u8(minute);
483 writer.write_u8(second);
484 } else {
485 writer.write_u8(11); writer.write_u16_le(year as u16);
487 writer.write_u8(month as u8);
488 writer.write_u8(day as u8);
489 writer.write_u8(hour);
490 writer.write_u8(minute);
491 writer.write_u8(second);
492 writer.write_u32_le(microseconds);
493 }
494}
495
496fn days_to_ymd(days: i32) -> (i32, i32, i32) {
501 let z = days + 719_468;
504
505 let era = if z >= 0 {
507 z / 146_097
508 } else {
509 (z - 146_096) / 146_097
510 };
511 let doe = (z - era * 146_097) as u32; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365; let y = yoe as i32 + era * 400;
514 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let d = doy - (153 * mp + 2) / 5 + 1; let m = if mp < 10 { mp + 3 } else { mp - 9 }; let year = if m <= 2 { y + 1 } else { y };
521
522 (year, m as i32, d as i32)
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn test_build_stmt_prepare_packet() {
531 let packet = build_stmt_prepare_packet("SELECT * FROM users WHERE id = ?", 0);
532
533 assert_eq!(packet[3], 0); assert_eq!(packet[4], Command::StmtPrepare as u8);
538
539 assert_eq!(&packet[5..], b"SELECT * FROM users WHERE id = ?");
541 }
542
543 #[test]
544 fn test_build_stmt_close_packet() {
545 let packet = build_stmt_close_packet(42, 0);
546
547 assert_eq!(packet.len(), 9);
549 assert_eq!(packet[4], Command::StmtClose as u8);
550
551 let stmt_id = u32::from_le_bytes([packet[5], packet[6], packet[7], packet[8]]);
553 assert_eq!(stmt_id, 42);
554 }
555
556 #[test]
557 fn test_parse_stmt_prepare_ok() {
558 let data = [
560 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, ];
567
568 let result = parse_stmt_prepare_ok(&data).unwrap();
569 assert_eq!(result.statement_id, 1);
570 assert_eq!(result.num_columns, 3);
571 assert_eq!(result.num_params, 2);
572 assert_eq!(result.warnings, 0);
573 }
574
575 #[test]
576 fn test_parse_stmt_prepare_ok_invalid() {
577 assert!(parse_stmt_prepare_ok(&[0x00, 0x01]).is_none());
579
580 let data = [
582 0xFF, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00,
584 ];
585 assert!(parse_stmt_prepare_ok(&data).is_none());
586 }
587
588 #[test]
589 fn test_build_stmt_execute_no_params() {
590 let packet = build_stmt_execute_packet(1, &[], None, 0);
591
592 assert_eq!(packet[4], Command::StmtExecute as u8);
594
595 let stmt_id = u32::from_le_bytes([packet[5], packet[6], packet[7], packet[8]]);
597 assert_eq!(stmt_id, 1);
598
599 assert_eq!(packet[9], 0x00);
601
602 let iter_count = u32::from_le_bytes([packet[10], packet[11], packet[12], packet[13]]);
604 assert_eq!(iter_count, 1);
605 }
606
607 #[test]
608 fn test_build_stmt_execute_with_params() {
609 let params = vec![Value::Int(42), Value::Text("hello".to_string())];
610 let packet = build_stmt_execute_packet(1, ¶ms, None, 0);
611
612 assert_eq!(packet[4], Command::StmtExecute as u8);
614
615 let stmt_id = u32::from_le_bytes([packet[5], packet[6], packet[7], packet[8]]);
617 assert_eq!(stmt_id, 1);
618
619 assert_eq!(packet[9], 0x00);
621
622 let iter_count = u32::from_le_bytes([packet[10], packet[11], packet[12], packet[13]]);
624 assert_eq!(iter_count, 1);
625
626 assert_eq!(packet[14], 0x00); assert_eq!(packet[15], 0x01);
631
632 assert_eq!(packet[16], FieldType::Long as u8);
634 assert_eq!(packet[17], 0x00); assert_eq!(packet[18], FieldType::VarString as u8);
636 assert_eq!(packet[19], 0x00); }
638
639 #[test]
640 fn test_build_stmt_execute_with_null() {
641 let params = vec![Value::Null, Value::Int(42)];
642 let packet = build_stmt_execute_packet(1, ¶ms, None, 0);
643
644 assert_eq!(packet[14], 0x01);
646 }
647
648 #[test]
649 fn test_value_to_field_type() {
650 assert_eq!(value_to_field_type(&Value::Null), FieldType::Null);
651 assert_eq!(value_to_field_type(&Value::Bool(true)), FieldType::Tiny);
652 assert_eq!(value_to_field_type(&Value::TinyInt(1)), FieldType::Tiny);
653 assert_eq!(value_to_field_type(&Value::SmallInt(1)), FieldType::Short);
654 assert_eq!(value_to_field_type(&Value::Int(1)), FieldType::Long);
655 assert_eq!(value_to_field_type(&Value::BigInt(1)), FieldType::LongLong);
656 assert_eq!(value_to_field_type(&Value::Float(1.0)), FieldType::Float);
657 assert_eq!(value_to_field_type(&Value::Double(1.0)), FieldType::Double);
658 assert_eq!(
659 value_to_field_type(&Value::Text(String::new())),
660 FieldType::VarString
661 );
662 assert_eq!(value_to_field_type(&Value::Bytes(vec![])), FieldType::Blob);
663 }
664
665 #[test]
666 fn test_days_to_ymd() {
667 assert_eq!(days_to_ymd(0), (1970, 1, 1));
669
670 assert_eq!(days_to_ymd(10957), (2000, 1, 1));
672
673 assert_eq!(days_to_ymd(19782), (2024, 2, 29));
675 }
676}