1mod errors;
9use errors::debug;
10use errors::DriverError;
11
12mod results;
14use results::ResultSet;
15use results::Value;
16
17use std::convert::TryInto;
19use std::fmt;
20use std::fs;
21use std::io::BufRead;
22use std::net::Shutdown;
23use std::net::TcpStream;
24use std::os::unix::net::UnixStream;
25use std::str::from_utf8;
26use std::str::FromStr;
27
28use chrono::prelude::*;
30
31use url::Url;
33
34use bufstream::BufStream;
36use native_tls::{TlsConnector, TlsStream};
37use openssl::hash::MessageDigest;
38use openssl::pkey::PKey;
39use openssl::rsa::Rsa;
40use openssl::sign::Signer;
41use std::io::{Read, Write};
42
43use byteorder::{ByteOrder, LittleEndian};
45
46mod ClientProtocol;
48use ClientProtocol::{
49 BeginSessionRequest, BeginSessionResponse, ClientAuthenticationRequest,
50 ClientAuthenticationResponse, Command, ServerResponse,
51};
52mod ColumnDataType;
53mod CommonTypes;
54
55enum ConnBufStream {
56 PlainBufUnixSocket(BufStream<UnixStream>),
57 PlainBufStream(BufStream<TcpStream>),
58 TlsBufStream(BufStream<TlsStream<TcpStream>>),
59}
60
61impl Read for ConnBufStream {
62 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
63 match *self {
64 ConnBufStream::PlainBufUnixSocket(ref mut stream) => stream.read(buf),
65 ConnBufStream::PlainBufStream(ref mut stream) => stream.read(buf),
66 ConnBufStream::TlsBufStream(ref mut stream) => stream.read(buf),
67 }
68 }
69}
70
71impl BufRead for ConnBufStream {
72 fn fill_buf(&mut self) -> std::result::Result<&[u8], std::io::Error> {
73 match *self {
74 ConnBufStream::PlainBufUnixSocket(ref mut stream) => stream.fill_buf(),
75 ConnBufStream::PlainBufStream(ref mut stream) => stream.fill_buf(),
76 ConnBufStream::TlsBufStream(ref mut stream) => stream.fill_buf(),
77 }
78 }
79 fn consume(&mut self, amt: usize) {
80 match *self {
81 ConnBufStream::PlainBufUnixSocket(ref mut stream) => stream.consume(amt),
82 ConnBufStream::PlainBufStream(ref mut stream) => stream.consume(amt),
83 ConnBufStream::TlsBufStream(ref mut stream) => stream.consume(amt),
84 }
85 }
86}
87
88impl Write for ConnBufStream {
89 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
90 match *self {
91 ConnBufStream::PlainBufUnixSocket(ref mut stream) => stream.write(buf),
92 ConnBufStream::PlainBufStream(ref mut stream) => stream.write(buf),
93 ConnBufStream::TlsBufStream(ref mut stream) => stream.write(buf),
94 }
95 }
96
97 fn flush(&mut self) -> std::io::Result<()> {
98 match *self {
99 ConnBufStream::PlainBufUnixSocket(ref mut stream) => stream.flush(),
100 ConnBufStream::PlainBufStream(ref mut stream) => stream.flush(),
101 ConnBufStream::TlsBufStream(ref mut stream) => stream.flush(),
102 }
103 }
104}
105
106enum ConnStream {
107 UnixSocketStream(UnixStream),
108 TcpStream(TcpStream),
109}
110
111impl ConnStream {
112 fn shutdown(&mut self) -> std::io::Result<()> {
113 match *self {
114 ConnStream::UnixSocketStream(ref mut stream) => stream.shutdown(Shutdown::Both),
115 ConnStream::TcpStream(ref mut stream) => stream.shutdown(Shutdown::Both),
116 }
117 }
118}
119
120pub struct SiodbConn {
128 scheme: String,
129 host: String,
130 port: u16,
131 unix_socket: String,
132 user: String,
133 pkfile: String,
134 trace: bool,
135 stream: Option<ConnStream>,
136 buf_stream: Option<ConnBufStream>,
137 result_set: Option<ResultSet>,
138}
139
140impl fmt::Debug for SiodbConn {
141 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143 write!(
144 f,
145 "scheme: {} | host: {} | port: {} | user: {} | pkfile: {} | trace: {}",
146 self.scheme, self.host, self.port, self.user, self.pkfile, self.trace,
147 )
148 }
149}
150
151impl SiodbConn {
152 fn parse_uri(uri_str: &str) -> Result<SiodbConn, DriverError> {
153 let uri = Url::parse(uri_str).expect(&format!("Unable to parse URI"));
154
155 let pairs = uri.query_pairs();
156 let mut unix_socket = "/run/siodb/siodb.socket".to_string();
157 let mut pkfile = "~/.ssh/id_rsa".to_string();
158 let mut trace = false;
159 for pair in pairs {
160 match pair.0 {
161 _ if pair.0.to_string() == String::from("unix_socket") => {
162 unix_socket = pair.1.to_string()
163 }
164 _ if pair.0.to_string() == String::from("identity_file") => {
165 pkfile = pair.1.to_string()
166 }
167 _ if pair.0.to_string() == String::from("trace") => {
168 trace = bool::from_str(&pair.1.to_string()).unwrap_or(trace)
169 }
170 _ => return Err(DriverError::new(&format!("Unknow option: {}.", &pair.0))),
171 }
172 }
173
174 let mut scheme: String = "siodbs".to_string();
176 if uri.scheme().len() > 0 {
177 scheme = uri.scheme().to_string();
178 }
179 if scheme != "siodb" && scheme != "siodbs" && scheme != "siodbu" {
180 return Err(DriverError::new(&format!(
181 "Wrong protocol: '{}'. Should be 'siodb', 'siodbs' or 'siodbu'.",
182 scheme
183 )));
184 }
185 let mut host: String = "localhost".to_string();
186 if !uri.host().is_none() {
187 host = uri.host().unwrap().to_string();
188 }
189 let port = uri.port().unwrap_or(50000);
190 let mut user: String = "root".to_string();
191 if uri.username().len() > 0 {
192 user = uri.username().to_string();
193 }
194
195 Ok(SiodbConn {
196 scheme,
197 host,
198 port,
199 unix_socket: unix_socket,
200 user,
201 pkfile,
202 trace,
203 stream: None,
204 buf_stream: None,
205 result_set: None,
206 })
207 }
208
209 pub fn new(uri_str: &str) -> Result<SiodbConn, DriverError> {
211 let mut siodb_conn = SiodbConn::parse_uri(uri_str).unwrap();
212 debug(siodb_conn.trace, &format!("siodb_conn: {:?}", siodb_conn));
213 siodb_conn.connect()?;
214 siodb_conn.authenticate()?;
215 Ok(siodb_conn)
216 }
217 fn connect(&mut self) -> Result<(), DriverError> {
218 match self.scheme.as_str() {
219 "siodb" => {
220 let stream = TcpStream::connect(format!("{}:{}", self.host, self.port))
222 .expect(&format!("Cannot connect to '{}:{}'", self.host, self.port));
223 self.stream = Some(ConnStream::TcpStream(stream.try_clone().unwrap()));
224 self.buf_stream = Some(ConnBufStream::PlainBufStream(BufStream::new(stream)));
225 }
226 "siodbs" => {
227 let mut builder = TlsConnector::builder();
231 builder.danger_accept_invalid_hostnames(true);
232 builder.danger_accept_invalid_certs(true);
233 let tls_connector = builder.build().unwrap();
234 let stream = TcpStream::connect(format!("{}:{}", self.host, self.port))
235 .expect(&format!("Cannot connect to '{}:{}'", self.host, self.port));
236 self.stream = Some(ConnStream::TcpStream(stream.try_clone().unwrap()));
237 let stream = tls_connector.connect(&self.host, stream).unwrap();
238 self.buf_stream = Some(ConnBufStream::TlsBufStream(BufStream::new(stream)));
239 }
240 "siodbu" => {
241 let stream = UnixStream::connect(format!("{}", self.unix_socket))
243 .expect(&format!("Cannot connect to socket '{}'", self.unix_socket));
244 self.stream = Some(ConnStream::UnixSocketStream(stream.try_clone().unwrap()));
245 self.buf_stream = Some(ConnBufStream::PlainBufUnixSocket(BufStream::new(stream)));
246 }
247 _ => {
248 return Err(DriverError::new(&format!("Protocol unknown.")));
249 }
250 }
251
252 Ok(())
253 }
254
255 pub fn close(&mut self) -> Result<(), DriverError> {
257 self.stream
258 .as_mut()
259 .unwrap()
260 .shutdown()
261 .expect(&format!("Error while closing connection."));
262
263 Ok(())
264 }
265 fn authenticate(&mut self) -> Result<(), DriverError> {
266 let mut begin_session_request = BeginSessionRequest::new();
268 begin_session_request.set_user_name(self.user.as_str().to_string());
269 debug(
270 self.trace,
271 &format!("begin_session_request: {:?}", begin_session_request),
272 );
273 self.write_message(5, &begin_session_request)?;
274
275 let _begin_session_response = self.read_message::<BeginSessionResponse>(6).unwrap()?;
277
278 if !_begin_session_response.get_session_started() {
279 return Err(DriverError::new(&format!("Siodb session not started.")));
280 }
281
282 let pkey = &self.pkfile;
284 let contents =
285 fs::read_to_string(pkey).expect(&format!("Error reading private key '{}'", pkey));
286 let keypair = Rsa::private_key_from_pem(contents.as_bytes())
287 .expect(&format!("Error loading private key"));
288 let keypair =
289 PKey::from_rsa(keypair).expect(&format!("Error loading private {} key", "RSA"));
290 let mut signer = Signer::new(MessageDigest::sha512(), &keypair)
291 .expect(&format!("Error creating signer"));
292 let signature = signer
293 .sign_oneshot_to_vec(_begin_session_response.get_challenge())
294 .expect(&format!("Error signing challenge"));
295
296 let mut client_authentication_request = ClientAuthenticationRequest::new();
298 client_authentication_request.set_signature(signature);
299 debug(
300 self.trace,
301 &format!(
302 "client_authentication_request: {:?}",
303 client_authentication_request
304 ),
305 );
306 self.write_message(7, &client_authentication_request)?;
307
308 let _client_authentication_response = self
310 .read_message::<ClientAuthenticationResponse>(8)
311 .unwrap()?;
312
313 if !_client_authentication_response.get_authenticated() {
314 return Err(DriverError::new(&format!("Siodb session not started.")));
315 }
316
317 Ok(())
318 }
319 fn write_message(
320 &mut self,
321 message_type: u32,
322 message: &dyn protobuf::Message,
323 ) -> Result<(), DriverError> {
324 let mut output_stream = self.buf_stream.as_mut().unwrap();
325 let mut coded_output_stream = protobuf::CodedOutputStream::new(&mut output_stream);
326
327 coded_output_stream
328 .write_raw_varint32(message_type)
329 .expect(&format!("write_message | Codec error"));
330
331 coded_output_stream
332 .write_raw_varint32(message.compute_size())
333 .expect(&format!("write_message | Codec error"));
334 &message
335 .write_to_with_cached_sizes(&mut coded_output_stream)
336 .expect(&format!("write_message | Codec error"));
337 coded_output_stream
338 .flush()
339 .expect(&format!("write_message | Codec error"));
340
341 self.buf_stream
342 .as_mut()
343 .unwrap()
344 .flush()
345 .expect(&format!("write_message | Codec error"));
346
347 Ok(())
348 }
349 fn read_message<M: protobuf::Message>(
350 &mut self,
351 message_type: u32,
352 ) -> Result<protobuf::ProtobufResult<M>, DriverError> {
353 let mut input_stream = self.buf_stream.as_mut().unwrap();
354 let mut coded_input_stream =
355 protobuf::CodedInputStream::from_buffered_reader(&mut input_stream);
356
357 let message_type_received = coded_input_stream
358 .read_raw_varint32()
359 .expect(&format!("read_message | Codec error"));
360 debug(self.trace, &format!("message_type: {:?}", message_type));
361 if message_type != message_type_received {
362 return Err(DriverError::new(&format!(
363 "read_message | wrong message type received from Siodb: {}. Expected: {}.",
364 message_type_received, message_type
365 )));
366 }
367 let message = coded_input_stream
368 .read_message()
369 .expect(&format!("read_message | Codec error"));
370
371 Ok(Ok(message))
372 }
373 pub fn execute(&mut self, sql: String) -> Result<(), DriverError> {
375 if self.result_set.is_some() && !self.result_set.as_mut().unwrap().end_of_row {
376 return Err(DriverError::new(&format!(
377 "execute | There is still data in the buffer."
378 )));
379 }
380
381 let mut command = Command::new();
383 command.set_request_id(1);
384 command.set_text(sql);
385 debug(self.trace, &format!("command: {:?}", command));
386 self.write_message(1, &command)?;
387
388 self.result_set = Some(ResultSet::new(
390 self.read_message::<ServerResponse>(2).unwrap()?,
391 )?);
392 debug(
393 self.trace,
394 &format!(
395 "ServerResponse: {:?}",
396 self.result_set.as_ref().unwrap().server_response
397 ),
398 );
399
400 if self
402 .result_set
403 .as_ref()
404 .unwrap()
405 .server_response
406 .message
407 .len()
408 > 0
409 {
410 let mut error_messages = String::new();
411 for column in &self.result_set.as_ref().unwrap().server_response.message {
412 error_messages = error_messages + &column.text.to_string();
413 }
414 return Err(DriverError::new(&format!(
415 "execute | Error message(s) {}.",
416 error_messages
417 )));
418 }
419
420 let column_count = self
422 .result_set
423 .as_ref()
424 .unwrap()
425 .server_response
426 .get_column_description()
427 .len();
428
429 if column_count > 0 {
430 self.result_set.as_mut().unwrap().end_of_row = false;
431 debug(
432 self.trace,
433 &format!(
434 "Dataset present in the the server's response with {} colmuns.",
435 self.result_set
436 .as_ref()
437 .unwrap()
438 .server_response
439 .get_column_description()
440 .len()
441 ),
442 );
443
444 for column in &self
446 .result_set
447 .as_ref()
448 .unwrap()
449 .server_response
450 .column_description
451 {
452 if column.is_null {
453 self.result_set.as_mut().unwrap().null_bit_mask_present = true;
454 debug(self.trace, &format!("null_bit_mask_present: true."));
455 if column_count % 8 == 0 {
457 self.result_set.as_mut().unwrap().null_bit_mask_byte_size =
458 (column_count / 8).try_into().unwrap();
459 } else {
460 self.result_set.as_mut().unwrap().null_bit_mask_byte_size =
461 (column_count / 8 + 1).try_into().unwrap();
462 }
463 debug(
464 self.trace,
465 &format!(
466 "null_bit_mask_byte_size: {}.",
467 self.result_set.as_mut().unwrap().null_bit_mask_byte_size
468 ),
469 );
470 break;
471 }
472 }
473 }
474
475 Ok(())
476 }
477
478 pub fn query_row(&mut self, sql: String) -> Option<Vec<Option<Value>>> {
480 let mut row: Option<Vec<Option<Value>>> = None;
481 self.execute(sql).unwrap();
482 if self.next().unwrap() {
483 row = Some(self.scan().to_vec());
484 }
485 while self.next().unwrap() {}
487 row
488 }
489 pub fn query(&mut self, sql: String) -> Result<(), DriverError> {
491 self.execute(sql)
492 }
493 pub fn next(&mut self) -> Result<bool, DriverError> {
495 let mut row = Vec::<Option<Value>>::new();
496 let mut input_stream = self.buf_stream.as_mut().unwrap();
497 let mut coded_input_stream =
498 protobuf::CodedInputStream::from_buffered_reader(&mut input_stream);
499
500 debug(self.trace, &format!("ResultSet.next() | ---"));
501
502 if self.result_set.as_ref().unwrap().end_of_row {
503 return Ok(false);
504 }
505
506 let row_length = coded_input_stream
507 .read_raw_varint32()
508 .expect(&format!("Codec error"));
509 debug(self.trace, &format!("Row bytes row_length: {}", row_length));
510 if row_length == 0 {
511 self.result_set.as_mut().unwrap().end_of_row = true;
512 return Ok(false);
513 } else {
514 self.result_set.as_mut().unwrap().row_count += 1;
515 }
516
517 let mut bit_mask: Vec<u8> = Vec::new();
519 if self.result_set.as_ref().unwrap().null_bit_mask_present {
520 bit_mask = coded_input_stream
521 .read_raw_bytes(self.result_set.as_ref().unwrap().null_bit_mask_byte_size as u32)
522 .unwrap();
523 debug(
524 self.trace,
525 &format!("ResultSet.next() | Bitmask value: {:?}.", bit_mask),
526 );
527 }
528
529 let mut is_null: u8 = 0;
531 for (idx, column) in self
532 .result_set
533 .as_ref()
534 .unwrap()
535 .server_response
536 .column_description
537 .iter()
538 .enumerate()
539 {
540 if self.result_set.as_ref().unwrap().null_bit_mask_present {
541 let mask = 1 << (idx % 8);
542 is_null = (bit_mask[idx / 8] & mask) >> (idx % 8);
543 debug(
544 self.trace,
545 &format!(
546 "ResultSet.next() | Is that cell (id: {:?} ) null?: {:?}.",
547 idx, is_null
548 ),
549 );
550 }
551
552 if is_null == 1 {
553 row.push(None)
554 } else {
555 debug(
556 self.trace,
557 &format!("read_data | data type: {:?}.", column.field_type),
558 );
559 match column.field_type {
560 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_INT8 => row.push(Some(
561 Value::Int8(coded_input_stream.read_raw_bytes(1).unwrap()[0] as i8),
562 )),
563 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_UINT8 => row.push(Some(
564 Value::Uint8(coded_input_stream.read_raw_bytes(1).unwrap()[0] as u8),
565 )),
566 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_INT16 => {
567 row.push(Some(Value::Int16(LittleEndian::read_i16(
568 &coded_input_stream.read_raw_bytes(2).unwrap(),
569 ))))
570 }
571 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_UINT16 => {
572 row.push(Some(Value::Uint16(LittleEndian::read_u16(
573 &coded_input_stream.read_raw_bytes(2).unwrap(),
574 ))))
575 }
576 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_INT32 => row.push(Some(
577 Value::Int32(coded_input_stream.read_raw_varint32().unwrap() as i32),
578 )),
579 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_UINT32 => row.push(Some(
580 Value::Uint32(coded_input_stream.read_raw_varint32().unwrap()),
581 )),
582
583 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_FLOAT => {
584 row.push(Some(Value::Float(coded_input_stream.read_float().unwrap())))
585 }
586 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_DOUBLE => row.push(Some(
587 Value::Double(coded_input_stream.read_double().unwrap()),
588 )),
589
590 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_INT64 => row.push(Some(
591 Value::Int64(coded_input_stream.read_raw_varint64().unwrap() as i64),
592 )),
593 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_UINT64 => row.push(Some(
594 Value::Uint64(coded_input_stream.read_raw_varint64().unwrap()),
595 )),
596 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_TEXT => {
597 let data_length = coded_input_stream.read_raw_varint32().unwrap();
598 row.push(Some(Value::Text(
599 from_utf8(&coded_input_stream.read_raw_bytes(data_length).unwrap())
600 .unwrap()
601 .to_string(),
602 )));
603 }
604 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_BINARY => {
605 let data_length = coded_input_stream.read_raw_varint32().unwrap();
606 row.push(Some(Value::Binary(
607 coded_input_stream.read_raw_bytes(data_length).unwrap(),
608 )));
609 }
610 ColumnDataType::ColumnDataType::COLUMN_DATA_TYPE_TIMESTAMP => {
611 let has_time_part: u8;
612 let year: i32;
613 let month: u8;
614 let day_of_week: u8;
615 let day_of_month: u8;
616 let mut hours: u8 = 0;
617 let mut minutes: u8 = 0;
618 let mut seconds: u8 = 0;
619 let mut nano: u32 = 0;
620 let date = coded_input_stream.read_raw_bytes(4).unwrap();
622 debug(
623 self.trace,
624 &format!(
625 "Binary timestamp: {:08b} {:08b} {:08b} {:08b} ",
626 date[0], date[1], date[2], date[3]
627 ),
628 );
629 has_time_part = date[0] & 0b0000_0001;
630 day_of_week = (date[0] & 0b0000_1110) >> 1;
631 day_of_month =
632 (((date[0] & 0b1111_0000) >> 4) + ((date[1] & 0b0000_0001) << 4)) + 1;
633 month = ((date[1] & 0b0001_1110) >> 1) + 1;
634 let year_bytes = [
635 0b0000_0000,
636 (date[3] & 0b1110_0000) >> 5,
637 ((date[2] & 0b1110_0000) >> 5) + ((date[3] & 0b0001_1111) << 3),
638 ((date[1] & 0b1110_0000) >> 5) + ((date[2] & 0b0001_1111) << 3),
639 ];
640 year = unsafe { std::mem::transmute::<[u8; 4], i32>(year_bytes) }.to_be();
641 debug(
642 self.trace,
643 &format!(
644 "hasTimePart: {:?} | dayOfWeek: {:?} | dayOfMonth: {:?} | month: {:?} | year: {:?} ",
645 has_time_part, day_of_week, day_of_month, month, year
646 ),
647 );
648 if has_time_part == 1 {
649 let time = coded_input_stream.read_raw_bytes(6).unwrap();
651 let nano_bytes = [
652 ((time[3] & 0b0111_1110) >> 1),
653 ((time[2] & 0b1111_1110) >> 1) + ((time[3] & 0b0000_0001) << 7),
654 ((time[1] & 0b1111_1110) >> 1) + ((time[2] & 0b0000_0001) << 7),
655 ((time[0] & 0b1111_1110) >> 1) + ((time[1] & 0b0000_0001) << 7),
656 ];
657 nano =
658 unsafe { std::mem::transmute::<[u8; 4], u32>(nano_bytes) }.to_be();
659 seconds =
660 ((time[3] & 0b1000_0000) >> 7) + ((time[4] & 0b0001_1111) << 1);
661 minutes =
662 ((time[4] & 0b1110_0000) >> 5) + ((time[5] & 0b0000_0111) << 3);
663 hours = (time[5] & 0b1111_1000) >> 3;
664 debug(
665 self.trace,
666 &format!(
667 "hours: {:?} | minutes: {:?} | seconds: {:?} | nano: {:?} | nano_bytes: {:?}",
668 hours, minutes, seconds as u32, nano, nano_bytes
669 ),
670 );
671 }
672 row.push(Some(Value::Timestamp(
673 Utc.ymd(year, month.into(), day_of_month.into())
674 .and_hms_nano(hours.into(), minutes.into(), seconds.into(), nano),
675 )));
676 }
677 _ => {
678 return Err(DriverError::new(&format!(
679 "read_data | Unknow data type: {:?}.",
680 column.field_type
681 )))
682 }
683 }
684 }
685 }
686
687 self.result_set.as_mut().unwrap().current_row = Some(row);
688
689 Ok(true)
690 }
691
692 pub fn scan(&self) -> &Vec<Option<Value>> {
694 self.result_set
695 .as_ref()
696 .unwrap()
697 .current_row
698 .as_ref()
699 .unwrap()
700 }
701
702 pub fn get_row_count(&mut self) -> u64 {
704 self.result_set.as_ref().unwrap().row_count
705 }
706
707 pub fn get_affected_row_count(&mut self) -> u64 {
709 if self
710 .result_set
711 .as_ref()
712 .unwrap()
713 .server_response
714 .get_has_affected_row_count()
715 {
716 self.result_set
717 .as_ref()
718 .unwrap()
719 .server_response
720 .get_affected_row_count()
721 } else {
722 0
723 }
724 }
725}