1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
use std::collections::HashMap;

// Should we use a seperate struct to represent domain names? Makes easier to not f up


/// Parse a DNS domain name from a raw DNS packet, taking into account DNS message compression.

pub fn parse_domain_name(
    buf: &Vec<u8>,
    start: usize,
    limit: usize,
) -> Result<(String, usize), String> {
    if buf.is_empty() {
        return Err("buf is empty".into());
    } else if start >= buf.len() || start >= limit {
        return Err("start is past buf's end or limit".into());
    }

    let mut domain_name = String::new();
    let mut curr = start;

    while (curr as usize) < buf.len() {
        let label_length = buf[curr] as usize;

        if label_length == 0
        //null label = root zone

        {
            curr = curr + 1;
            break;
        } else if (label_length & 0xC0usize) == 0xC0usize
        //message compression pointer

        {
            if curr + 1 >= buf.len() {
                return Err("domain name pointer appears at end of buf".into());
            }

            let offset = (label_length & 0x3Fusize) << 8 | buf[curr + 1] as usize & 0xFFusize;
            let res = parse_domain_name(buf, offset, curr);

            match res {
                Ok((domain_name_suffix, _)) => {
                    domain_name.push_str(domain_name_suffix.as_str());
                }
                Err(error) => {
                    return Err(format!("error parsing domain name pointer: {}", error).into());
                }
            };

            curr += 2;

            break;
        } else if (label_length + curr) >= buf.len() {
            return Err("domain name label length octet too large".into());
        } else {
            let mut label = String::new();
            for ch in buf[curr + 1..curr + label_length + 1].iter() {
                label.push(*ch as char);
            }
            domain_name.push_str(label.as_str());
            domain_name.push('.');

            curr += label_length + 1;
        }
    }

    if !is_domain_name_valid(&domain_name) {
        return Err(format!("invalid domain name: {}", domain_name));
    }

    Ok((domain_name, curr))
}

/// Returns true if domain_name represents a valid DNS domain name.

pub fn is_domain_name_valid(domain_name: &String) -> bool {
    if domain_name == "." {
        return true;
    } else if domain_name.is_empty() || domain_name.starts_with(".") {
        return false;
    }

    /*
     IDNA was proposed in RFC 3490, but it only applies to application code. We are not that.
     We won't handle punycode, but instead just ensure that the domain names are valid per
     RFC 1035.

     Each label must:
      1. start with a letter
      2. End with a letter or digit
      3. Have as interior characters only letters, digits, and hyphen.
      4. Must be 63 characters or less. (This means first two bits of all labels are always 0).
    */
    for label in domain_name.split_terminator('.') {
        if label.len() > 63
            || !label.starts_with(|c: char| c.is_ascii_alphabetic())
            || !label.ends_with(|c: char| c.is_ascii_alphanumeric())
            || label.contains(|c: char| c != '-' && !c.is_ascii_alphanumeric())
        {
            return false;
        }
    }

    domain_name.ends_with('.')
}

/// Will attempt to massage a given domain name into a valid one by removing leading '.' and append '.' to end

// TODO Should this be extended to be more aggressive/convert to IDNA?

// Should this remove whitespace at all? Maybe just dns specific things like dots

pub fn normalize_domain_name(domain_name: &String) -> String {
    if domain_name.is_empty() || domain_name == "." {
        return domain_name.clone();
    }

    let mut domain_name: String = domain_name
        .strip_prefix(".")
        .or(Some(domain_name))
        .unwrap()
        .into();

    if !domain_name.ends_with('.') {
        domain_name.push('.');
    }

    domain_name
}

/// Serialize domain_name into a DNS protocol conformant, network ready buffer, using message compression.

pub fn serialize_domain_name(
    domain_name: &String,
    buf: &mut Vec<u8>,
    domain_name_offsets: &mut HashMap<String, u16>,
) -> Result<(), String> {
    if !is_domain_name_valid(domain_name) {
        return Err(format!("invalid domain name: {}", domain_name));
    }
    let mut subdomain = &domain_name[0..domain_name.len()];

    loop {
        if subdomain.is_empty() || subdomain == "." {
            buf.push(0);
            break;
        }

        match domain_name_offsets.get(subdomain) {
            Some(offset) => {
                buf.push(((offset >> 8) & 0x3F) as u8 | 0xC0);
                buf.push((offset & 0xFF) as u8);
                break;
            }
            None => {
                let label = &subdomain[0..subdomain.find('.').ok_or("subdomain had no \'.\'")?];

                // Max offset is 0x3FFF, since the two high order bits are always set.

                // If we go past the possible offset value, no point in storing pointer.

                if (buf.len()) <= 0x3FFFusize {
                    domain_name_offsets.insert(subdomain.into(), buf.len() as u16);
                }

                buf.push(label.len() as u8);

                for ch in label.chars() {
                    buf.push(ch as u8);
                }

                subdomain = &subdomain
                    [subdomain.find('.').ok_or("subdomain had no \'.\'")? + 1..subdomain.len()];
            }
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{query_examples::*, DNS_HEADER_SIZE};

    #[test]
    fn test_parse_domain_name() -> Result<(), String> {
        let query = &Vec::from(BASIC_QUERY);
        let correct_domain_name = String::from("www.google.com.");

        let (domain_name, end) = parse_domain_name(query, DNS_HEADER_SIZE, query.len())?;

        assert_eq!(domain_name, correct_domain_name);
        assert_eq!(end, 28);

        let correct_domain_name = String::from("pointer.www.google.com.");
        let query = &Vec::from(NAME_COMPRESSION_QUERY);

        let (domain_name, end) = parse_domain_name(query, 32, query.len())?;

        assert_eq!(domain_name, correct_domain_name);
        assert_eq!(end, 42);

        Ok(())
    }

    #[test]
    fn test_is_domain_name_valid() -> Result<(), String> {
        let domain_name = String::from("www.google.com.");

        assert!(is_domain_name_valid(&domain_name));

        let domain_name = String::from(
            "reallylongdomainnamelabelistoolongreallylongdomainnamelabelistoolong.google.com.",
        );
        assert!(!is_domain_name_valid(&domain_name));

        let domain_name = String::from("www.space inlabel.google.com.");
        assert!(!is_domain_name_valid(&domain_name));

        let domain_name = String::from(".beginswith.");
        assert!(!is_domain_name_valid(&domain_name));

        let domain_name = String::from("hasnoending.dot");
        assert!(!is_domain_name_valid(&domain_name));

        let domain_name = String::from("");
        assert!(!is_domain_name_valid(&domain_name));

        let domain_name = String::from(".");
        assert!(is_domain_name_valid(&domain_name));

        Ok(())
    }

    #[test]
    fn test_normalize_domain_name() -> Result<(), String> {
        Ok(())
    }
}