use snafu::{Backtrace, OptionExt, ResultExt, Snafu};
use std::collections::HashMap;
use std::fmt::{Debug, Write as FmtWrite};
use std::io::{self, BufRead, BufReader, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::string::FromUtf8Error;
use std::time::Duration;
pub mod raw;
use io::Read;
use raw::*;
#[derive(Snafu, Debug)]
pub enum Ts3Error {
#[snafu(display("Input was invalid UTF-8: {}", source))]
Utf8Error { source: FromUtf8Error },
#[snafu(display("IO Error: {}{}, kind: {:?}", context, source,source.kind()))]
Io {
context: &'static str,
source: io::Error,
},
#[snafu(display("IO Error: Connection closed"))]
ConnectionClosed { backtrace: Backtrace },
#[snafu(display("No valid socket address provided."))]
InvalidSocketAddress { backtrace: Backtrace },
#[snafu(display("Received invalid response, {}{:?}", context, data))]
InvalidResponse {
context: &'static str,
data: String,
},
#[snafu(display("Server responded with error: {}", response))]
ServerError {
response: ErrorResponse,
backtrace: Backtrace,
},
#[snafu(display("Invalid response, DDOS limit reached: {:?}", response))]
ResponseLimit {
response: Vec<String>,
backtrace: Backtrace,
},
}
impl Ts3Error {
pub fn is_error_response(&self) -> bool {
match self {
Ts3Error::ServerError { .. } => true,
_ => false,
}
}
pub fn error_response(&self) -> Option<&ErrorResponse> {
match self {
Ts3Error::ServerError { response, .. } => Some(response),
_ => None,
}
}
}
impl From<io::Error> for Ts3Error {
fn from(error: io::Error) -> Self {
Ts3Error::Io {
context: "",
source: error,
}
}
}
#[derive(Debug)]
pub struct ErrorResponse {
pub id: usize,
pub msg: String,
}
impl std::fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Error code {}, msg: {}", self.id, self.msg)
}
}
pub struct QueryClient {
rx: BufReader<TcpStream>,
tx: TcpStream,
limit_lines: usize,
limit_lines_bytes: u64,
}
const LIMIT_READ_LINES: usize = 100;
const LIMIT_LINE_BYTES: u64 = 64_000;
type Result<T> = ::std::result::Result<T, Ts3Error>;
impl Drop for QueryClient {
fn drop(&mut self) {
self.quit();
let _ = self.tx.shutdown(Shutdown::Both);
}
}
impl QueryClient {
pub fn new<A: ToSocketAddrs>(addr: A) -> Result<Self> {
let (rx, tx) = Self::new_inner(addr, None, None)?;
Ok(Self {
rx,
tx,
limit_lines: LIMIT_READ_LINES,
limit_lines_bytes: LIMIT_LINE_BYTES,
})
}
pub fn with_timeout<A: ToSocketAddrs>(
addr: A,
t_connect: Option<Duration>,
timeout: Option<Duration>,
) -> Result<Self> {
let (rx, tx) = Self::new_inner(addr, timeout, t_connect)?;
Ok(Self {
rx,
tx,
limit_lines: LIMIT_READ_LINES,
limit_lines_bytes: LIMIT_LINE_BYTES,
})
}
pub fn limit_lines(&mut self, limit: usize) {
self.limit_lines = limit;
}
pub fn limit_line_bytes(&mut self, limit: u64) {
self.limit_lines_bytes = limit;
}
pub fn rename(&mut self, name: &str) -> Result<()> {
writeln!(
&mut self.tx,
"clientupdate client_nickname={}",
escape_arg(name)
)?;
let _ = self.read_response()?;
Ok(())
}
pub fn update_description(&mut self, descr: &str) -> Result<()> {
write!(
&mut self.tx,
"clientupdate CLIENT_DESCRIPTION={}",
escape_arg(descr)
)?;
let _ = self.read_response()?;
Ok(())
}
fn quit(&mut self) {
let _ = writeln!(&mut self.tx, "quit");
}
fn new_inner<A: ToSocketAddrs>(
addr: A,
timeout: Option<Duration>,
conn_timeout: Option<Duration>,
) -> Result<(BufReader<TcpStream>, TcpStream)> {
let addr = addr
.to_socket_addrs()
.context(Io {
context: "invalid socket address",
})?
.next()
.context(InvalidSocketAddress {})?;
let stream = if let Some(dur) = conn_timeout {
TcpStream::connect_timeout(&addr, dur).context(Io {
context: "while connecting: ",
})?
} else {
TcpStream::connect(addr).context(Io {
context: "while connecting: ",
})?
};
stream.set_write_timeout(timeout).context(Io {
context: "setting write timeout: ",
})?;
stream.set_read_timeout(timeout).context(Io {
context: "setting read timeout: ",
})?;
stream.set_nodelay(true).context(Io {
context: "setting nodelay: ",
})?;
let mut reader = BufReader::new(stream.try_clone().context(Io {
context: "splitting connection: ",
})?);
let mut buffer = Vec::new();
reader.read_until(b'\r', &mut buffer).context(Io {
context: "reading response: ",
})?;
buffer.clear();
if let Err(e) = reader.read_until(b'\r', &mut buffer) {
use std::io::ErrorKind::*;
match e.kind() {
TimedOut | WouldBlock => (),
_ => return Err(e.into()),
}
}
Ok((reader, stream))
}
pub fn raw_command(&mut self, command: &str) -> Result<Vec<String>> {
writeln!(&mut self.tx, "{}", command)?;
let v = self.read_response()?;
Ok(v)
}
pub fn whoami(&mut self, unescape: bool) -> Result<HashMap<String, String>> {
writeln!(&mut self.tx, "whoami")?;
let v = self.read_response()?;
Ok(parse_hashmap(v, unescape))
}
pub fn logout(&mut self) -> Result<()> {
writeln!(&mut self.tx, "logout")?;
let _ = self.read_response()?;
Ok(())
}
pub fn login(&mut self, user: &str, password: &str) -> Result<()> {
writeln!(
&mut self.tx,
"login {} {}",
escape_arg(user),
escape_arg(password)
)?;
let _ = self.read_response()?;
Ok(())
}
pub fn select_server_by_port(&mut self, port: u16) -> Result<()> {
writeln!(&mut self.tx, "use port={}", port)?;
let _ = self.read_response()?;
Ok(())
}
pub fn create_dir(&mut self, cid: usize, path: &str) -> Result<()> {
writeln!(
&mut self.tx,
"ftcreatedir cid={} cpw= dirname={}",
cid,
escape_arg(path)
)?;
let _ = self.read_response()?;
Ok(())
}
pub fn delete_file(&mut self, cid: usize, path: &str) -> Result<()> {
writeln!(
&mut self.tx,
"ftdeletefile cid={} cpw= name={}",
cid,
escape_arg(path)
)?;
let _ = self.read_response()?;
Ok(())
}
pub fn ping(&mut self) -> Result<()> {
writeln!(&mut self.tx, "whoami")?;
let _ = self.read_response()?;
Ok(())
}
pub fn select_server_by_id(&mut self, sid: usize) -> Result<()> {
writeln!(&mut self.tx, "use sid={}", sid)?;
let _ = self.read_response()?;
Ok(())
}
pub fn server_group_del_clients(&mut self, group: usize, cldbid: &[usize]) -> Result<()> {
if cldbid.is_empty() {
return Ok(());
}
writeln!(
&mut self.tx,
"servergroupdelclient sgid={} {}",
group,
Self::format_cldbids(cldbid)
)?;
let _ = self.read_response()?;
Ok(())
}
pub fn server_group_add_clients(&mut self, group: usize, cldbid: &[usize]) -> Result<()> {
if cldbid.is_empty() {
return Ok(());
}
let v = Self::format_cldbids(cldbid);
writeln!(&mut self.tx, "servergroupaddclient sgid={} {}", group, v)?;
let _ = self.read_response()?;
Ok(())
}
fn format_cldbids(it: &[usize]) -> String {
let mut res = String::new();
let mut it = it.iter();
if let Some(n) = it.next() {
write!(res, "cldbid={}", n).unwrap();
}
for n in it {
write!(res, "|cldbid={}", n).unwrap();
}
res
}
fn read_response(&mut self) -> Result<Vec<String>> {
let mut result: Vec<String> = Vec::new();
let mut lr = (&mut self.rx).take(self.limit_lines_bytes);
for _ in 0..self.limit_lines {
let mut buffer = Vec::new();
if lr.read_until(b'\r', &mut buffer).context(Io {
context: "reading response: ",
})? == 0
{
return ConnectionClosed {}.fail();
}
if buffer.ends_with(&[b'\r']) {
buffer.pop();
if buffer.ends_with(&[b'\n']) {
buffer.pop();
}
} else if lr.limit() == 0 {
return ResponseLimit { response: result }.fail();
} else {
return InvalidResponse {
context: "expected \\r delimiter, got: ",
data: String::from_utf8_lossy(&buffer),
}
.fail();
}
if buffer.len() > 0 {
let line = String::from_utf8(buffer).context(Utf8Error)?;
#[cfg(feature = "debug_response")]
println!("Read: {:?}", &line);
if line.starts_with("error ") {
Self::check_ok(&line)?;
return Ok(result);
}
result.push(line);
}
lr.set_limit(LIMIT_LINE_BYTES);
}
ResponseLimit { response: result }.fail()
}
pub fn get_servergroup_client_list(&mut self, server_group: usize) -> Result<Vec<usize>> {
writeln!(&mut self.tx, "servergroupclientlist sgid={}", server_group)?;
let resp = self.read_response()?;
if let Some(line) = resp.get(0) {
let data: Vec<usize> = line
.split('|')
.map(|e| {
if let Some(cldbid) = e.split('=').collect::<Vec<_>>().get(1) {
Ok(cldbid
.parse::<usize>()
.map_err(|_| Ts3Error::InvalidResponse {
context: "expected usize, got ",
data: line.to_string(),
})?)
} else {
Err(Ts3Error::InvalidResponse {
context: "expected data of cldbid=1, got ",
data: line.to_string(),
})
}
})
.collect::<Result<Vec<usize>>>()?;
Ok(data)
} else {
Ok(Vec::new())
}
}
fn check_ok(msg: &str) -> Result<()> {
let result: Vec<&str> = msg.split(' ').collect();
#[cfg(debug)]
{
assert_eq!(
"check_ok invoked on non-error line",
result.get(0),
Some(&"error")
);
}
if let (Some(id), Some(msg)) = (result.get(1), result.get(2)) {
let split_id: Vec<&str> = id.split('=').collect();
let split_msg: Vec<&str> = msg.split('=').collect();
if let (Some(id), Some(msg)) = (split_id.get(1), split_msg.get(1)) {
let id = id.parse::<usize>().map_err(|_| Ts3Error::InvalidResponse {
context: "expected usize, got ",
data: (*msg).to_string(),
})?;
if id != 0 {
return ServerError {
response: ErrorResponse {
id,
msg: unescape_val(*msg),
},
}
.fail();
} else {
return Ok(());
}
}
}
Err(Ts3Error::InvalidResponse {
context: "expected id and msg, got ",
data: msg.to_string(),
})
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_format_cldbids() {
let ids = vec![0, 1, 2, 3];
assert_eq!(
"cldbid=0|cldbid=1|cldbid=2|cldbid=3",
QueryClient::format_cldbids(&ids)
);
assert_eq!("", QueryClient::format_cldbids(&[]));
assert_eq!("cldbid=0", QueryClient::format_cldbids(&ids[0..1]));
}
}