1#![allow(clippy::cast_possible_truncation)]
33
34use std::io::{Read, Write};
35use std::net::TcpStream;
36#[cfg(feature = "console")]
37use std::sync::Arc;
38
39use sqlmodel_core::Error;
40use sqlmodel_core::error::{
41 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
42};
43use sqlmodel_core::{Row, Value};
44
45#[cfg(feature = "console")]
46use sqlmodel_console::{ConsoleAware, SqlModelConsole};
47
48use crate::auth;
49use crate::config::MySqlConfig;
50use crate::protocol::{
51 Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
52 capabilities, charset,
53};
54use crate::types::{ColumnDef, FieldType, decode_text_value, interpolate_params};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum ConnectionState {
59 Disconnected,
61 Connecting,
63 Authenticating,
65 Ready,
67 InQuery,
69 InTransaction,
71 Error,
73 Closed,
75}
76
77#[derive(Debug, Clone)]
79pub struct ServerCapabilities {
80 pub capabilities: u32,
82 pub protocol_version: u8,
84 pub server_version: String,
86 pub connection_id: u32,
88 pub auth_plugin: String,
90 pub auth_data: Vec<u8>,
92 pub charset: u8,
94 pub status_flags: u16,
96}
97
98pub struct MySqlConnection {
103 stream: TcpStream,
105 state: ConnectionState,
107 server_caps: Option<ServerCapabilities>,
109 connection_id: u32,
111 status_flags: u16,
113 affected_rows: u64,
115 last_insert_id: u64,
117 warnings: u16,
119 config: MySqlConfig,
121 sequence_id: u8,
123 #[cfg(feature = "console")]
125 console: Option<Arc<SqlModelConsole>>,
126}
127
128impl std::fmt::Debug for MySqlConnection {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("MySqlConnection")
131 .field("state", &self.state)
132 .field("connection_id", &self.connection_id)
133 .field("host", &self.config.host)
134 .field("port", &self.config.port)
135 .field("database", &self.config.database)
136 .finish_non_exhaustive()
137 }
138}
139
140impl MySqlConnection {
141 #[allow(clippy::result_large_err)]
149 pub fn connect(config: MySqlConfig) -> Result<Self, Error> {
150 let stream = TcpStream::connect_timeout(
152 &config.socket_addr().parse().map_err(|e| {
153 Error::Connection(ConnectionError {
154 kind: ConnectionErrorKind::Connect,
155 message: format!("Invalid socket address: {}", e),
156 source: None,
157 })
158 })?,
159 config.connect_timeout,
160 )
161 .map_err(|e| {
162 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
163 ConnectionErrorKind::Refused
164 } else {
165 ConnectionErrorKind::Connect
166 };
167 Error::Connection(ConnectionError {
168 kind,
169 message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
170 source: Some(Box::new(e)),
171 })
172 })?;
173
174 stream.set_nodelay(true).ok();
176 stream.set_read_timeout(Some(config.connect_timeout)).ok();
177 stream.set_write_timeout(Some(config.connect_timeout)).ok();
178
179 let mut conn = Self {
180 stream,
181 state: ConnectionState::Connecting,
182 server_caps: None,
183 connection_id: 0,
184 status_flags: 0,
185 affected_rows: 0,
186 last_insert_id: 0,
187 warnings: 0,
188 config,
189 sequence_id: 0,
190 #[cfg(feature = "console")]
191 console: None,
192 };
193
194 let server_caps = conn.read_handshake()?;
196 conn.connection_id = server_caps.connection_id;
197 conn.server_caps = Some(server_caps);
198 conn.state = ConnectionState::Authenticating;
199
200 conn.send_handshake_response()?;
202
203 conn.handle_auth_result()?;
205
206 conn.state = ConnectionState::Ready;
207 Ok(conn)
208 }
209
210 pub fn state(&self) -> ConnectionState {
212 self.state
213 }
214
215 pub fn is_ready(&self) -> bool {
217 matches!(self.state, ConnectionState::Ready)
218 }
219
220 pub fn connection_id(&self) -> u32 {
222 self.connection_id
223 }
224
225 pub fn server_version(&self) -> Option<&str> {
227 self.server_caps
228 .as_ref()
229 .map(|caps| caps.server_version.as_str())
230 }
231
232 pub fn affected_rows(&self) -> u64 {
234 self.affected_rows
235 }
236
237 pub fn last_insert_id(&self) -> u64 {
239 self.last_insert_id
240 }
241
242 pub fn warnings(&self) -> u16 {
244 self.warnings
245 }
246
247 #[allow(clippy::result_large_err)]
249 fn read_handshake(&mut self) -> Result<ServerCapabilities, Error> {
250 let (payload, _) = self.read_packet()?;
251 let mut reader = PacketReader::new(&payload);
252
253 let protocol_version = reader
255 .read_u8()
256 .ok_or_else(|| protocol_error("Missing protocol version"))?;
257
258 if protocol_version != 10 {
259 return Err(protocol_error(format!(
260 "Unsupported protocol version: {}",
261 protocol_version
262 )));
263 }
264
265 let server_version = reader
267 .read_null_string()
268 .ok_or_else(|| protocol_error("Missing server version"))?;
269
270 let connection_id = reader
272 .read_u32_le()
273 .ok_or_else(|| protocol_error("Missing connection ID"))?;
274
275 let auth_data_1 = reader
277 .read_bytes(8)
278 .ok_or_else(|| protocol_error("Missing auth data"))?;
279
280 reader.skip(1);
282
283 let caps_lower = reader
285 .read_u16_le()
286 .ok_or_else(|| protocol_error("Missing capability flags"))?;
287
288 let charset = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
290
291 let status_flags = reader.read_u16_le().unwrap_or(0);
293
294 let caps_upper = reader.read_u16_le().unwrap_or(0);
296 let capabilities = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
297
298 let auth_data_len = if capabilities & capabilities::CLIENT_PLUGIN_AUTH != 0 {
300 reader.read_u8().unwrap_or(0) as usize
301 } else {
302 0
303 };
304
305 reader.skip(10);
307
308 let mut auth_data = auth_data_1.to_vec();
310 if capabilities & capabilities::CLIENT_SECURE_CONNECTION != 0 {
311 let len2 = if auth_data_len > 8 {
312 auth_data_len - 8
313 } else {
314 13 };
316 if let Some(data2) = reader.read_bytes(len2) {
317 let data2_clean = if data2.last() == Some(&0) {
319 &data2[..data2.len() - 1]
320 } else {
321 data2
322 };
323 auth_data.extend_from_slice(data2_clean);
324 }
325 }
326
327 let auth_plugin = if capabilities & capabilities::CLIENT_PLUGIN_AUTH != 0 {
329 reader.read_null_string().unwrap_or_default()
330 } else {
331 auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
332 };
333
334 Ok(ServerCapabilities {
335 capabilities,
336 protocol_version,
337 server_version,
338 connection_id,
339 auth_plugin,
340 auth_data,
341 charset,
342 status_flags,
343 })
344 }
345
346 #[allow(clippy::result_large_err)]
348 fn send_handshake_response(&mut self) -> Result<(), Error> {
349 let server_caps = self
350 .server_caps
351 .as_ref()
352 .ok_or_else(|| protocol_error("No server handshake received"))?;
353
354 let client_caps = self.config.capability_flags() & server_caps.capabilities;
356
357 let auth_response =
359 self.compute_auth_response(&server_caps.auth_plugin, &server_caps.auth_data);
360
361 let mut writer = PacketWriter::new();
362
363 writer.write_u32_le(client_caps);
365
366 writer.write_u32_le(self.config.max_packet_size);
368
369 writer.write_u8(self.config.charset);
371
372 writer.write_zeros(23);
374
375 writer.write_null_string(&self.config.user);
377
378 if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
380 writer.write_lenenc_bytes(&auth_response);
381 } else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
382 #[allow(clippy::cast_possible_truncation)]
384 writer.write_u8(auth_response.len() as u8);
385 writer.write_bytes(&auth_response);
386 } else {
387 writer.write_bytes(&auth_response);
388 writer.write_u8(0); }
390
391 if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
393 if let Some(ref db) = self.config.database {
394 writer.write_null_string(db);
395 } else {
396 writer.write_u8(0); }
398 }
399
400 if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
402 writer.write_null_string(&server_caps.auth_plugin);
403 }
404
405 if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
407 && !self.config.attributes.is_empty()
408 {
409 let mut attrs_writer = PacketWriter::new();
410 for (key, value) in &self.config.attributes {
411 attrs_writer.write_lenenc_string(key);
412 attrs_writer.write_lenenc_string(value);
413 }
414 let attrs_data = attrs_writer.into_bytes();
415 writer.write_lenenc_bytes(&attrs_data);
416 }
417
418 self.write_packet(writer.as_bytes())?;
419
420 Ok(())
421 }
422
423 fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
425 let password = self.config.password.as_deref().unwrap_or("");
426
427 match plugin {
428 auth::plugins::MYSQL_NATIVE_PASSWORD => {
429 auth::mysql_native_password(password, auth_data)
430 }
431 auth::plugins::CACHING_SHA2_PASSWORD => {
432 auth::caching_sha2_password(password, auth_data)
433 }
434 auth::plugins::MYSQL_CLEAR_PASSWORD => {
435 let mut result = password.as_bytes().to_vec();
437 result.push(0);
438 result
439 }
440 _ => {
441 auth::mysql_native_password(password, auth_data)
443 }
444 }
445 }
446
447 #[allow(clippy::result_large_err)]
449 fn handle_auth_result(&mut self) -> Result<(), Error> {
450 let (payload, _) = self.read_packet()?;
451
452 if payload.is_empty() {
453 return Err(protocol_error("Empty authentication response"));
454 }
455
456 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
457 PacketType::Ok => {
458 let mut reader = PacketReader::new(&payload);
460 if let Some(ok) = reader.parse_ok_packet() {
461 self.status_flags = ok.status_flags;
462 self.affected_rows = ok.affected_rows;
463 }
464 Ok(())
465 }
466 PacketType::Error => {
467 let mut reader = PacketReader::new(&payload);
468 let err = reader
469 .parse_err_packet()
470 .ok_or_else(|| protocol_error("Invalid error packet"))?;
471 Err(auth_error(format!(
472 "Authentication failed: {} ({})",
473 err.error_message, err.error_code
474 )))
475 }
476 PacketType::Eof => {
477 self.handle_auth_switch(&payload[1..])
479 }
480 _ => {
481 self.handle_additional_auth(&payload)
483 }
484 }
485 }
486
487 #[allow(clippy::result_large_err)]
489 fn handle_auth_switch(&mut self, data: &[u8]) -> Result<(), Error> {
490 let mut reader = PacketReader::new(data);
491
492 let plugin = reader
494 .read_null_string()
495 .ok_or_else(|| protocol_error("Missing plugin name in auth switch"))?;
496
497 let auth_data = reader.read_rest();
499
500 let response = self.compute_auth_response(&plugin, auth_data);
502
503 self.write_packet(&response)?;
505
506 self.handle_auth_result()
508 }
509
510 #[allow(clippy::result_large_err)]
512 fn handle_additional_auth(&mut self, data: &[u8]) -> Result<(), Error> {
513 if data.is_empty() {
514 return Err(protocol_error("Empty additional auth data"));
515 }
516
517 match data[0] {
518 auth::caching_sha2::FAST_AUTH_SUCCESS => {
519 let (payload, _) = self.read_packet()?;
521 let mut reader = PacketReader::new(&payload);
522 if let Some(ok) = reader.parse_ok_packet() {
523 self.status_flags = ok.status_flags;
524 }
525 Ok(())
526 }
527 auth::caching_sha2::PERFORM_FULL_AUTH => {
528 Err(auth_error(
531 "Full authentication required - please use TLS connection",
532 ))
533 }
534 _ => {
535 let mut reader = PacketReader::new(data);
537 if let Some(ok) = reader.parse_ok_packet() {
538 self.status_flags = ok.status_flags;
539 Ok(())
540 } else {
541 Err(protocol_error(format!(
542 "Unknown auth response: {:02X}",
543 data[0]
544 )))
545 }
546 }
547 }
548 }
549
550 #[allow(clippy::result_large_err)]
555 pub fn query_sync(&mut self, sql: &str, params: &[Value]) -> Result<Vec<Row>, Error> {
556 #[cfg(feature = "console")]
557 let start = std::time::Instant::now();
558
559 let sql = interpolate_params(sql, params);
560 if !self.is_ready() && self.state != ConnectionState::InTransaction {
561 return Err(connection_error("Connection not ready for queries"));
562 }
563
564 self.state = ConnectionState::InQuery;
565 self.sequence_id = 0;
566
567 let mut writer = PacketWriter::new();
569 writer.write_u8(Command::Query as u8);
570 writer.write_bytes(sql.as_bytes());
571 self.write_packet(writer.as_bytes())?;
572
573 let (payload, _) = self.read_packet()?;
575
576 if payload.is_empty() {
577 self.state = ConnectionState::Ready;
578 return Err(protocol_error("Empty query response"));
579 }
580
581 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
582 PacketType::Ok => {
583 let mut reader = PacketReader::new(&payload);
585 if let Some(ok) = reader.parse_ok_packet() {
586 self.affected_rows = ok.affected_rows;
587 self.last_insert_id = ok.last_insert_id;
588 self.status_flags = ok.status_flags;
589 self.warnings = ok.warnings;
590 }
591 self.state = if self.status_flags
592 & crate::protocol::server_status::SERVER_STATUS_IN_TRANS
593 != 0
594 {
595 ConnectionState::InTransaction
596 } else {
597 ConnectionState::Ready
598 };
599
600 #[cfg(feature = "console")]
601 {
602 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
603 self.emit_execute_timing(&sql, elapsed_ms, self.affected_rows);
604 self.emit_warnings(self.warnings);
605 }
606
607 Ok(vec![])
608 }
609 PacketType::Error => {
610 self.state = ConnectionState::Ready;
611 let mut reader = PacketReader::new(&payload);
612 let err = reader
613 .parse_err_packet()
614 .ok_or_else(|| protocol_error("Invalid error packet"))?;
615 Err(query_error(&err))
616 }
617 PacketType::LocalInfile => {
618 self.state = ConnectionState::Ready;
619 Err(query_error_msg("LOCAL INFILE not supported"))
620 }
621 _ => {
622 #[cfg(feature = "console")]
624 let result = self.read_result_set_with_timing(&sql, &payload, start);
625 #[cfg(not(feature = "console"))]
626 let result = self.read_result_set(&payload);
627 result
628 }
629 }
630 }
631
632 #[allow(dead_code)] #[allow(clippy::result_large_err)]
635 fn read_result_set(&mut self, first_packet: &[u8]) -> Result<Vec<Row>, Error> {
636 let mut reader = PacketReader::new(first_packet);
637 let column_count = reader
638 .read_lenenc_int()
639 .ok_or_else(|| protocol_error("Invalid column count"))?
640 as usize;
641
642 let mut columns = Vec::with_capacity(column_count);
644 for _ in 0..column_count {
645 let (payload, _) = self.read_packet()?;
646 columns.push(self.parse_column_def(&payload)?);
647 }
648
649 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
651 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
652 let (payload, _) = self.read_packet()?;
653 if payload.first() == Some(&0xFE) {
654 }
656 }
657
658 let mut rows = Vec::new();
660 loop {
661 let (payload, _) = self.read_packet()?;
662
663 if payload.is_empty() {
664 break;
665 }
666
667 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
668 PacketType::Eof | PacketType::Ok => {
669 let mut reader = PacketReader::new(&payload);
671 if payload[0] == 0x00 {
672 if let Some(ok) = reader.parse_ok_packet() {
673 self.status_flags = ok.status_flags;
674 self.warnings = ok.warnings;
675 }
676 } else if payload[0] == 0xFE {
677 if let Some(eof) = reader.parse_eof_packet() {
678 self.status_flags = eof.status_flags;
679 self.warnings = eof.warnings;
680 }
681 }
682 break;
683 }
684 PacketType::Error => {
685 let mut reader = PacketReader::new(&payload);
686 let err = reader
687 .parse_err_packet()
688 .ok_or_else(|| protocol_error("Invalid error packet"))?;
689 self.state = ConnectionState::Ready;
690 return Err(query_error(&err));
691 }
692 _ => {
693 let row = self.parse_text_row(&payload, &columns);
695 rows.push(row);
696 }
697 }
698 }
699
700 self.state =
701 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
702 ConnectionState::InTransaction
703 } else {
704 ConnectionState::Ready
705 };
706
707 Ok(rows)
708 }
709
710 #[cfg(feature = "console")]
712 #[allow(clippy::result_large_err)]
713 fn read_result_set_with_timing(
714 &mut self,
715 sql: &str,
716 first_packet: &[u8],
717 start: std::time::Instant,
718 ) -> Result<Vec<Row>, Error> {
719 let mut reader = PacketReader::new(first_packet);
720 let column_count = reader
721 .read_lenenc_int()
722 .ok_or_else(|| protocol_error("Invalid column count"))?
723 as usize;
724
725 let mut columns = Vec::with_capacity(column_count);
727 let mut col_names = Vec::with_capacity(column_count);
728 for _ in 0..column_count {
729 let (payload, _) = self.read_packet()?;
730 let col = self.parse_column_def(&payload)?;
731 col_names.push(col.name.clone());
732 columns.push(col);
733 }
734
735 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
737 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
738 let (payload, _) = self.read_packet()?;
739 if payload.first() == Some(&0xFE) {
740 }
742 }
743
744 let mut rows = Vec::new();
746 loop {
747 let (payload, _) = self.read_packet()?;
748
749 if payload.is_empty() {
750 break;
751 }
752
753 match PacketType::from_first_byte(payload[0], payload.len() as u32) {
754 PacketType::Eof | PacketType::Ok => {
755 let mut reader = PacketReader::new(&payload);
756 if payload[0] == 0x00 {
757 if let Some(ok) = reader.parse_ok_packet() {
758 self.status_flags = ok.status_flags;
759 self.warnings = ok.warnings;
760 }
761 } else if payload[0] == 0xFE {
762 if let Some(eof) = reader.parse_eof_packet() {
763 self.status_flags = eof.status_flags;
764 self.warnings = eof.warnings;
765 }
766 }
767 break;
768 }
769 PacketType::Error => {
770 let mut reader = PacketReader::new(&payload);
771 let err = reader
772 .parse_err_packet()
773 .ok_or_else(|| protocol_error("Invalid error packet"))?;
774 self.state = ConnectionState::Ready;
775 return Err(query_error(&err));
776 }
777 _ => {
778 let row = self.parse_text_row(&payload, &columns);
779 rows.push(row);
780 }
781 }
782 }
783
784 self.state =
785 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
786 ConnectionState::InTransaction
787 } else {
788 ConnectionState::Ready
789 };
790
791 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
793 let sql_upper = sql.trim().to_uppercase();
794 if sql_upper.starts_with("SHOW") {
795 self.emit_show_results(sql, &col_names, &rows, elapsed_ms);
796 } else {
797 self.emit_query_timing(sql, elapsed_ms, rows.len());
798 }
799 self.emit_warnings(self.warnings);
800
801 Ok(rows)
802 }
803
804 #[allow(clippy::result_large_err)]
806 fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
807 let mut reader = PacketReader::new(data);
808
809 let catalog = reader
810 .read_lenenc_string()
811 .ok_or_else(|| protocol_error("Missing catalog"))?;
812 let schema = reader
813 .read_lenenc_string()
814 .ok_or_else(|| protocol_error("Missing schema"))?;
815 let table = reader
816 .read_lenenc_string()
817 .ok_or_else(|| protocol_error("Missing table"))?;
818 let org_table = reader
819 .read_lenenc_string()
820 .ok_or_else(|| protocol_error("Missing org_table"))?;
821 let name = reader
822 .read_lenenc_string()
823 .ok_or_else(|| protocol_error("Missing name"))?;
824 let org_name = reader
825 .read_lenenc_string()
826 .ok_or_else(|| protocol_error("Missing org_name"))?;
827
828 let _fixed_len = reader.read_lenenc_int();
830
831 let charset = reader
832 .read_u16_le()
833 .ok_or_else(|| protocol_error("Missing charset"))?;
834 let column_length = reader
835 .read_u32_le()
836 .ok_or_else(|| protocol_error("Missing column_length"))?;
837 let column_type = FieldType::from_u8(
838 reader
839 .read_u8()
840 .ok_or_else(|| protocol_error("Missing column_type"))?,
841 );
842 let flags = reader
843 .read_u16_le()
844 .ok_or_else(|| protocol_error("Missing flags"))?;
845 let decimals = reader
846 .read_u8()
847 .ok_or_else(|| protocol_error("Missing decimals"))?;
848
849 Ok(ColumnDef {
850 catalog,
851 schema,
852 table,
853 org_table,
854 name,
855 org_name,
856 charset,
857 column_length,
858 column_type,
859 flags,
860 decimals,
861 })
862 }
863
864 fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
866 let mut reader = PacketReader::new(data);
867 let mut values = Vec::with_capacity(columns.len());
868
869 for col in columns {
870 if reader.peek() == Some(0xFB) {
873 reader.skip(1);
874 values.push(Value::Null);
875 } else if let Some(data) = reader.read_lenenc_bytes() {
876 let is_unsigned = col.is_unsigned();
877 let value = decode_text_value(col.column_type, &data, is_unsigned);
878 values.push(value);
879 } else {
880 values.push(Value::Null);
881 }
882 }
883
884 let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
886
887 Row::new(column_names, values)
888 }
889
890 #[allow(clippy::result_large_err)]
892 pub fn query_one_sync(&mut self, sql: &str, params: &[Value]) -> Result<Option<Row>, Error> {
893 let rows = self.query_sync(sql, params)?;
894 Ok(rows.into_iter().next())
895 }
896
897 #[allow(clippy::result_large_err)]
899 pub fn execute_sync(&mut self, sql: &str, params: &[Value]) -> Result<u64, Error> {
900 self.query_sync(sql, params)?;
901 Ok(self.affected_rows)
902 }
903
904 #[allow(clippy::result_large_err)]
906 pub fn insert_sync(&mut self, sql: &str, params: &[Value]) -> Result<i64, Error> {
907 self.query_sync(sql, params)?;
908 Ok(self.last_insert_id as i64)
909 }
910
911 #[allow(clippy::result_large_err)]
913 pub fn ping(&mut self) -> Result<(), Error> {
914 self.sequence_id = 0;
915
916 let mut writer = PacketWriter::new();
917 writer.write_u8(Command::Ping as u8);
918 self.write_packet(writer.as_bytes())?;
919
920 let (payload, _) = self.read_packet()?;
921
922 if payload.first() == Some(&0x00) {
923 Ok(())
924 } else {
925 Err(connection_error("Ping failed"))
926 }
927 }
928
929 #[allow(clippy::result_large_err)]
931 pub fn close(mut self) -> Result<(), Error> {
932 if self.state == ConnectionState::Closed {
933 return Ok(());
934 }
935
936 self.sequence_id = 0;
937
938 let mut writer = PacketWriter::new();
939 writer.write_u8(Command::Quit as u8);
940
941 let _ = self.write_packet(writer.as_bytes());
943
944 self.state = ConnectionState::Closed;
945 Ok(())
946 }
947
948 #[allow(clippy::result_large_err)]
950 fn read_packet(&mut self) -> Result<(Vec<u8>, u8), Error> {
951 let mut header_buf = [0u8; 4];
953 self.stream.read_exact(&mut header_buf).map_err(|e| {
954 Error::Connection(ConnectionError {
955 kind: ConnectionErrorKind::Disconnected,
956 message: format!("Failed to read packet header: {}", e),
957 source: Some(Box::new(e)),
958 })
959 })?;
960
961 let header = PacketHeader::from_bytes(&header_buf);
962 let payload_len = header.payload_length as usize;
963 self.sequence_id = header.sequence_id.wrapping_add(1);
964
965 let mut payload = vec![0u8; payload_len];
967 if payload_len > 0 {
968 self.stream.read_exact(&mut payload).map_err(|e| {
969 Error::Connection(ConnectionError {
970 kind: ConnectionErrorKind::Disconnected,
971 message: format!("Failed to read packet payload: {}", e),
972 source: Some(Box::new(e)),
973 })
974 })?;
975 }
976
977 if payload_len == MAX_PACKET_SIZE {
979 loop {
980 let mut header_buf = [0u8; 4];
981 self.stream.read_exact(&mut header_buf).map_err(|e| {
982 Error::Connection(ConnectionError {
983 kind: ConnectionErrorKind::Disconnected,
984 message: format!("Failed to read continuation header: {}", e),
985 source: Some(Box::new(e)),
986 })
987 })?;
988
989 let cont_header = PacketHeader::from_bytes(&header_buf);
990 let cont_len = cont_header.payload_length as usize;
991 self.sequence_id = cont_header.sequence_id.wrapping_add(1);
992
993 if cont_len > 0 {
994 let mut cont_payload = vec![0u8; cont_len];
995 self.stream.read_exact(&mut cont_payload).map_err(|e| {
996 Error::Connection(ConnectionError {
997 kind: ConnectionErrorKind::Disconnected,
998 message: format!("Failed to read continuation payload: {}", e),
999 source: Some(Box::new(e)),
1000 })
1001 })?;
1002 payload.extend_from_slice(&cont_payload);
1003 }
1004
1005 if cont_len < MAX_PACKET_SIZE {
1006 break;
1007 }
1008 }
1009 }
1010
1011 Ok((payload, header.sequence_id))
1012 }
1013
1014 #[allow(clippy::result_large_err)]
1016 fn write_packet(&mut self, payload: &[u8]) -> Result<(), Error> {
1017 let writer = PacketWriter::new();
1018 let packet = writer.build_packet_from_payload(payload, self.sequence_id);
1019 self.sequence_id = self.sequence_id.wrapping_add(1);
1020
1021 self.stream.write_all(&packet).map_err(|e| {
1022 Error::Connection(ConnectionError {
1023 kind: ConnectionErrorKind::Disconnected,
1024 message: format!("Failed to write packet: {}", e),
1025 source: Some(Box::new(e)),
1026 })
1027 })?;
1028
1029 self.stream.flush().map_err(|e| {
1030 Error::Connection(ConnectionError {
1031 kind: ConnectionErrorKind::Disconnected,
1032 message: format!("Failed to flush stream: {}", e),
1033 source: Some(Box::new(e)),
1034 })
1035 })?;
1036
1037 Ok(())
1038 }
1039}
1040
1041#[cfg(feature = "console")]
1043impl ConsoleAware for MySqlConnection {
1044 fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
1045 self.console = console;
1046 }
1047
1048 fn console(&self) -> Option<&Arc<SqlModelConsole>> {
1049 self.console.as_ref()
1050 }
1051}
1052
1053#[cfg(feature = "console")]
1054impl MySqlConnection {
1055 #[allow(dead_code)]
1059 fn emit_connection_progress(&self, stage: &str, status: &str, is_final: bool) {
1060 if let Some(console) = &self.console {
1061 let mode = console.mode();
1062 match mode {
1063 sqlmodel_console::OutputMode::Plain => {
1064 if is_final {
1065 console.status(&format!("[MySQL] {}: {}", stage, status));
1066 }
1067 }
1068 sqlmodel_console::OutputMode::Rich => {
1069 let status_icon = if status.starts_with("OK") || status.starts_with("Connected")
1070 {
1071 "✓"
1072 } else if status.starts_with("Error") || status.starts_with("Failed") {
1073 "✗"
1074 } else {
1075 "…"
1076 };
1077 console.status(&format!(" {} {}: {}", status_icon, stage, status));
1078 }
1079 sqlmodel_console::OutputMode::Json => {
1080 }
1082 }
1083 }
1084 }
1085
1086 fn emit_query_timing(&self, sql: &str, elapsed_ms: f64, row_count: usize) {
1088 if let Some(console) = &self.console {
1089 let mode = console.mode();
1090 let sql_preview: String = sql.chars().take(60).collect();
1091 let sql_display = if sql.len() > 60 {
1092 format!("{}...", sql_preview)
1093 } else {
1094 sql_preview
1095 };
1096
1097 match mode {
1098 sqlmodel_console::OutputMode::Plain => {
1099 console.status(&format!(
1100 "[MySQL] Query: {:.2}ms, {} rows | {}",
1101 elapsed_ms, row_count, sql_display
1102 ));
1103 }
1104 sqlmodel_console::OutputMode::Rich => {
1105 let time_color = if elapsed_ms < 10.0 {
1106 "\x1b[32m" } else if elapsed_ms < 100.0 {
1108 "\x1b[33m" } else {
1110 "\x1b[31m" };
1112 console.status(&format!(
1113 " ⏱ {}{:.2}ms\x1b[0m ({} rows) {}",
1114 time_color, elapsed_ms, row_count, sql_display
1115 ));
1116 }
1117 sqlmodel_console::OutputMode::Json => {
1118 }
1120 }
1121 }
1122 }
1123
1124 fn emit_execute_timing(&self, sql: &str, elapsed_ms: f64, affected_rows: u64) {
1126 if let Some(console) = &self.console {
1127 let mode = console.mode();
1128 let sql_preview: String = sql.chars().take(60).collect();
1129 let sql_display = if sql.len() > 60 {
1130 format!("{}...", sql_preview)
1131 } else {
1132 sql_preview
1133 };
1134
1135 match mode {
1136 sqlmodel_console::OutputMode::Plain => {
1137 console.status(&format!(
1138 "[MySQL] Execute: {:.2}ms, {} affected | {}",
1139 elapsed_ms, affected_rows, sql_display
1140 ));
1141 }
1142 sqlmodel_console::OutputMode::Rich => {
1143 let time_color = if elapsed_ms < 10.0 {
1144 "\x1b[32m"
1145 } else if elapsed_ms < 100.0 {
1146 "\x1b[33m"
1147 } else {
1148 "\x1b[31m"
1149 };
1150 console.status(&format!(
1151 " ⏱ {}{:.2}ms\x1b[0m ({} affected) {}",
1152 time_color, elapsed_ms, affected_rows, sql_display
1153 ));
1154 }
1155 sqlmodel_console::OutputMode::Json => {}
1156 }
1157 }
1158 }
1159
1160 fn emit_warnings(&self, warning_count: u16) {
1162 if warning_count == 0 {
1163 return;
1164 }
1165 if let Some(console) = &self.console {
1166 let mode = console.mode();
1167 match mode {
1168 sqlmodel_console::OutputMode::Plain => {
1169 console.warning(&format!("[MySQL] {} warning(s)", warning_count));
1170 }
1171 sqlmodel_console::OutputMode::Rich => {
1172 console.warning(&format!("{} warning(s)", warning_count));
1173 }
1174 sqlmodel_console::OutputMode::Json => {}
1175 }
1176 }
1177 }
1178
1179 fn emit_show_results(&self, sql: &str, col_names: &[String], rows: &[Row], elapsed_ms: f64) {
1181 if let Some(console) = &self.console {
1182 let mode = console.mode();
1183 let sql_upper = sql.trim().to_uppercase();
1184
1185 if !sql_upper.starts_with("SHOW") {
1187 self.emit_query_timing(sql, elapsed_ms, rows.len());
1188 return;
1189 }
1190
1191 match mode {
1192 sqlmodel_console::OutputMode::Plain | sqlmodel_console::OutputMode::Rich => {
1193 let mut widths: Vec<usize> = col_names.iter().map(|n| n.len()).collect();
1195 for row in rows {
1196 for (i, val) in row.values().enumerate() {
1197 if i < widths.len() {
1198 let val_str = format_value(val);
1199 widths[i] = widths[i].max(val_str.len());
1200 }
1201 }
1202 }
1203
1204 let header: String = col_names
1206 .iter()
1207 .zip(&widths)
1208 .map(|(name, width)| format!("{:width$}", name, width = width))
1209 .collect::<Vec<_>>()
1210 .join(" | ");
1211
1212 let separator: String = widths
1213 .iter()
1214 .map(|w| "-".repeat(*w))
1215 .collect::<Vec<_>>()
1216 .join("-+-");
1217
1218 console.status(&header);
1219 console.status(&separator);
1220
1221 for row in rows {
1222 let row_str: String = row
1223 .values()
1224 .zip(&widths)
1225 .map(|(val, width)| {
1226 format!("{:width$}", format_value(val), width = width)
1227 })
1228 .collect::<Vec<_>>()
1229 .join(" | ");
1230 console.status(&row_str);
1231 }
1232
1233 console.status(&format!("({} rows, {:.2}ms)\n", rows.len(), elapsed_ms));
1234 }
1235 sqlmodel_console::OutputMode::Json => {}
1236 }
1237 }
1238 }
1239}
1240
1241#[cfg(feature = "console")]
1243fn format_value(value: &Value) -> String {
1244 match value {
1245 Value::Null => "NULL".to_string(),
1246 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
1247 Value::TinyInt(i) => i.to_string(),
1248 Value::SmallInt(i) => i.to_string(),
1249 Value::Int(i) => i.to_string(),
1250 Value::BigInt(i) => i.to_string(),
1251 Value::Float(f) => format!("{:.6}", f),
1252 Value::Double(f) => format!("{:.6}", f),
1253 Value::Decimal(d) => d.clone(),
1254 Value::Text(s) => s.clone(),
1255 Value::Bytes(b) => format!("<{} bytes>", b.len()),
1256 Value::Date(d) => format!("date:{}", d),
1257 Value::Time(t) => format!("time:{}", t),
1258 Value::Timestamp(ts) => format!("ts:{}", ts),
1259 Value::TimestampTz(ts) => format!("tstz:{}", ts),
1260 Value::Uuid(u) => {
1261 format!(
1262 "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
1263 u[0],
1264 u[1],
1265 u[2],
1266 u[3],
1267 u[4],
1268 u[5],
1269 u[6],
1270 u[7],
1271 u[8],
1272 u[9],
1273 u[10],
1274 u[11],
1275 u[12],
1276 u[13],
1277 u[14],
1278 u[15]
1279 )
1280 }
1281 Value::Json(j) => j.to_string(),
1282 Value::Array(arr) => format!("[{} items]", arr.len()),
1283 Value::Default => "DEFAULT".to_string(),
1284 }
1285}
1286
1287fn protocol_error(msg: impl Into<String>) -> Error {
1290 Error::Protocol(ProtocolError {
1291 message: msg.into(),
1292 raw_data: None,
1293 source: None,
1294 })
1295}
1296
1297fn auth_error(msg: impl Into<String>) -> Error {
1298 Error::Connection(ConnectionError {
1299 kind: ConnectionErrorKind::Authentication,
1300 message: msg.into(),
1301 source: None,
1302 })
1303}
1304
1305fn connection_error(msg: impl Into<String>) -> Error {
1306 Error::Connection(ConnectionError {
1307 kind: ConnectionErrorKind::Connect,
1308 message: msg.into(),
1309 source: None,
1310 })
1311}
1312
1313fn query_error(err: &ErrPacket) -> Error {
1314 let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
1315 QueryErrorKind::Constraint
1316 } else {
1317 QueryErrorKind::Syntax
1318 };
1319
1320 Error::Query(QueryError {
1321 kind,
1322 message: err.error_message.clone(),
1323 sqlstate: Some(err.sql_state.clone()),
1324 sql: None,
1325 detail: None,
1326 hint: None,
1327 position: None,
1328 source: None,
1329 })
1330}
1331
1332fn query_error_msg(msg: impl Into<String>) -> Error {
1333 Error::Query(QueryError {
1334 kind: QueryErrorKind::Syntax,
1335 message: msg.into(),
1336 sqlstate: None,
1337 sql: None,
1338 detail: None,
1339 hint: None,
1340 position: None,
1341 source: None,
1342 })
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347 use super::*;
1348
1349 #[test]
1350 fn test_connection_state_default() {
1351 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
1352 }
1353
1354 #[test]
1355 fn test_error_helpers() {
1356 let err = protocol_error("test error");
1357 assert!(matches!(err, Error::Protocol(_)));
1358
1359 let err = auth_error("auth failed");
1360 assert!(matches!(err, Error::Connection(_)));
1361
1362 let err = connection_error("connection failed");
1363 assert!(matches!(err, Error::Connection(_)));
1364 }
1365
1366 #[test]
1367 fn test_query_error_duplicate_key() {
1368 let err_packet = ErrPacket {
1369 error_code: 1062,
1370 sql_state: "23000".to_string(),
1371 error_message: "Duplicate entry".to_string(),
1372 };
1373
1374 let err = query_error(&err_packet);
1375 if let Error::Query(q) = err {
1376 assert_eq!(q.kind, QueryErrorKind::Constraint);
1377 } else {
1378 panic!("Expected query error");
1379 }
1380 }
1381
1382 #[cfg(feature = "console")]
1384 mod console_tests {
1385 use super::*;
1386 use sqlmodel_console::{ConsoleAware, OutputMode, SqlModelConsole};
1387
1388 fn assert_console_aware<T: ConsoleAware>() {}
1389
1390 #[test]
1391 fn test_console_aware_trait_impl() {
1392 let config = MySqlConfig::new()
1395 .host("localhost")
1396 .port(13306)
1397 .user("test")
1398 .password("test");
1399
1400 assert_console_aware::<MySqlConnection>();
1404
1405 assert_eq!(config.host, "localhost");
1407 assert_eq!(config.port, 13306);
1408 }
1409
1410 #[test]
1411 fn test_format_value_all_types() {
1412 assert_eq!(format_value(&Value::Null), "NULL");
1414 assert_eq!(format_value(&Value::Bool(true)), "true");
1415 assert_eq!(format_value(&Value::Bool(false)), "false");
1416 assert_eq!(format_value(&Value::TinyInt(42)), "42");
1417 assert_eq!(format_value(&Value::SmallInt(1000)), "1000");
1418 assert_eq!(format_value(&Value::Int(123_456)), "123456");
1419 assert_eq!(format_value(&Value::BigInt(9_999_999_999)), "9999999999");
1420 assert!(format_value(&Value::Float(1.5)).starts_with("1.5"));
1421 assert!(format_value(&Value::Double(1.234_567_890)).starts_with("1.23456"));
1422 assert_eq!(
1423 format_value(&Value::Decimal("123.45".to_string())),
1424 "123.45"
1425 );
1426 assert_eq!(format_value(&Value::Text("hello".to_string())), "hello");
1427 assert_eq!(format_value(&Value::Bytes(vec![1, 2, 3])), "<3 bytes>");
1428 assert!(format_value(&Value::Date(19000)).contains("date:"));
1429 assert!(format_value(&Value::Time(43_200_000_000)).contains("time:"));
1430 assert!(format_value(&Value::Timestamp(1_700_000_000_000_000)).contains("ts:"));
1431 assert!(format_value(&Value::TimestampTz(1_700_000_000_000_000)).contains("tstz:"));
1432
1433 let uuid = [0u8; 16];
1434 let uuid_str = format_value(&Value::Uuid(uuid));
1435 assert_eq!(uuid_str, "00000000-0000-0000-0000-000000000000");
1436
1437 let json = serde_json::json!({"key": "value"});
1438 let json_str = format_value(&Value::Json(json));
1439 assert!(json_str.contains("key"));
1440
1441 let arr = vec![Value::Int(1), Value::Int(2)];
1442 assert_eq!(format_value(&Value::Array(arr)), "[2 items]");
1443 }
1444
1445 #[test]
1446 fn test_plain_mode_output_format() {
1447 let plain_console = SqlModelConsole::with_mode(OutputMode::Plain);
1449 assert!(plain_console.is_plain());
1450
1451 let rich_console = SqlModelConsole::with_mode(OutputMode::Rich);
1452 assert!(rich_console.is_rich());
1453
1454 let json_console = SqlModelConsole::with_mode(OutputMode::Json);
1455 assert!(json_console.is_json());
1456 }
1457
1458 #[test]
1459 fn test_console_mode_detection() {
1460 let console = SqlModelConsole::with_mode(OutputMode::Plain);
1462 assert!(console.is_plain());
1463 assert!(!console.is_rich());
1464 assert!(!console.is_json());
1465
1466 assert_eq!(console.mode(), OutputMode::Plain);
1467 }
1468
1469 #[test]
1470 fn test_format_value_uuid() {
1471 let uuid: [u8; 16] = [
1472 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc,
1473 0xde, 0xf0,
1474 ];
1475 let result = format_value(&Value::Uuid(uuid));
1476 assert_eq!(result, "12345678-9abc-def0-1234-56789abcdef0");
1477 }
1478
1479 #[test]
1480 fn test_format_value_nested_json() {
1481 let json = serde_json::json!({
1482 "users": [
1483 {"name": "Alice", "age": 30},
1484 {"name": "Bob", "age": 25}
1485 ]
1486 });
1487 let result = format_value(&Value::Json(json));
1488 assert!(result.contains("users"));
1489 assert!(result.contains("Alice"));
1490 }
1491 }
1492}