1#![allow(clippy::cast_possible_truncation)]
23
24use std::io::{Read, Write};
25use std::net::TcpStream;
26#[cfg(feature = "console")]
27use std::sync::Arc;
28
29use sqlmodel_core::Error;
30use sqlmodel_core::error::{
31 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
32};
33use sqlmodel_core::{Row, Value};
34
35#[cfg(feature = "console")]
36use sqlmodel_console::{ConsoleAware, SqlModelConsole};
37
38use crate::auth;
39use crate::config::MySqlConfig;
40use crate::protocol::{
41 Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
42 capabilities, charset,
43};
44use crate::types::{ColumnDef, FieldType, decode_text_value, interpolate_params};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ConnectionState {
49 Disconnected,
51 Connecting,
53 Authenticating,
55 Ready,
57 InQuery,
59 InTransaction,
61 Error,
63 Closed,
65}
66
67#[derive(Debug, Clone)]
69pub struct ServerCapabilities {
70 pub capabilities: u32,
72 pub protocol_version: u8,
74 pub server_version: String,
76 pub connection_id: u32,
78 pub auth_plugin: String,
80 pub auth_data: Vec<u8>,
82 pub charset: u8,
84 pub status_flags: u16,
86}
87
88pub struct MySqlConnection {
93 stream: TcpStream,
95 state: ConnectionState,
97 server_caps: Option<ServerCapabilities>,
99 connection_id: u32,
101 status_flags: u16,
103 affected_rows: u64,
105 last_insert_id: u64,
107 warnings: u16,
109 config: MySqlConfig,
111 sequence_id: u8,
113 #[cfg(feature = "console")]
115 console: Option<Arc<SqlModelConsole>>,
116}
117
118impl std::fmt::Debug for MySqlConnection {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("MySqlConnection")
121 .field("state", &self.state)
122 .field("connection_id", &self.connection_id)
123 .field("host", &self.config.host)
124 .field("port", &self.config.port)
125 .field("database", &self.config.database)
126 .finish_non_exhaustive()
127 }
128}
129
130impl MySqlConnection {
131 #[allow(clippy::result_large_err)]
139 pub fn connect(config: MySqlConfig) -> Result<Self, Error> {
140 let stream = TcpStream::connect_timeout(
142 &config.socket_addr().parse().map_err(|e| {
143 Error::Connection(ConnectionError {
144 kind: ConnectionErrorKind::Connect,
145 message: format!("Invalid socket address: {}", e),
146 source: None,
147 })
148 })?,
149 config.connect_timeout,
150 )
151 .map_err(|e| {
152 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
153 ConnectionErrorKind::Refused
154 } else {
155 ConnectionErrorKind::Connect
156 };
157 Error::Connection(ConnectionError {
158 kind,
159 message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
160 source: Some(Box::new(e)),
161 })
162 })?;
163
164 stream.set_nodelay(true).ok();
166 stream.set_read_timeout(Some(config.connect_timeout)).ok();
167 stream.set_write_timeout(Some(config.connect_timeout)).ok();
168
169 let mut conn = Self {
170 stream,
171 state: ConnectionState::Connecting,
172 server_caps: None,
173 connection_id: 0,
174 status_flags: 0,
175 affected_rows: 0,
176 last_insert_id: 0,
177 warnings: 0,
178 config,
179 sequence_id: 0,
180 #[cfg(feature = "console")]
181 console: None,
182 };
183
184 let server_caps = conn.read_handshake()?;
186 conn.connection_id = server_caps.connection_id;
187 conn.server_caps = Some(server_caps);
188 conn.state = ConnectionState::Authenticating;
189
190 conn.send_handshake_response()?;
192
193 conn.handle_auth_result()?;
195
196 conn.state = ConnectionState::Ready;
197 Ok(conn)
198 }
199
200 pub fn state(&self) -> ConnectionState {
202 self.state
203 }
204
205 pub fn is_ready(&self) -> bool {
207 matches!(self.state, ConnectionState::Ready)
208 }
209
210 pub fn connection_id(&self) -> u32 {
212 self.connection_id
213 }
214
215 pub fn server_version(&self) -> Option<&str> {
217 self.server_caps
218 .as_ref()
219 .map(|caps| caps.server_version.as_str())
220 }
221
222 pub fn affected_rows(&self) -> u64 {
224 self.affected_rows
225 }
226
227 pub fn last_insert_id(&self) -> u64 {
229 self.last_insert_id
230 }
231
232 pub fn warnings(&self) -> u16 {
234 self.warnings
235 }
236
237 #[allow(clippy::result_large_err)]
239 fn read_handshake(&mut self) -> Result<ServerCapabilities, Error> {
240 let (payload, _) = self.read_packet()?;
241 let mut reader = PacketReader::new(&payload);
242
243 let protocol_version = reader
245 .read_u8()
246 .ok_or_else(|| protocol_error("Missing protocol version"))?;
247
248 if protocol_version != 10 {
249 return Err(protocol_error(format!(
250 "Unsupported protocol version: {}",
251 protocol_version
252 )));
253 }
254
255 let server_version = reader
257 .read_null_string()
258 .ok_or_else(|| protocol_error("Missing server version"))?;
259
260 let connection_id = reader
262 .read_u32_le()
263 .ok_or_else(|| protocol_error("Missing connection ID"))?;
264
265 let auth_data_1 = reader
267 .read_bytes(8)
268 .ok_or_else(|| protocol_error("Missing auth data"))?;
269
270 reader.skip(1);
272
273 let caps_lower = reader
275 .read_u16_le()
276 .ok_or_else(|| protocol_error("Missing capability flags"))?;
277
278 let charset = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
280
281 let status_flags = reader.read_u16_le().unwrap_or(0);
283
284 let caps_upper = reader.read_u16_le().unwrap_or(0);
286 let capabilities = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
287
288 let auth_data_len = if capabilities & capabilities::CLIENT_PLUGIN_AUTH != 0 {
290 reader.read_u8().unwrap_or(0) as usize
291 } else {
292 0
293 };
294
295 reader.skip(10);
297
298 let mut auth_data = auth_data_1.to_vec();
300 if capabilities & capabilities::CLIENT_SECURE_CONNECTION != 0 {
301 let len2 = if auth_data_len > 8 {
302 auth_data_len - 8
303 } else {
304 13 };
306 if let Some(data2) = reader.read_bytes(len2) {
307 let data2_clean = if data2.last() == Some(&0) {
309 &data2[..data2.len() - 1]
310 } else {
311 data2
312 };
313 auth_data.extend_from_slice(data2_clean);
314 }
315 }
316
317 let auth_plugin = if capabilities & capabilities::CLIENT_PLUGIN_AUTH != 0 {
319 reader.read_null_string().unwrap_or_default()
320 } else {
321 auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
322 };
323
324 Ok(ServerCapabilities {
325 capabilities,
326 protocol_version,
327 server_version,
328 connection_id,
329 auth_plugin,
330 auth_data,
331 charset,
332 status_flags,
333 })
334 }
335
336 #[allow(clippy::result_large_err)]
338 fn send_handshake_response(&mut self) -> Result<(), Error> {
339 let server_caps = self
340 .server_caps
341 .as_ref()
342 .ok_or_else(|| protocol_error("No server handshake received"))?;
343
344 let client_caps = self.config.capability_flags() & server_caps.capabilities;
346
347 let auth_response =
349 self.compute_auth_response(&server_caps.auth_plugin, &server_caps.auth_data);
350
351 let mut writer = PacketWriter::new();
352
353 writer.write_u32_le(client_caps);
355
356 writer.write_u32_le(self.config.max_packet_size);
358
359 writer.write_u8(self.config.charset);
361
362 writer.write_zeros(23);
364
365 writer.write_null_string(&self.config.user);
367
368 if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
370 writer.write_lenenc_bytes(&auth_response);
371 } else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
372 #[allow(clippy::cast_possible_truncation)]
374 writer.write_u8(auth_response.len() as u8);
375 writer.write_bytes(&auth_response);
376 } else {
377 writer.write_bytes(&auth_response);
378 writer.write_u8(0); }
380
381 if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
383 if let Some(ref db) = self.config.database {
384 writer.write_null_string(db);
385 } else {
386 writer.write_u8(0); }
388 }
389
390 if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
392 writer.write_null_string(&server_caps.auth_plugin);
393 }
394
395 if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
397 && !self.config.attributes.is_empty()
398 {
399 let mut attrs_writer = PacketWriter::new();
400 for (key, value) in &self.config.attributes {
401 attrs_writer.write_lenenc_string(key);
402 attrs_writer.write_lenenc_string(value);
403 }
404 let attrs_data = attrs_writer.into_bytes();
405 writer.write_lenenc_bytes(&attrs_data);
406 }
407
408 self.write_packet(writer.as_bytes())?;
409
410 Ok(())
411 }
412
413 fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
415 let password = self.config.password.as_deref().unwrap_or("");
416
417 match plugin {
418 auth::plugins::MYSQL_NATIVE_PASSWORD => {
419 auth::mysql_native_password(password, auth_data)
420 }
421 auth::plugins::CACHING_SHA2_PASSWORD => {
422 auth::caching_sha2_password(password, auth_data)
423 }
424 auth::plugins::MYSQL_CLEAR_PASSWORD => {
425 let mut result = password.as_bytes().to_vec();
427 result.push(0);
428 result
429 }
430 _ => {
431 auth::mysql_native_password(password, auth_data)
433 }
434 }
435 }
436
437 #[allow(clippy::result_large_err)]
439 fn handle_auth_result(&mut self) -> Result<(), Error> {
440 let (payload, _) = self.read_packet()?;
441
442 if payload.is_empty() {
443 return Err(protocol_error("Empty authentication response"));
444 }
445
446 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
447 PacketType::Ok => {
448 let mut reader = PacketReader::new(&payload);
450 if let Some(ok) = reader.parse_ok_packet() {
451 self.status_flags = ok.status_flags;
452 self.affected_rows = ok.affected_rows;
453 }
454 Ok(())
455 }
456 PacketType::Error => {
457 let mut reader = PacketReader::new(&payload);
458 let err = reader
459 .parse_err_packet()
460 .ok_or_else(|| protocol_error("Invalid error packet"))?;
461 Err(auth_error(format!(
462 "Authentication failed: {} ({})",
463 err.error_message, err.error_code
464 )))
465 }
466 PacketType::Eof => {
467 self.handle_auth_switch(&payload[1..])
469 }
470 _ => {
471 self.handle_additional_auth(&payload)
473 }
474 }
475 }
476
477 #[allow(clippy::result_large_err)]
479 fn handle_auth_switch(&mut self, data: &[u8]) -> Result<(), Error> {
480 let mut reader = PacketReader::new(data);
481
482 let plugin = reader
484 .read_null_string()
485 .ok_or_else(|| protocol_error("Missing plugin name in auth switch"))?;
486
487 let auth_data = reader.read_rest();
489
490 let response = self.compute_auth_response(&plugin, auth_data);
492
493 self.write_packet(&response)?;
495
496 self.handle_auth_result()
498 }
499
500 #[allow(clippy::result_large_err)]
502 fn handle_additional_auth(&mut self, data: &[u8]) -> Result<(), Error> {
503 if data.is_empty() {
504 return Err(protocol_error("Empty additional auth data"));
505 }
506
507 match data[0] {
508 auth::caching_sha2::FAST_AUTH_SUCCESS => {
509 let (payload, _) = self.read_packet()?;
511 let mut reader = PacketReader::new(&payload);
512 if let Some(ok) = reader.parse_ok_packet() {
513 self.status_flags = ok.status_flags;
514 }
515 Ok(())
516 }
517 auth::caching_sha2::PERFORM_FULL_AUTH => {
518 let Some(server_caps) = self.server_caps.as_ref() else {
520 return Err(protocol_error("Missing server capabilities during auth"));
521 };
522
523 let password = self.config.password.clone().unwrap_or_default();
524 let seed = server_caps.auth_data.clone();
525 let server_version = server_caps.server_version.clone();
526
527 self.write_packet(&[auth::caching_sha2::REQUEST_PUBLIC_KEY])?;
529 let (payload, _) = self.read_packet()?;
530 if payload.is_empty() {
531 return Err(protocol_error("Empty public key response"));
532 }
533
534 let public_key = if payload[0] == 0x01 {
536 &payload[1..]
537 } else {
538 &payload[..]
539 };
540
541 let use_oaep = mysql_server_uses_oaep(&server_version);
542 let encrypted = auth::sha256_password_rsa(&password, &seed, public_key, use_oaep)
543 .map_err(auth_error)?;
544
545 self.write_packet(&encrypted)?;
546 self.handle_auth_result()
547 }
548 _ => {
549 let mut reader = PacketReader::new(data);
551 if let Some(ok) = reader.parse_ok_packet() {
552 self.status_flags = ok.status_flags;
553 Ok(())
554 } else {
555 Err(protocol_error(format!(
556 "Unknown auth response: {:02X}",
557 data[0]
558 )))
559 }
560 }
561 }
562 }
563
564 #[allow(clippy::result_large_err)]
569 pub fn query_sync(&mut self, sql: &str, params: &[Value]) -> Result<Vec<Row>, Error> {
570 #[cfg(feature = "console")]
571 let start = std::time::Instant::now();
572
573 let sql = interpolate_params(sql, params);
574 if !self.is_ready() && self.state != ConnectionState::InTransaction {
575 return Err(connection_error("Connection not ready for queries"));
576 }
577
578 self.state = ConnectionState::InQuery;
579 self.sequence_id = 0;
580
581 let mut writer = PacketWriter::new();
583 writer.write_u8(Command::Query as u8);
584 writer.write_bytes(sql.as_bytes());
585 self.write_packet(writer.as_bytes())?;
586
587 let (payload, _) = self.read_packet()?;
589
590 if payload.is_empty() {
591 self.state = ConnectionState::Ready;
592 return Err(protocol_error("Empty query response"));
593 }
594
595 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
596 PacketType::Ok => {
597 let mut reader = PacketReader::new(&payload);
599 if let Some(ok) = reader.parse_ok_packet() {
600 self.affected_rows = ok.affected_rows;
601 self.last_insert_id = ok.last_insert_id;
602 self.status_flags = ok.status_flags;
603 self.warnings = ok.warnings;
604 }
605 self.state = if self.status_flags
606 & crate::protocol::server_status::SERVER_STATUS_IN_TRANS
607 != 0
608 {
609 ConnectionState::InTransaction
610 } else {
611 ConnectionState::Ready
612 };
613
614 #[cfg(feature = "console")]
615 {
616 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
617 self.emit_execute_timing(&sql, elapsed_ms, self.affected_rows);
618 self.emit_warnings(self.warnings);
619 }
620
621 Ok(vec![])
622 }
623 PacketType::Error => {
624 self.state = ConnectionState::Ready;
625 let mut reader = PacketReader::new(&payload);
626 let err = reader
627 .parse_err_packet()
628 .ok_or_else(|| protocol_error("Invalid error packet"))?;
629 Err(query_error(&err))
630 }
631 PacketType::LocalInfile => {
632 self.state = ConnectionState::Ready;
633 Err(query_error_msg("LOCAL INFILE not supported"))
634 }
635 _ => {
636 #[cfg(feature = "console")]
638 let result = self.read_result_set_with_timing(&sql, &payload, start);
639 #[cfg(not(feature = "console"))]
640 let result = self.read_result_set(&payload);
641 result
642 }
643 }
644 }
645
646 #[allow(dead_code)] #[allow(clippy::result_large_err)]
649 fn read_result_set(&mut self, first_packet: &[u8]) -> Result<Vec<Row>, Error> {
650 let mut reader = PacketReader::new(first_packet);
651 let column_count = reader
652 .read_lenenc_int()
653 .ok_or_else(|| protocol_error("Invalid column count"))?
654 as usize;
655
656 let mut columns = Vec::with_capacity(column_count);
658 for _ in 0..column_count {
659 let (payload, _) = self.read_packet()?;
660 columns.push(self.parse_column_def(&payload)?);
661 }
662
663 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
665 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
666 let (payload, _) = self.read_packet()?;
667 if payload.first() == Some(&0xFE) {
668 }
670 }
671
672 let mut rows = Vec::new();
674 loop {
675 let (payload, _) = self.read_packet()?;
676
677 if payload.is_empty() {
678 break;
679 }
680
681 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
682 PacketType::Eof | PacketType::Ok => {
683 let mut reader = PacketReader::new(&payload);
685 if payload[0] == 0x00 {
686 if let Some(ok) = reader.parse_ok_packet() {
687 self.status_flags = ok.status_flags;
688 self.warnings = ok.warnings;
689 }
690 } else if payload[0] == 0xFE {
691 if let Some(eof) = reader.parse_eof_packet() {
692 self.status_flags = eof.status_flags;
693 self.warnings = eof.warnings;
694 }
695 }
696 break;
697 }
698 PacketType::Error => {
699 let mut reader = PacketReader::new(&payload);
700 let err = reader
701 .parse_err_packet()
702 .ok_or_else(|| protocol_error("Invalid error packet"))?;
703 self.state = ConnectionState::Ready;
704 return Err(query_error(&err));
705 }
706 _ => {
707 let row = self.parse_text_row(&payload, &columns);
709 rows.push(row);
710 }
711 }
712 }
713
714 self.state =
715 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
716 ConnectionState::InTransaction
717 } else {
718 ConnectionState::Ready
719 };
720
721 Ok(rows)
722 }
723
724 #[cfg(feature = "console")]
726 #[allow(clippy::result_large_err)]
727 fn read_result_set_with_timing(
728 &mut self,
729 sql: &str,
730 first_packet: &[u8],
731 start: std::time::Instant,
732 ) -> Result<Vec<Row>, Error> {
733 let mut reader = PacketReader::new(first_packet);
734 let column_count = reader
735 .read_lenenc_int()
736 .ok_or_else(|| protocol_error("Invalid column count"))?
737 as usize;
738
739 let mut columns = Vec::with_capacity(column_count);
741 let mut col_names = Vec::with_capacity(column_count);
742 for _ in 0..column_count {
743 let (payload, _) = self.read_packet()?;
744 let col = self.parse_column_def(&payload)?;
745 col_names.push(col.name.clone());
746 columns.push(col);
747 }
748
749 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
751 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
752 let (payload, _) = self.read_packet()?;
753 if payload.first() == Some(&0xFE) {
754 }
756 }
757
758 let mut rows = Vec::new();
760 loop {
761 let (payload, _) = self.read_packet()?;
762
763 if payload.is_empty() {
764 break;
765 }
766
767 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
768 PacketType::Eof | PacketType::Ok => {
769 let mut reader = PacketReader::new(&payload);
770 if payload[0] == 0x00 {
771 if let Some(ok) = reader.parse_ok_packet() {
772 self.status_flags = ok.status_flags;
773 self.warnings = ok.warnings;
774 }
775 } else if payload[0] == 0xFE {
776 if let Some(eof) = reader.parse_eof_packet() {
777 self.status_flags = eof.status_flags;
778 self.warnings = eof.warnings;
779 }
780 }
781 break;
782 }
783 PacketType::Error => {
784 let mut reader = PacketReader::new(&payload);
785 let err = reader
786 .parse_err_packet()
787 .ok_or_else(|| protocol_error("Invalid error packet"))?;
788 self.state = ConnectionState::Ready;
789 return Err(query_error(&err));
790 }
791 _ => {
792 let row = self.parse_text_row(&payload, &columns);
793 rows.push(row);
794 }
795 }
796 }
797
798 self.state =
799 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
800 ConnectionState::InTransaction
801 } else {
802 ConnectionState::Ready
803 };
804
805 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
807 let sql_upper = sql.trim().to_uppercase();
808 if sql_upper.starts_with("SHOW") {
809 self.emit_show_results(sql, &col_names, &rows, elapsed_ms);
810 } else {
811 self.emit_query_timing(sql, elapsed_ms, rows.len());
812 }
813 self.emit_warnings(self.warnings);
814
815 Ok(rows)
816 }
817
818 #[allow(clippy::result_large_err)]
820 fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
821 let mut reader = PacketReader::new(data);
822
823 let catalog = reader
824 .read_lenenc_string()
825 .ok_or_else(|| protocol_error("Missing catalog"))?;
826 let schema = reader
827 .read_lenenc_string()
828 .ok_or_else(|| protocol_error("Missing schema"))?;
829 let table = reader
830 .read_lenenc_string()
831 .ok_or_else(|| protocol_error("Missing table"))?;
832 let org_table = reader
833 .read_lenenc_string()
834 .ok_or_else(|| protocol_error("Missing org_table"))?;
835 let name = reader
836 .read_lenenc_string()
837 .ok_or_else(|| protocol_error("Missing name"))?;
838 let org_name = reader
839 .read_lenenc_string()
840 .ok_or_else(|| protocol_error("Missing org_name"))?;
841
842 let _fixed_len = reader.read_lenenc_int();
844
845 let charset = reader
846 .read_u16_le()
847 .ok_or_else(|| protocol_error("Missing charset"))?;
848 let column_length = reader
849 .read_u32_le()
850 .ok_or_else(|| protocol_error("Missing column_length"))?;
851 let column_type = FieldType::from_u8(
852 reader
853 .read_u8()
854 .ok_or_else(|| protocol_error("Missing column_type"))?,
855 );
856 let flags = reader
857 .read_u16_le()
858 .ok_or_else(|| protocol_error("Missing flags"))?;
859 let decimals = reader
860 .read_u8()
861 .ok_or_else(|| protocol_error("Missing decimals"))?;
862
863 Ok(ColumnDef {
864 catalog,
865 schema,
866 table,
867 org_table,
868 name,
869 org_name,
870 charset,
871 column_length,
872 column_type,
873 flags,
874 decimals,
875 })
876 }
877
878 fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
880 let mut reader = PacketReader::new(data);
881 let mut values = Vec::with_capacity(columns.len());
882
883 for col in columns {
884 if reader.peek() == Some(0xFB) {
887 reader.skip(1);
888 values.push(Value::Null);
889 } else if let Some(data) = reader.read_lenenc_bytes() {
890 let is_unsigned = col.is_unsigned();
891 let value = decode_text_value(col.column_type, &data, is_unsigned);
892 values.push(value);
893 } else {
894 values.push(Value::Null);
895 }
896 }
897
898 let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
900
901 Row::new(column_names, values)
902 }
903
904 #[allow(clippy::result_large_err)]
906 pub fn query_one_sync(&mut self, sql: &str, params: &[Value]) -> Result<Option<Row>, Error> {
907 let rows = self.query_sync(sql, params)?;
908 Ok(rows.into_iter().next())
909 }
910
911 #[allow(clippy::result_large_err)]
913 pub fn execute_sync(&mut self, sql: &str, params: &[Value]) -> Result<u64, Error> {
914 self.query_sync(sql, params)?;
915 Ok(self.affected_rows)
916 }
917
918 #[allow(clippy::result_large_err)]
920 pub fn insert_sync(&mut self, sql: &str, params: &[Value]) -> Result<i64, Error> {
921 self.query_sync(sql, params)?;
922 Ok(self.last_insert_id as i64)
923 }
924
925 #[allow(clippy::result_large_err)]
927 pub fn ping(&mut self) -> Result<(), Error> {
928 self.sequence_id = 0;
929
930 let mut writer = PacketWriter::new();
931 writer.write_u8(Command::Ping as u8);
932 self.write_packet(writer.as_bytes())?;
933
934 let (payload, _) = self.read_packet()?;
935
936 if payload.first() == Some(&0x00) {
937 Ok(())
938 } else {
939 Err(connection_error("Ping failed"))
940 }
941 }
942
943 #[allow(clippy::result_large_err)]
945 pub fn close(mut self) -> Result<(), Error> {
946 if self.state == ConnectionState::Closed {
947 return Ok(());
948 }
949
950 self.sequence_id = 0;
951
952 let mut writer = PacketWriter::new();
953 writer.write_u8(Command::Quit as u8);
954
955 let _ = self.write_packet(writer.as_bytes());
957
958 self.state = ConnectionState::Closed;
959 Ok(())
960 }
961
962 #[allow(clippy::result_large_err)]
964 fn read_packet(&mut self) -> Result<(Vec<u8>, u8), Error> {
965 let mut header_buf = [0u8; 4];
967 self.stream.read_exact(&mut header_buf).map_err(|e| {
968 Error::Connection(ConnectionError {
969 kind: ConnectionErrorKind::Disconnected,
970 message: format!("Failed to read packet header: {}", e),
971 source: Some(Box::new(e)),
972 })
973 })?;
974
975 let header = PacketHeader::from_bytes(&header_buf);
976 let payload_len = header.payload_length as usize;
977 self.sequence_id = header.sequence_id.wrapping_add(1);
978
979 let mut payload = vec![0u8; payload_len];
981 if payload_len > 0 {
982 self.stream.read_exact(&mut payload).map_err(|e| {
983 Error::Connection(ConnectionError {
984 kind: ConnectionErrorKind::Disconnected,
985 message: format!("Failed to read packet payload: {}", e),
986 source: Some(Box::new(e)),
987 })
988 })?;
989 }
990
991 if payload_len == MAX_PACKET_SIZE {
993 loop {
994 let mut header_buf = [0u8; 4];
995 self.stream.read_exact(&mut header_buf).map_err(|e| {
996 Error::Connection(ConnectionError {
997 kind: ConnectionErrorKind::Disconnected,
998 message: format!("Failed to read continuation header: {}", e),
999 source: Some(Box::new(e)),
1000 })
1001 })?;
1002
1003 let cont_header = PacketHeader::from_bytes(&header_buf);
1004 let cont_len = cont_header.payload_length as usize;
1005 self.sequence_id = cont_header.sequence_id.wrapping_add(1);
1006
1007 if cont_len > 0 {
1008 let mut cont_payload = vec![0u8; cont_len];
1009 self.stream.read_exact(&mut cont_payload).map_err(|e| {
1010 Error::Connection(ConnectionError {
1011 kind: ConnectionErrorKind::Disconnected,
1012 message: format!("Failed to read continuation payload: {}", e),
1013 source: Some(Box::new(e)),
1014 })
1015 })?;
1016 payload.extend_from_slice(&cont_payload);
1017 }
1018
1019 if cont_len < MAX_PACKET_SIZE {
1020 break;
1021 }
1022 }
1023 }
1024
1025 Ok((payload, header.sequence_id))
1026 }
1027
1028 #[allow(clippy::result_large_err)]
1030 fn write_packet(&mut self, payload: &[u8]) -> Result<(), Error> {
1031 let writer = PacketWriter::new();
1032 let packet = writer.build_packet_from_payload(payload, self.sequence_id);
1033 self.sequence_id = self.sequence_id.wrapping_add(1);
1034
1035 self.stream.write_all(&packet).map_err(|e| {
1036 Error::Connection(ConnectionError {
1037 kind: ConnectionErrorKind::Disconnected,
1038 message: format!("Failed to write packet: {}", e),
1039 source: Some(Box::new(e)),
1040 })
1041 })?;
1042
1043 self.stream.flush().map_err(|e| {
1044 Error::Connection(ConnectionError {
1045 kind: ConnectionErrorKind::Disconnected,
1046 message: format!("Failed to flush stream: {}", e),
1047 source: Some(Box::new(e)),
1048 })
1049 })?;
1050
1051 Ok(())
1052 }
1053}
1054
1055#[cfg(feature = "console")]
1057impl ConsoleAware for MySqlConnection {
1058 fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
1059 self.console = console;
1060 }
1061
1062 fn console(&self) -> Option<&Arc<SqlModelConsole>> {
1063 self.console.as_ref()
1064 }
1065}
1066
1067#[cfg(feature = "console")]
1068impl MySqlConnection {
1069 #[allow(dead_code)]
1073 fn emit_connection_progress(&self, stage: &str, status: &str, is_final: bool) {
1074 if let Some(console) = &self.console {
1075 let mode = console.mode();
1076 match mode {
1077 sqlmodel_console::OutputMode::Plain => {
1078 if is_final {
1079 console.status(&format!("[MySQL] {}: {}", stage, status));
1080 }
1081 }
1082 sqlmodel_console::OutputMode::Rich => {
1083 let status_icon = if status.starts_with("OK") || status.starts_with("Connected")
1084 {
1085 "✓"
1086 } else if status.starts_with("Error") || status.starts_with("Failed") {
1087 "✗"
1088 } else {
1089 "…"
1090 };
1091 console.status(&format!(" {} {}: {}", status_icon, stage, status));
1092 }
1093 sqlmodel_console::OutputMode::Json => {
1094 }
1096 }
1097 }
1098 }
1099
1100 fn emit_query_timing(&self, sql: &str, elapsed_ms: f64, row_count: usize) {
1102 if let Some(console) = &self.console {
1103 let mode = console.mode();
1104 let sql_preview: String = sql.chars().take(60).collect();
1105 let sql_display = if sql.len() > 60 {
1106 format!("{}...", sql_preview)
1107 } else {
1108 sql_preview
1109 };
1110
1111 match mode {
1112 sqlmodel_console::OutputMode::Plain => {
1113 console.status(&format!(
1114 "[MySQL] Query: {:.2}ms, {} rows | {}",
1115 elapsed_ms, row_count, sql_display
1116 ));
1117 }
1118 sqlmodel_console::OutputMode::Rich => {
1119 let time_color = if elapsed_ms < 10.0 {
1120 "\x1b[32m" } else if elapsed_ms < 100.0 {
1122 "\x1b[33m" } else {
1124 "\x1b[31m" };
1126 console.status(&format!(
1127 " ⏱ {}{:.2}ms\x1b[0m ({} rows) {}",
1128 time_color, elapsed_ms, row_count, sql_display
1129 ));
1130 }
1131 sqlmodel_console::OutputMode::Json => {
1132 }
1134 }
1135 }
1136 }
1137
1138 fn emit_execute_timing(&self, sql: &str, elapsed_ms: f64, affected_rows: u64) {
1140 if let Some(console) = &self.console {
1141 let mode = console.mode();
1142 let sql_preview: String = sql.chars().take(60).collect();
1143 let sql_display = if sql.len() > 60 {
1144 format!("{}...", sql_preview)
1145 } else {
1146 sql_preview
1147 };
1148
1149 match mode {
1150 sqlmodel_console::OutputMode::Plain => {
1151 console.status(&format!(
1152 "[MySQL] Execute: {:.2}ms, {} affected | {}",
1153 elapsed_ms, affected_rows, sql_display
1154 ));
1155 }
1156 sqlmodel_console::OutputMode::Rich => {
1157 let time_color = if elapsed_ms < 10.0 {
1158 "\x1b[32m"
1159 } else if elapsed_ms < 100.0 {
1160 "\x1b[33m"
1161 } else {
1162 "\x1b[31m"
1163 };
1164 console.status(&format!(
1165 " ⏱ {}{:.2}ms\x1b[0m ({} affected) {}",
1166 time_color, elapsed_ms, affected_rows, sql_display
1167 ));
1168 }
1169 sqlmodel_console::OutputMode::Json => {}
1170 }
1171 }
1172 }
1173
1174 fn emit_warnings(&self, warning_count: u16) {
1176 if warning_count == 0 {
1177 return;
1178 }
1179 if let Some(console) = &self.console {
1180 let mode = console.mode();
1181 match mode {
1182 sqlmodel_console::OutputMode::Plain => {
1183 console.warning(&format!("[MySQL] {} warning(s)", warning_count));
1184 }
1185 sqlmodel_console::OutputMode::Rich => {
1186 console.warning(&format!("{} warning(s)", warning_count));
1187 }
1188 sqlmodel_console::OutputMode::Json => {}
1189 }
1190 }
1191 }
1192
1193 fn emit_show_results(&self, sql: &str, col_names: &[String], rows: &[Row], elapsed_ms: f64) {
1195 if let Some(console) = &self.console {
1196 let mode = console.mode();
1197 let sql_upper = sql.trim().to_uppercase();
1198
1199 if !sql_upper.starts_with("SHOW") {
1201 self.emit_query_timing(sql, elapsed_ms, rows.len());
1202 return;
1203 }
1204
1205 match mode {
1206 sqlmodel_console::OutputMode::Plain | sqlmodel_console::OutputMode::Rich => {
1207 let mut widths: Vec<usize> = col_names.iter().map(|n| n.len()).collect();
1209 for row in rows {
1210 for (i, val) in row.values().enumerate() {
1211 if i < widths.len() {
1212 let val_str = format_value(val);
1213 widths[i] = widths[i].max(val_str.len());
1214 }
1215 }
1216 }
1217
1218 let header: String = col_names
1220 .iter()
1221 .zip(&widths)
1222 .map(|(name, width)| format!("{:width$}", name, width = width))
1223 .collect::<Vec<_>>()
1224 .join(" | ");
1225
1226 let separator: String = widths
1227 .iter()
1228 .map(|w| "-".repeat(*w))
1229 .collect::<Vec<_>>()
1230 .join("-+-");
1231
1232 console.status(&header);
1233 console.status(&separator);
1234
1235 for row in rows {
1236 let row_str: String = row
1237 .values()
1238 .zip(&widths)
1239 .map(|(val, width)| {
1240 format!("{:width$}", format_value(val), width = width)
1241 })
1242 .collect::<Vec<_>>()
1243 .join(" | ");
1244 console.status(&row_str);
1245 }
1246
1247 console.status(&format!("({} rows, {:.2}ms)\n", rows.len(), elapsed_ms));
1248 }
1249 sqlmodel_console::OutputMode::Json => {}
1250 }
1251 }
1252 }
1253}
1254
1255#[cfg(feature = "console")]
1257fn format_value(value: &Value) -> String {
1258 match value {
1259 Value::Null => "NULL".to_string(),
1260 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
1261 Value::TinyInt(i) => i.to_string(),
1262 Value::SmallInt(i) => i.to_string(),
1263 Value::Int(i) => i.to_string(),
1264 Value::BigInt(i) => i.to_string(),
1265 Value::Float(f) => format!("{:.6}", f),
1266 Value::Double(f) => format!("{:.6}", f),
1267 Value::Decimal(d) => d.clone(),
1268 Value::Text(s) => s.clone(),
1269 Value::Bytes(b) => format!("<{} bytes>", b.len()),
1270 Value::Date(d) => format!("date:{}", d),
1271 Value::Time(t) => format!("time:{}", t),
1272 Value::Timestamp(ts) => format!("ts:{}", ts),
1273 Value::TimestampTz(ts) => format!("tstz:{}", ts),
1274 Value::Uuid(u) => {
1275 format!(
1276 "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
1277 u[0],
1278 u[1],
1279 u[2],
1280 u[3],
1281 u[4],
1282 u[5],
1283 u[6],
1284 u[7],
1285 u[8],
1286 u[9],
1287 u[10],
1288 u[11],
1289 u[12],
1290 u[13],
1291 u[14],
1292 u[15]
1293 )
1294 }
1295 Value::Json(j) => j.to_string(),
1296 Value::Array(arr) => format!("[{} items]", arr.len()),
1297 Value::Default => "DEFAULT".to_string(),
1298 }
1299}
1300
1301fn protocol_error(msg: impl Into<String>) -> Error {
1304 Error::Protocol(ProtocolError {
1305 message: msg.into(),
1306 raw_data: None,
1307 source: None,
1308 })
1309}
1310
1311fn auth_error(msg: impl Into<String>) -> Error {
1312 Error::Connection(ConnectionError {
1313 kind: ConnectionErrorKind::Authentication,
1314 message: msg.into(),
1315 source: None,
1316 })
1317}
1318
1319fn connection_error(msg: impl Into<String>) -> Error {
1320 Error::Connection(ConnectionError {
1321 kind: ConnectionErrorKind::Connect,
1322 message: msg.into(),
1323 source: None,
1324 })
1325}
1326
1327fn mysql_server_uses_oaep(server_version: &str) -> bool {
1328 let prefix: String = server_version
1332 .chars()
1333 .take_while(|c| c.is_ascii_digit() || *c == '.')
1334 .collect();
1335 let mut it = prefix.split('.').filter(|s| !s.is_empty());
1336 let major: u64 = match it.next().and_then(|s| s.parse().ok()) {
1337 Some(v) => v,
1338 None => return true,
1339 };
1340 let minor: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
1341 let patch: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
1342
1343 (major, minor, patch) >= (8, 0, 5)
1344}
1345
1346fn query_error(err: &ErrPacket) -> Error {
1347 let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
1348 QueryErrorKind::Constraint
1349 } else {
1350 QueryErrorKind::Syntax
1351 };
1352
1353 Error::Query(QueryError {
1354 kind,
1355 message: err.error_message.clone(),
1356 sqlstate: Some(err.sql_state.clone()),
1357 sql: None,
1358 detail: None,
1359 hint: None,
1360 position: None,
1361 source: None,
1362 })
1363}
1364
1365fn query_error_msg(msg: impl Into<String>) -> Error {
1366 Error::Query(QueryError {
1367 kind: QueryErrorKind::Syntax,
1368 message: msg.into(),
1369 sqlstate: None,
1370 sql: None,
1371 detail: None,
1372 hint: None,
1373 position: None,
1374 source: None,
1375 })
1376}
1377
1378#[cfg(test)]
1379mod tests {
1380 use super::*;
1381
1382 #[test]
1383 fn test_connection_state_default() {
1384 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
1385 }
1386
1387 #[test]
1388 fn test_error_helpers() {
1389 let err = protocol_error("test error");
1390 assert!(matches!(err, Error::Protocol(_)));
1391
1392 let err = auth_error("auth failed");
1393 assert!(matches!(err, Error::Connection(_)));
1394
1395 let err = connection_error("connection failed");
1396 assert!(matches!(err, Error::Connection(_)));
1397 }
1398
1399 #[test]
1400 fn test_query_error_duplicate_key() {
1401 let err_packet = ErrPacket {
1402 error_code: 1062,
1403 sql_state: "23000".to_string(),
1404 error_message: "Duplicate entry".to_string(),
1405 };
1406
1407 let err = query_error(&err_packet);
1408 assert!(matches!(err, Error::Query(_)), "Expected query error");
1409 let Error::Query(q) = err else { return };
1410 assert_eq!(q.kind, QueryErrorKind::Constraint);
1411 }
1412
1413 #[cfg(feature = "console")]
1415 mod console_tests {
1416 use super::*;
1417 use sqlmodel_console::{ConsoleAware, OutputMode, SqlModelConsole};
1418
1419 fn assert_console_aware<T: ConsoleAware>() {}
1420
1421 #[test]
1422 fn test_console_aware_trait_impl() {
1423 let config = MySqlConfig::new()
1426 .host("localhost")
1427 .port(13306)
1428 .user("test")
1429 .password("test");
1430
1431 assert_console_aware::<MySqlConnection>();
1435
1436 assert_eq!(config.host, "localhost");
1438 assert_eq!(config.port, 13306);
1439 }
1440
1441 #[test]
1442 fn test_format_value_all_types() {
1443 assert_eq!(format_value(&Value::Null), "NULL");
1445 assert_eq!(format_value(&Value::Bool(true)), "true");
1446 assert_eq!(format_value(&Value::Bool(false)), "false");
1447 assert_eq!(format_value(&Value::TinyInt(42)), "42");
1448 assert_eq!(format_value(&Value::SmallInt(1000)), "1000");
1449 assert_eq!(format_value(&Value::Int(123_456)), "123456");
1450 assert_eq!(format_value(&Value::BigInt(9_999_999_999)), "9999999999");
1451 assert!(format_value(&Value::Float(1.5)).starts_with("1.5"));
1452 assert!(format_value(&Value::Double(1.234_567_890)).starts_with("1.23456"));
1453 assert_eq!(
1454 format_value(&Value::Decimal("123.45".to_string())),
1455 "123.45"
1456 );
1457 assert_eq!(format_value(&Value::Text("hello".to_string())), "hello");
1458 assert_eq!(format_value(&Value::Bytes(vec![1, 2, 3])), "<3 bytes>");
1459 assert!(format_value(&Value::Date(19000)).contains("date:"));
1460 assert!(format_value(&Value::Time(43_200_000_000)).contains("time:"));
1461 assert!(format_value(&Value::Timestamp(1_700_000_000_000_000)).contains("ts:"));
1462 assert!(format_value(&Value::TimestampTz(1_700_000_000_000_000)).contains("tstz:"));
1463
1464 let uuid = [0u8; 16];
1465 let uuid_str = format_value(&Value::Uuid(uuid));
1466 assert_eq!(uuid_str, "00000000-0000-0000-0000-000000000000");
1467
1468 let json = serde_json::json!({"key": "value"});
1469 let json_str = format_value(&Value::Json(json));
1470 assert!(json_str.contains("key"));
1471
1472 let arr = vec![Value::Int(1), Value::Int(2)];
1473 assert_eq!(format_value(&Value::Array(arr)), "[2 items]");
1474 }
1475
1476 #[test]
1477 fn test_plain_mode_output_format() {
1478 let plain_console = SqlModelConsole::with_mode(OutputMode::Plain);
1480 assert!(plain_console.is_plain());
1481
1482 let rich_console = SqlModelConsole::with_mode(OutputMode::Rich);
1483 assert!(rich_console.is_rich());
1484
1485 let json_console = SqlModelConsole::with_mode(OutputMode::Json);
1486 assert!(json_console.is_json());
1487 }
1488
1489 #[test]
1490 fn test_console_mode_detection() {
1491 let console = SqlModelConsole::with_mode(OutputMode::Plain);
1493 assert!(console.is_plain());
1494 assert!(!console.is_rich());
1495 assert!(!console.is_json());
1496
1497 assert_eq!(console.mode(), OutputMode::Plain);
1498 }
1499
1500 #[test]
1501 fn test_format_value_uuid() {
1502 let uuid: [u8; 16] = [
1503 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc,
1504 0xde, 0xf0,
1505 ];
1506 let result = format_value(&Value::Uuid(uuid));
1507 assert_eq!(result, "12345678-9abc-def0-1234-56789abcdef0");
1508 }
1509
1510 #[test]
1511 fn test_format_value_nested_json() {
1512 let json = serde_json::json!({
1513 "users": [
1514 {"name": "Alice", "age": 30},
1515 {"name": "Bob", "age": 25}
1516 ]
1517 });
1518 let result = format_value(&Value::Json(json));
1519 assert!(result.contains("users"));
1520 assert!(result.contains("Alice"));
1521 }
1522 }
1523}