1use thiserror::Error;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub struct TokenType(u8);
6
7impl TokenType {
8 pub const ERROR: Self = Self(0xaa);
10 pub const INFO: Self = Self(0xab);
12 pub const LOGINACK: Self = Self(0xad);
14 pub const ENVCHANGE: Self = Self(0xe3);
16 pub const DONE: Self = Self(0xfd);
18
19 pub const fn code(self) -> u8 {
21 self.0
22 }
23}
24
25impl From<u8> for TokenType {
26 fn from(value: u8) -> Self {
27 Self(value)
28 }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Token {
34 LoginAck(LoginAck),
36 Error(ServerError),
38 EnvChange(EnvChange),
40 Done(Done),
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct LoginAck {
47 pub interface: u8,
49 pub tds_version: u32,
51 pub program_name: String,
53 pub major_version: u8,
55 pub minor_version: u8,
57 pub build_number_high: u8,
59 pub build_number_low: u8,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct ServerError {
66 pub number: i32,
68 pub state: u8,
70 pub class: u8,
72 pub message: String,
74 pub server_name: String,
76 pub procedure_name: String,
78 pub line_number: u32,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum EnvChange {
85 Database(String),
87 Language(String),
89 CharacterSet(String),
91 PacketSize(u32),
93 UnicodeDataSortingLocalId(String),
95 UnicodeDataSortingComparisonFlags(String),
97 SqlCollation(Vec<u8>),
99 BeginTransaction(u64),
101 CommitTransaction(u64),
103 RollbackTransaction(u64),
105 Ignored {
107 change_type: u8,
109 data: Vec<u8>,
111 },
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub struct Done {
117 pub status: u16,
119 pub current_command: u16,
121 pub row_count: u64,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq)]
127pub enum LoginResponse {
128 Success {
130 login_ack: LoginAck,
132 env_changes: Vec<EnvChange>,
134 },
135 ServerError(ServerError),
137}
138
139pub fn parse_tokens(mut input: &[u8]) -> Result<Vec<Token>, TokenParseError> {
141 let mut tokens = Vec::new();
142
143 while !input.is_empty() {
144 let token_type = TokenType::from(read_u8(&mut input)?);
145
146 let token = if token_type == TokenType::LOGINACK {
147 Token::LoginAck(parse_login_ack(read_len_prefixed_token(&mut input)?)?)
148 } else if token_type == TokenType::ERROR {
149 Token::Error(parse_error(read_len_prefixed_token(&mut input)?)?)
150 } else if token_type == TokenType::INFO {
151 let _ = read_len_prefixed_token(&mut input)?;
152 continue;
153 } else if token_type == TokenType::ENVCHANGE {
154 Token::EnvChange(parse_env_change(read_len_prefixed_token(&mut input)?)?)
155 } else if token_type == TokenType::DONE {
156 Token::Done(parse_done(&mut input)?)
157 } else {
158 return Err(TokenParseError::UnsupportedToken(token_type.code()));
159 };
160
161 tokens.push(token);
162 }
163
164 Ok(tokens)
165}
166
167pub fn parse_login_response(input: &[u8]) -> Result<LoginResponse, TokenParseError> {
169 let tokens = parse_tokens(input)?;
170 let mut login_ack = None;
171 let mut done = false;
172 let mut env_changes = Vec::new();
173
174 for token in tokens {
175 match token {
176 Token::LoginAck(ack) => login_ack = Some(ack),
177 Token::Error(error) => return Ok(LoginResponse::ServerError(error)),
178 Token::Done(_) => done = true,
179 Token::EnvChange(change) => env_changes.push(change),
180 }
181 }
182
183 let login_ack = login_ack.ok_or(TokenParseError::MissingLoginAck)?;
184 if !done {
185 return Err(TokenParseError::MissingDone);
186 }
187
188 Ok(LoginResponse::Success {
189 login_ack,
190 env_changes,
191 })
192}
193
194fn parse_login_ack(mut input: &[u8]) -> Result<LoginAck, TokenParseError> {
195 let interface = read_u8(&mut input)?;
196 let tds_version = read_u32_be(&mut input)?;
197 let program_name = read_b_varchar(&mut input)?;
198 let major_version = read_u8(&mut input)?;
199 let minor_version = read_u8(&mut input)?;
200 let build_number_high = read_u8(&mut input)?;
201 let build_number_low = read_u8(&mut input)?;
202 expect_empty(input)?;
203
204 Ok(LoginAck {
205 interface,
206 tds_version,
207 program_name,
208 major_version,
209 minor_version,
210 build_number_high,
211 build_number_low,
212 })
213}
214
215pub(crate) fn parse_server_error(input: &[u8]) -> Result<ServerError, TokenParseError> {
216 parse_error(input)
217}
218
219fn parse_error(mut input: &[u8]) -> Result<ServerError, TokenParseError> {
220 let number = read_i32_le(&mut input)?;
221 let state = read_u8(&mut input)?;
222 let class = read_u8(&mut input)?;
223 let message = read_us_varchar(&mut input)?;
224 let server_name = read_b_varchar(&mut input)?;
225 let procedure_name = read_b_varchar(&mut input)?;
226 let line_number = read_u32_le(&mut input)?;
227 expect_empty(input)?;
228
229 Ok(ServerError {
230 number,
231 state,
232 class,
233 message,
234 server_name,
235 procedure_name,
236 line_number,
237 })
238}
239
240pub(crate) fn parse_env_change(mut input: &[u8]) -> Result<EnvChange, TokenParseError> {
241 let change_type = read_u8(&mut input)?;
242
243 Ok(match change_type {
244 1 => EnvChange::Database(read_b_varchar(&mut input)?),
245 2 => EnvChange::Language(read_b_varchar(&mut input)?),
246 3 => EnvChange::CharacterSet(read_b_varchar(&mut input)?),
247 4 => {
248 let size = read_b_varchar(&mut input)?;
249 EnvChange::PacketSize(
250 size.parse()
251 .map_err(|_| TokenParseError::InvalidEnvChangePacketSize(size))?,
252 )
253 }
254 5 => EnvChange::UnicodeDataSortingLocalId(read_b_varchar(&mut input)?),
255 6 => EnvChange::UnicodeDataSortingComparisonFlags(read_b_varchar(&mut input)?),
256 7 => EnvChange::SqlCollation(read_b_varbyte(&mut input)?.to_vec()),
257 8 => EnvChange::BeginTransaction(read_b_varbyte_u64_le(&mut input)?),
258 9 => EnvChange::CommitTransaction(read_transaction_end_descriptor(&mut input)?),
259 10 => EnvChange::RollbackTransaction(read_transaction_end_descriptor(&mut input)?),
260 _ => EnvChange::Ignored {
261 change_type,
262 data: input.to_vec(),
263 },
264 })
265}
266
267fn parse_done(input: &mut &[u8]) -> Result<Done, TokenParseError> {
268 Ok(Done {
269 status: read_u16_le(input)?,
270 current_command: read_u16_le(input)?,
271 row_count: read_u64_le(input)?,
272 })
273}
274
275fn read_len_prefixed_token<'a>(input: &mut &'a [u8]) -> Result<&'a [u8], TokenParseError> {
276 let len = usize::from(read_u16_le(input)?);
277 take(input, len)
278}
279
280fn read_b_varchar(input: &mut &[u8]) -> Result<String, TokenParseError> {
281 let len_chars = usize::from(read_u8(input)?);
282 read_utf16_string(input, len_chars)
283}
284
285fn read_b_varbyte<'a>(input: &mut &'a [u8]) -> Result<&'a [u8], TokenParseError> {
286 let len = usize::from(read_u8(input)?);
287 take(input, len)
288}
289
290fn read_us_varchar(input: &mut &[u8]) -> Result<String, TokenParseError> {
291 let len_chars = usize::from(read_u16_le(input)?);
292 read_utf16_string(input, len_chars)
293}
294
295fn read_utf16_string(input: &mut &[u8], len_chars: usize) -> Result<String, TokenParseError> {
296 let len_bytes = len_chars
297 .checked_mul(2)
298 .ok_or(TokenParseError::LengthOverflow)?;
299 let bytes = take(input, len_bytes)?;
300 let units = bytes
301 .chunks_exact(2)
302 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]));
303
304 String::from_utf16(&units.collect::<Vec<_>>()).map_err(|_| TokenParseError::InvalidUtf16)
305}
306
307fn read_u8(input: &mut &[u8]) -> Result<u8, TokenParseError> {
308 let bytes = take(input, 1)?;
309 Ok(bytes[0])
310}
311
312fn read_u16_le(input: &mut &[u8]) -> Result<u16, TokenParseError> {
313 let bytes = take(input, 2)?;
314 Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
315}
316
317fn read_i32_le(input: &mut &[u8]) -> Result<i32, TokenParseError> {
318 let bytes = take(input, 4)?;
319 Ok(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
320}
321
322fn read_u32_le(input: &mut &[u8]) -> Result<u32, TokenParseError> {
323 let bytes = take(input, 4)?;
324 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
325}
326
327fn read_u32_be(input: &mut &[u8]) -> Result<u32, TokenParseError> {
328 let bytes = take(input, 4)?;
329 Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
330}
331
332fn read_u64_le(input: &mut &[u8]) -> Result<u64, TokenParseError> {
333 let bytes = take(input, 8)?;
334 Ok(u64::from_le_bytes([
335 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
336 ]))
337}
338
339fn read_b_varbyte_u64_le(input: &mut &[u8]) -> Result<u64, TokenParseError> {
340 let mut bytes = read_b_varbyte(input)?;
341 read_u64_le(&mut bytes)
342}
343
344fn read_transaction_end_descriptor(input: &mut &[u8]) -> Result<u64, TokenParseError> {
345 let _new_descriptor = read_b_varbyte(input)?;
346 read_u64_le(input)
347}
348
349fn take<'a>(input: &mut &'a [u8], len: usize) -> Result<&'a [u8], TokenParseError> {
350 let bytes = input.get(..len).ok_or(TokenParseError::UnexpectedEof)?;
351 *input = &input[len..];
352 Ok(bytes)
353}
354
355fn expect_empty(input: &[u8]) -> Result<(), TokenParseError> {
356 if input.is_empty() {
357 Ok(())
358 } else {
359 Err(TokenParseError::TrailingTokenBytes(input.len()))
360 }
361}
362
363#[derive(Debug, Error, PartialEq, Eq)]
365pub enum TokenParseError {
366 #[error("TDS token stream ended before the current token was complete")]
368 UnexpectedEof,
369 #[error("TDS token length overflowed")]
371 LengthOverflow,
372 #[error("TDS token contained invalid UTF-16 string data")]
374 InvalidUtf16,
375 #[error("unsupported TDS token 0x{0:02x}")]
377 UnsupportedToken(u8),
378 #[error("TDS token contained {0} trailing bytes")]
380 TrailingTokenBytes(usize),
381 #[error("TDS login response did not include LOGINACK")]
383 MissingLoginAck,
384 #[error("TDS login response did not include DONE")]
386 MissingDone,
387 #[error("TDS ENVCHANGE packet size `{0}` is not a valid integer")]
389 InvalidEnvChangePacketSize(String),
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn parses_login_ack_envchange_and_done_as_success() {
398 let bytes = [
399 login_ack("Microsoft SQL Server"),
400 env_change(
401 4,
402 &[
403 4, b'4', 0, b'0', 0, b'9', 0, b'6', 0, 3, b'5', 0, b'1', 0, b'2', 0,
404 ],
405 ),
406 done(0, 0, 0),
407 ]
408 .concat();
409
410 let tokens = parse_tokens(&bytes).unwrap();
411
412 assert_eq!(3, tokens.len());
413 assert_eq!(
414 LoginResponse::Success {
415 login_ack: LoginAck {
416 interface: 1,
417 tds_version: 0x7400_0004,
418 program_name: "Microsoft SQL Server".to_owned(),
419 major_version: 16,
420 minor_version: 0,
421 build_number_high: 0x10,
422 build_number_low: 0x4a,
423 },
424 env_changes: vec![EnvChange::PacketSize(4096)],
425 },
426 parse_login_response(&bytes).unwrap()
427 );
428 }
429
430 #[test]
431 fn parses_transaction_envchanges() {
432 assert_eq!(
433 EnvChange::BeginTransaction(0x0102_0304_0506_0708),
434 parse_env_change(&[8, 8, 8, 7, 6, 5, 4, 3, 2, 1,]).unwrap()
435 );
436 assert_eq!(
437 EnvChange::CommitTransaction(0x1112_1314_1516_1718),
438 parse_env_change(&[9, 0, 0x18, 0x17, 0x16, 0x15, 0x14, 0x13, 0x12, 0x11,]).unwrap()
439 );
440 assert_eq!(
441 EnvChange::RollbackTransaction(0x2122_2324_2526_2728),
442 parse_env_change(&[10, 0, 0x28, 0x27, 0x26, 0x25, 0x24, 0x23, 0x22, 0x21,]).unwrap()
443 );
444 }
445
446 #[test]
447 fn reports_server_error_before_done() {
448 let bytes = [
449 error(18456, 1, 14, "Login failed", "dbhost", "", 1),
450 done(0x0002, 0, 0),
451 ]
452 .concat();
453
454 assert_eq!(
455 LoginResponse::ServerError(ServerError {
456 number: 18456,
457 state: 1,
458 class: 14,
459 message: "Login failed".to_owned(),
460 server_name: "dbhost".to_owned(),
461 procedure_name: String::new(),
462 line_number: 1,
463 }),
464 parse_login_response(&bytes).unwrap()
465 );
466 }
467
468 #[test]
469 fn skips_info_tokens_during_login() {
470 let bytes = [
471 len_prefixed(
472 TokenType::INFO,
473 error_body(5701, 1, 10, "Changed database", "", "", 1),
474 ),
475 login_ack("Microsoft SQL Server"),
476 done(0, 0, 0),
477 ]
478 .concat();
479
480 assert!(matches!(
481 parse_login_response(&bytes).unwrap(),
482 LoginResponse::Success { .. }
483 ));
484 }
485
486 #[test]
487 fn rejects_truncated_login_ack() {
488 let bytes = [TokenType::LOGINACK.code(), 10, 0, 1, 0x74];
489
490 assert_eq!(
491 TokenParseError::UnexpectedEof,
492 parse_tokens(&bytes).unwrap_err()
493 );
494 }
495
496 #[test]
497 fn rejects_unsupported_tokens_in_bounded_parser() {
498 let bytes = [0xac, 0, 0];
499
500 assert_eq!(
501 TokenParseError::UnsupportedToken(0xac),
502 parse_tokens(&bytes).unwrap_err()
503 );
504 }
505
506 #[test]
507 fn login_response_requires_login_ack_when_no_error_is_present() {
508 let bytes = done(0, 0, 0);
509
510 assert_eq!(
511 TokenParseError::MissingLoginAck,
512 parse_login_response(&bytes).unwrap_err()
513 );
514 }
515
516 #[test]
517 fn login_response_success_requires_done() {
518 let bytes = login_ack("Microsoft SQL Server");
519
520 assert_eq!(
521 TokenParseError::MissingDone,
522 parse_login_response(&bytes).unwrap_err()
523 );
524 }
525
526 fn login_ack(program_name: &str) -> Vec<u8> {
527 let mut body = Vec::new();
528 body.push(1);
529 body.extend_from_slice(&0x7400_0004u32.to_be_bytes());
530 push_b_varchar(&mut body, program_name);
531 body.extend_from_slice(&[16, 0, 0x10, 0x4a]);
532
533 len_prefixed(TokenType::LOGINACK, body)
534 }
535
536 fn error(
537 number: i32,
538 state: u8,
539 class: u8,
540 message: &str,
541 server_name: &str,
542 procedure_name: &str,
543 line_number: u32,
544 ) -> Vec<u8> {
545 len_prefixed(
546 TokenType::ERROR,
547 error_body(
548 number,
549 state,
550 class,
551 message,
552 server_name,
553 procedure_name,
554 line_number,
555 ),
556 )
557 }
558
559 fn error_body(
560 number: i32,
561 state: u8,
562 class: u8,
563 message: &str,
564 server_name: &str,
565 procedure_name: &str,
566 line_number: u32,
567 ) -> Vec<u8> {
568 let mut body = Vec::new();
569 body.extend_from_slice(&number.to_le_bytes());
570 body.push(state);
571 body.push(class);
572 push_us_varchar(&mut body, message);
573 push_b_varchar(&mut body, server_name);
574 push_b_varchar(&mut body, procedure_name);
575 body.extend_from_slice(&line_number.to_le_bytes());
576 body
577 }
578
579 fn env_change(change_type: u8, data: &[u8]) -> Vec<u8> {
580 let mut body = Vec::with_capacity(1 + data.len());
581 body.push(change_type);
582 body.extend_from_slice(data);
583
584 len_prefixed(TokenType::ENVCHANGE, body)
585 }
586
587 fn done(status: u16, current_command: u16, row_count: u64) -> Vec<u8> {
588 let mut out = Vec::new();
589 out.push(TokenType::DONE.code());
590 out.extend_from_slice(&status.to_le_bytes());
591 out.extend_from_slice(¤t_command.to_le_bytes());
592 out.extend_from_slice(&row_count.to_le_bytes());
593 out
594 }
595
596 fn len_prefixed(token_type: TokenType, body: Vec<u8>) -> Vec<u8> {
597 let mut out = Vec::new();
598 out.push(token_type.code());
599 out.extend_from_slice(
600 &u16::try_from(body.len())
601 .expect("test token body fits in u16")
602 .to_le_bytes(),
603 );
604 out.extend_from_slice(&body);
605 out
606 }
607
608 fn push_b_varchar(out: &mut Vec<u8>, value: &str) {
609 out.push(u8::try_from(value.encode_utf16().count()).expect("test string fits in u8"));
610 push_utf16(out, value);
611 }
612
613 fn push_us_varchar(out: &mut Vec<u8>, value: &str) {
614 out.extend_from_slice(
615 &u16::try_from(value.encode_utf16().count())
616 .expect("test string fits in u16")
617 .to_le_bytes(),
618 );
619 push_utf16(out, value);
620 }
621
622 fn push_utf16(out: &mut Vec<u8>, value: &str) {
623 for unit in value.encode_utf16() {
624 out.extend_from_slice(&unit.to_le_bytes());
625 }
626 }
627}