Skip to main content

sqlx_sqlserver/protocol/
token.rs

1use thiserror::Error;
2
3/// TDS tabular-result token type byte.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub struct TokenType(u8);
6
7impl TokenType {
8    /// ERROR token.
9    pub const ERROR: Self = Self(0xaa);
10    /// INFO token.
11    pub const INFO: Self = Self(0xab);
12    /// LOGINACK token.
13    pub const LOGINACK: Self = Self(0xad);
14    /// ENVCHANGE token.
15    pub const ENVCHANGE: Self = Self(0xe3);
16    /// DONE token.
17    pub const DONE: Self = Self(0xfd);
18
19    /// Returns the raw token type byte.
20    pub const fn code(self) -> u8 {
21        self.0
22    }
23}
24
25impl From<u8> for TokenType {
26    fn from(value: u8) -> Self {
27        Self(value)
28    }
29}
30
31/// Parsed subset of TDS tabular-result tokens needed during login.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Token {
34    /// Server accepted the LOGIN7 request.
35    LoginAck(LoginAck),
36    /// Server returned an error.
37    Error(ServerError),
38    /// Server reported a connection environment change.
39    EnvChange(EnvChange),
40    /// Server completed the response stream.
41    Done(Done),
42}
43
44/// LOGINACK token data.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct LoginAck {
47    /// Accepted server interface.
48    pub interface: u8,
49    /// TDS protocol version selected by the server.
50    pub tds_version: u32,
51    /// Server program name.
52    pub program_name: String,
53    /// Server major version.
54    pub major_version: u8,
55    /// Server minor version.
56    pub minor_version: u8,
57    /// High byte of the server build number.
58    pub build_number_high: u8,
59    /// Low byte of the server build number.
60    pub build_number_low: u8,
61}
62
63/// ERROR token data.
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct ServerError {
66    /// SQL Server error number.
67    pub number: i32,
68    /// Error state.
69    pub state: u8,
70    /// Error class / severity.
71    pub class: u8,
72    /// Human-readable error message.
73    pub message: String,
74    /// Server name.
75    pub server_name: String,
76    /// Stored procedure name, when present.
77    pub procedure_name: String,
78    /// Line number reported by the server.
79    pub line_number: u32,
80}
81
82/// ENVCHANGE token data.
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum EnvChange {
85    /// Current database changed.
86    Database(String),
87    /// Current language changed.
88    Language(String),
89    /// Current character set changed.
90    CharacterSet(String),
91    /// Server accepted a packet size.
92    PacketSize(u32),
93    /// Unicode sorting locale changed.
94    UnicodeDataSortingLocalId(String),
95    /// Unicode sorting comparison flags changed.
96    UnicodeDataSortingComparisonFlags(String),
97    /// SQL collation changed.
98    SqlCollation(Vec<u8>),
99    /// Server started a transaction and returned its descriptor.
100    BeginTransaction(u64),
101    /// Server committed a transaction.
102    CommitTransaction(u64),
103    /// Server rolled back a transaction.
104    RollbackTransaction(u64),
105    /// ENVCHANGE type not currently interpreted by this driver.
106    Ignored {
107        /// Environment change type byte.
108        change_type: u8,
109        /// Raw change payload after the type byte.
110        data: Vec<u8>,
111    },
112}
113
114/// DONE token data.
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub struct Done {
117    /// DONE status bit field.
118    pub status: u16,
119    /// Current command token.
120    pub current_command: u16,
121    /// Rows affected, valid when the DONE_COUNT status bit is set.
122    pub row_count: u64,
123}
124
125/// Result of interpreting a LOGIN7 response token stream.
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub enum LoginResponse {
128    /// LOGINACK was received and no ERROR token was present.
129    Success {
130        /// Accepted login metadata returned by the server.
131        login_ack: LoginAck,
132        /// Environment changes received during login.
133        env_changes: Vec<EnvChange>,
134    },
135    /// Server returned at least one ERROR token.
136    ServerError(ServerError),
137}
138
139/// Parses the bounded token subset currently needed from a tabular-result payload.
140pub fn parse_tokens(mut input: &[u8]) -> Result<Vec<Token>, TokenParseError> {
141    let mut tokens = Vec::new();
142
143    while !input.is_empty() {
144        let token_type = TokenType::from(read_u8(&mut input)?);
145
146        let token = if token_type == TokenType::LOGINACK {
147            Token::LoginAck(parse_login_ack(read_len_prefixed_token(&mut input)?)?)
148        } else if token_type == TokenType::ERROR {
149            Token::Error(parse_error(read_len_prefixed_token(&mut input)?)?)
150        } else if token_type == TokenType::INFO {
151            let _ = read_len_prefixed_token(&mut input)?;
152            continue;
153        } else if token_type == TokenType::ENVCHANGE {
154            Token::EnvChange(parse_env_change(read_len_prefixed_token(&mut input)?)?)
155        } else if token_type == TokenType::DONE {
156            Token::Done(parse_done(&mut input)?)
157        } else {
158            return Err(TokenParseError::UnsupportedToken(token_type.code()));
159        };
160
161        tokens.push(token);
162    }
163
164    Ok(tokens)
165}
166
167/// Interprets a LOGIN7 response token stream as success or server failure.
168pub fn parse_login_response(input: &[u8]) -> Result<LoginResponse, TokenParseError> {
169    let tokens = parse_tokens(input)?;
170    let mut login_ack = None;
171    let mut done = false;
172    let mut env_changes = Vec::new();
173
174    for token in tokens {
175        match token {
176            Token::LoginAck(ack) => login_ack = Some(ack),
177            Token::Error(error) => return Ok(LoginResponse::ServerError(error)),
178            Token::Done(_) => done = true,
179            Token::EnvChange(change) => env_changes.push(change),
180        }
181    }
182
183    let login_ack = login_ack.ok_or(TokenParseError::MissingLoginAck)?;
184    if !done {
185        return Err(TokenParseError::MissingDone);
186    }
187
188    Ok(LoginResponse::Success {
189        login_ack,
190        env_changes,
191    })
192}
193
194fn parse_login_ack(mut input: &[u8]) -> Result<LoginAck, TokenParseError> {
195    let interface = read_u8(&mut input)?;
196    let tds_version = read_u32_be(&mut input)?;
197    let program_name = read_b_varchar(&mut input)?;
198    let major_version = read_u8(&mut input)?;
199    let minor_version = read_u8(&mut input)?;
200    let build_number_high = read_u8(&mut input)?;
201    let build_number_low = read_u8(&mut input)?;
202    expect_empty(input)?;
203
204    Ok(LoginAck {
205        interface,
206        tds_version,
207        program_name,
208        major_version,
209        minor_version,
210        build_number_high,
211        build_number_low,
212    })
213}
214
215pub(crate) fn parse_server_error(input: &[u8]) -> Result<ServerError, TokenParseError> {
216    parse_error(input)
217}
218
219fn parse_error(mut input: &[u8]) -> Result<ServerError, TokenParseError> {
220    let number = read_i32_le(&mut input)?;
221    let state = read_u8(&mut input)?;
222    let class = read_u8(&mut input)?;
223    let message = read_us_varchar(&mut input)?;
224    let server_name = read_b_varchar(&mut input)?;
225    let procedure_name = read_b_varchar(&mut input)?;
226    let line_number = read_u32_le(&mut input)?;
227    expect_empty(input)?;
228
229    Ok(ServerError {
230        number,
231        state,
232        class,
233        message,
234        server_name,
235        procedure_name,
236        line_number,
237    })
238}
239
240pub(crate) fn parse_env_change(mut input: &[u8]) -> Result<EnvChange, TokenParseError> {
241    let change_type = read_u8(&mut input)?;
242
243    Ok(match change_type {
244        1 => EnvChange::Database(read_b_varchar(&mut input)?),
245        2 => EnvChange::Language(read_b_varchar(&mut input)?),
246        3 => EnvChange::CharacterSet(read_b_varchar(&mut input)?),
247        4 => {
248            let size = read_b_varchar(&mut input)?;
249            EnvChange::PacketSize(
250                size.parse()
251                    .map_err(|_| TokenParseError::InvalidEnvChangePacketSize(size))?,
252            )
253        }
254        5 => EnvChange::UnicodeDataSortingLocalId(read_b_varchar(&mut input)?),
255        6 => EnvChange::UnicodeDataSortingComparisonFlags(read_b_varchar(&mut input)?),
256        7 => EnvChange::SqlCollation(read_b_varbyte(&mut input)?.to_vec()),
257        8 => EnvChange::BeginTransaction(read_b_varbyte_u64_le(&mut input)?),
258        9 => EnvChange::CommitTransaction(read_transaction_end_descriptor(&mut input)?),
259        10 => EnvChange::RollbackTransaction(read_transaction_end_descriptor(&mut input)?),
260        _ => EnvChange::Ignored {
261            change_type,
262            data: input.to_vec(),
263        },
264    })
265}
266
267fn parse_done(input: &mut &[u8]) -> Result<Done, TokenParseError> {
268    Ok(Done {
269        status: read_u16_le(input)?,
270        current_command: read_u16_le(input)?,
271        row_count: read_u64_le(input)?,
272    })
273}
274
275fn read_len_prefixed_token<'a>(input: &mut &'a [u8]) -> Result<&'a [u8], TokenParseError> {
276    let len = usize::from(read_u16_le(input)?);
277    take(input, len)
278}
279
280fn read_b_varchar(input: &mut &[u8]) -> Result<String, TokenParseError> {
281    let len_chars = usize::from(read_u8(input)?);
282    read_utf16_string(input, len_chars)
283}
284
285fn read_b_varbyte<'a>(input: &mut &'a [u8]) -> Result<&'a [u8], TokenParseError> {
286    let len = usize::from(read_u8(input)?);
287    take(input, len)
288}
289
290fn read_us_varchar(input: &mut &[u8]) -> Result<String, TokenParseError> {
291    let len_chars = usize::from(read_u16_le(input)?);
292    read_utf16_string(input, len_chars)
293}
294
295fn read_utf16_string(input: &mut &[u8], len_chars: usize) -> Result<String, TokenParseError> {
296    let len_bytes = len_chars
297        .checked_mul(2)
298        .ok_or(TokenParseError::LengthOverflow)?;
299    let bytes = take(input, len_bytes)?;
300    let units = bytes
301        .chunks_exact(2)
302        .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]));
303
304    String::from_utf16(&units.collect::<Vec<_>>()).map_err(|_| TokenParseError::InvalidUtf16)
305}
306
307fn read_u8(input: &mut &[u8]) -> Result<u8, TokenParseError> {
308    let bytes = take(input, 1)?;
309    Ok(bytes[0])
310}
311
312fn read_u16_le(input: &mut &[u8]) -> Result<u16, TokenParseError> {
313    let bytes = take(input, 2)?;
314    Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
315}
316
317fn read_i32_le(input: &mut &[u8]) -> Result<i32, TokenParseError> {
318    let bytes = take(input, 4)?;
319    Ok(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
320}
321
322fn read_u32_le(input: &mut &[u8]) -> Result<u32, TokenParseError> {
323    let bytes = take(input, 4)?;
324    Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
325}
326
327fn read_u32_be(input: &mut &[u8]) -> Result<u32, TokenParseError> {
328    let bytes = take(input, 4)?;
329    Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
330}
331
332fn read_u64_le(input: &mut &[u8]) -> Result<u64, TokenParseError> {
333    let bytes = take(input, 8)?;
334    Ok(u64::from_le_bytes([
335        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
336    ]))
337}
338
339fn read_b_varbyte_u64_le(input: &mut &[u8]) -> Result<u64, TokenParseError> {
340    let mut bytes = read_b_varbyte(input)?;
341    read_u64_le(&mut bytes)
342}
343
344fn read_transaction_end_descriptor(input: &mut &[u8]) -> Result<u64, TokenParseError> {
345    let _new_descriptor = read_b_varbyte(input)?;
346    read_u64_le(input)
347}
348
349fn take<'a>(input: &mut &'a [u8], len: usize) -> Result<&'a [u8], TokenParseError> {
350    let bytes = input.get(..len).ok_or(TokenParseError::UnexpectedEof)?;
351    *input = &input[len..];
352    Ok(bytes)
353}
354
355fn expect_empty(input: &[u8]) -> Result<(), TokenParseError> {
356    if input.is_empty() {
357        Ok(())
358    } else {
359        Err(TokenParseError::TrailingTokenBytes(input.len()))
360    }
361}
362
363/// Error returned while parsing a bounded TDS token stream.
364#[derive(Debug, Error, PartialEq, Eq)]
365pub enum TokenParseError {
366    /// The token stream ended in the middle of a token.
367    #[error("TDS token stream ended before the current token was complete")]
368    UnexpectedEof,
369    /// A token advertised a length that cannot be represented safely.
370    #[error("TDS token length overflowed")]
371    LengthOverflow,
372    /// A token contained invalid UTF-16 string data.
373    #[error("TDS token contained invalid UTF-16 string data")]
374    InvalidUtf16,
375    /// This bounded parser does not yet understand the token type.
376    #[error("unsupported TDS token 0x{0:02x}")]
377    UnsupportedToken(u8),
378    /// A length-prefixed token contained extra bytes after its expected fields.
379    #[error("TDS token contained {0} trailing bytes")]
380    TrailingTokenBytes(usize),
381    /// A LOGIN7 response did not include LOGINACK.
382    #[error("TDS login response did not include LOGINACK")]
383    MissingLoginAck,
384    /// A LOGIN7 response did not include DONE.
385    #[error("TDS login response did not include DONE")]
386    MissingDone,
387    /// ENVCHANGE packet size was not a decimal integer.
388    #[error("TDS ENVCHANGE packet size `{0}` is not a valid integer")]
389    InvalidEnvChangePacketSize(String),
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn parses_login_ack_envchange_and_done_as_success() {
398        let bytes = [
399            login_ack("Microsoft SQL Server"),
400            env_change(
401                4,
402                &[
403                    4, b'4', 0, b'0', 0, b'9', 0, b'6', 0, 3, b'5', 0, b'1', 0, b'2', 0,
404                ],
405            ),
406            done(0, 0, 0),
407        ]
408        .concat();
409
410        let tokens = parse_tokens(&bytes).unwrap();
411
412        assert_eq!(3, tokens.len());
413        assert_eq!(
414            LoginResponse::Success {
415                login_ack: LoginAck {
416                    interface: 1,
417                    tds_version: 0x7400_0004,
418                    program_name: "Microsoft SQL Server".to_owned(),
419                    major_version: 16,
420                    minor_version: 0,
421                    build_number_high: 0x10,
422                    build_number_low: 0x4a,
423                },
424                env_changes: vec![EnvChange::PacketSize(4096)],
425            },
426            parse_login_response(&bytes).unwrap()
427        );
428    }
429
430    #[test]
431    fn parses_transaction_envchanges() {
432        assert_eq!(
433            EnvChange::BeginTransaction(0x0102_0304_0506_0708),
434            parse_env_change(&[8, 8, 8, 7, 6, 5, 4, 3, 2, 1,]).unwrap()
435        );
436        assert_eq!(
437            EnvChange::CommitTransaction(0x1112_1314_1516_1718),
438            parse_env_change(&[9, 0, 0x18, 0x17, 0x16, 0x15, 0x14, 0x13, 0x12, 0x11,]).unwrap()
439        );
440        assert_eq!(
441            EnvChange::RollbackTransaction(0x2122_2324_2526_2728),
442            parse_env_change(&[10, 0, 0x28, 0x27, 0x26, 0x25, 0x24, 0x23, 0x22, 0x21,]).unwrap()
443        );
444    }
445
446    #[test]
447    fn reports_server_error_before_done() {
448        let bytes = [
449            error(18456, 1, 14, "Login failed", "dbhost", "", 1),
450            done(0x0002, 0, 0),
451        ]
452        .concat();
453
454        assert_eq!(
455            LoginResponse::ServerError(ServerError {
456                number: 18456,
457                state: 1,
458                class: 14,
459                message: "Login failed".to_owned(),
460                server_name: "dbhost".to_owned(),
461                procedure_name: String::new(),
462                line_number: 1,
463            }),
464            parse_login_response(&bytes).unwrap()
465        );
466    }
467
468    #[test]
469    fn skips_info_tokens_during_login() {
470        let bytes = [
471            len_prefixed(
472                TokenType::INFO,
473                error_body(5701, 1, 10, "Changed database", "", "", 1),
474            ),
475            login_ack("Microsoft SQL Server"),
476            done(0, 0, 0),
477        ]
478        .concat();
479
480        assert!(matches!(
481            parse_login_response(&bytes).unwrap(),
482            LoginResponse::Success { .. }
483        ));
484    }
485
486    #[test]
487    fn rejects_truncated_login_ack() {
488        let bytes = [TokenType::LOGINACK.code(), 10, 0, 1, 0x74];
489
490        assert_eq!(
491            TokenParseError::UnexpectedEof,
492            parse_tokens(&bytes).unwrap_err()
493        );
494    }
495
496    #[test]
497    fn rejects_unsupported_tokens_in_bounded_parser() {
498        let bytes = [0xac, 0, 0];
499
500        assert_eq!(
501            TokenParseError::UnsupportedToken(0xac),
502            parse_tokens(&bytes).unwrap_err()
503        );
504    }
505
506    #[test]
507    fn login_response_requires_login_ack_when_no_error_is_present() {
508        let bytes = done(0, 0, 0);
509
510        assert_eq!(
511            TokenParseError::MissingLoginAck,
512            parse_login_response(&bytes).unwrap_err()
513        );
514    }
515
516    #[test]
517    fn login_response_success_requires_done() {
518        let bytes = login_ack("Microsoft SQL Server");
519
520        assert_eq!(
521            TokenParseError::MissingDone,
522            parse_login_response(&bytes).unwrap_err()
523        );
524    }
525
526    fn login_ack(program_name: &str) -> Vec<u8> {
527        let mut body = Vec::new();
528        body.push(1);
529        body.extend_from_slice(&0x7400_0004u32.to_be_bytes());
530        push_b_varchar(&mut body, program_name);
531        body.extend_from_slice(&[16, 0, 0x10, 0x4a]);
532
533        len_prefixed(TokenType::LOGINACK, body)
534    }
535
536    fn error(
537        number: i32,
538        state: u8,
539        class: u8,
540        message: &str,
541        server_name: &str,
542        procedure_name: &str,
543        line_number: u32,
544    ) -> Vec<u8> {
545        len_prefixed(
546            TokenType::ERROR,
547            error_body(
548                number,
549                state,
550                class,
551                message,
552                server_name,
553                procedure_name,
554                line_number,
555            ),
556        )
557    }
558
559    fn error_body(
560        number: i32,
561        state: u8,
562        class: u8,
563        message: &str,
564        server_name: &str,
565        procedure_name: &str,
566        line_number: u32,
567    ) -> Vec<u8> {
568        let mut body = Vec::new();
569        body.extend_from_slice(&number.to_le_bytes());
570        body.push(state);
571        body.push(class);
572        push_us_varchar(&mut body, message);
573        push_b_varchar(&mut body, server_name);
574        push_b_varchar(&mut body, procedure_name);
575        body.extend_from_slice(&line_number.to_le_bytes());
576        body
577    }
578
579    fn env_change(change_type: u8, data: &[u8]) -> Vec<u8> {
580        let mut body = Vec::with_capacity(1 + data.len());
581        body.push(change_type);
582        body.extend_from_slice(data);
583
584        len_prefixed(TokenType::ENVCHANGE, body)
585    }
586
587    fn done(status: u16, current_command: u16, row_count: u64) -> Vec<u8> {
588        let mut out = Vec::new();
589        out.push(TokenType::DONE.code());
590        out.extend_from_slice(&status.to_le_bytes());
591        out.extend_from_slice(&current_command.to_le_bytes());
592        out.extend_from_slice(&row_count.to_le_bytes());
593        out
594    }
595
596    fn len_prefixed(token_type: TokenType, body: Vec<u8>) -> Vec<u8> {
597        let mut out = Vec::new();
598        out.push(token_type.code());
599        out.extend_from_slice(
600            &u16::try_from(body.len())
601                .expect("test token body fits in u16")
602                .to_le_bytes(),
603        );
604        out.extend_from_slice(&body);
605        out
606    }
607
608    fn push_b_varchar(out: &mut Vec<u8>, value: &str) {
609        out.push(u8::try_from(value.encode_utf16().count()).expect("test string fits in u8"));
610        push_utf16(out, value);
611    }
612
613    fn push_us_varchar(out: &mut Vec<u8>, value: &str) {
614        out.extend_from_slice(
615            &u16::try_from(value.encode_utf16().count())
616                .expect("test string fits in u16")
617                .to_le_bytes(),
618        );
619        push_utf16(out, value);
620    }
621
622    fn push_utf16(out: &mut Vec<u8>, value: &str) {
623        for unit in value.encode_utf16() {
624            out.extend_from_slice(&unit.to_le_bytes());
625        }
626    }
627}