1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
use nix::unistd::getuid;
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;

fn write_message(msg: &str, stream: &mut UnixStream) -> std::io::Result<()> {
    let mut buf = Vec::new();
    buf.extend(msg.bytes());
    buf.push(b'\r');
    buf.push(b'\n');
    stream.write_all(&buf)?;
    Ok(())
}

fn has_line_ending(buf: &[u8]) -> bool {
    for idx in 1..buf.len() {
        if buf[idx - 1] == b'\r' && buf[idx] == b'\n' {
            return true;
        }
    }
    false
}

fn find_line_ending(buf: &[u8]) -> Option<usize> {
    for idx in 1..buf.len() {
        if buf[idx - 1] == b'\r' && buf[idx] == b'\n' {
            return Some(idx - 1);
        }
    }
    None
}

fn read_message(stream: &mut UnixStream, buf: &mut Vec<u8>) -> std::io::Result<String> {
    let mut tmpbuf = [0u8; 512];
    while !has_line_ending(&buf) {
        let bytes = stream.read(&mut tmpbuf[..])?;
        buf.extend(&tmpbuf[..bytes])
    }
    let idx = find_line_ending(&buf).unwrap();
    let line = buf.drain(0..idx).collect::<Vec<_>>();
    Ok(String::from_utf8(line).unwrap())
}

fn get_uid_as_hex() -> String {
    let uid = getuid();
    let mut tmp = uid.as_raw();
    let mut numbers = Vec::new();
    if tmp == 0 {
        return "30".to_owned();
    }
    while tmp > 0 {
        numbers.push(tmp % 10);
        tmp /= 10;
    }
    let mut hex = String::new();
    for idx in 0..numbers.len() {
        hex.push_str(match numbers[numbers.len() - 1 - idx] {
            0 => "30",
            1 => "31",
            2 => "32",
            3 => "33",
            4 => "34",
            5 => "35",
            6 => "36",
            7 => "37",
            8 => "38",
            9 => "39",
            _ => unreachable!(),
        })
    }

    hex
}

pub enum AuthResult {
    Ok,
    Rejected,
}

pub fn do_auth(stream: &mut UnixStream) -> std::io::Result<AuthResult> {
    // send a null byte as the first thing
    stream.write_all(&[0])?;
    write_message(&format!("AUTH EXTERNAL {}", get_uid_as_hex()), stream)?;

    let mut read_buf = Vec::new();
    let msg = read_message(stream, &mut read_buf)?;
    if msg.starts_with("OK") {
        Ok(AuthResult::Ok)
    } else {
        Ok(AuthResult::Rejected)
    }
}

pub fn negotiate_unix_fds(stream: &mut UnixStream) -> std::io::Result<AuthResult> {
    write_message("NEGOTIATE_UNIX_FD", stream)?;

    let mut read_buf = Vec::new();
    let msg = read_message(stream, &mut read_buf)?;
    if msg.starts_with("AGREE_UNIX_FD") {
        Ok(AuthResult::Ok)
    } else {
        Ok(AuthResult::Rejected)
    }
}

pub fn send_begin(stream: &mut UnixStream) -> std::io::Result<()> {
    write_message("BEGIN", stream)?;
    Ok(())
}