1pub mod basic_types;
6pub mod commands;
7pub mod message_parser;
8pub mod messages;
9
10use basic_types::{Compression, PasswordHashAlgo};
11use commands::{Command, CommandType, DynCommand, HandshakeCommand};
12use message_parser::ParseMessageError;
13use messages::{Message, WHashtable, WString};
14use std::io::Write;
15use std::net::TcpStream;
16use std::string::String;
17
18type NomError = nom::error::Error<Vec<u8>>;
19
20#[derive(Debug)]
21pub enum WeechatError {
22 NewlineInArgument,
24 IOError(std::io::Error),
26 ParserError(ParseMessageError<NomError>),
28 UnexpectedResponse(String),
30 FailedHandshake,
32}
33
34impl std::fmt::Display for WeechatError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::NewlineInArgument => writeln!(f, "newline found in unescaped argument"),
38 Self::IOError(e) => e.fmt(f),
39 Self::ParserError(e) => e.fmt(f),
40 Self::UnexpectedResponse(s) => writeln!(f, "received unexpected message: {}", s),
41 Self::FailedHandshake => writeln!(f, "handshake failed to negotiate parameters"),
42 }
43 }
44}
45
46impl std::error::Error for WeechatError {}
47
48impl From<std::io::Error> for WeechatError {
49 fn from(error: std::io::Error) -> Self {
50 Self::IOError(error)
51 }
52}
53
54impl From<ParseMessageError<NomError>> for WeechatError {
55 fn from(error: ParseMessageError<NomError>) -> Self {
56 Self::ParserError(error)
57 }
58}
59
60impl From<std::string::FromUtf8Error> for WeechatError {
61 fn from(_error: std::string::FromUtf8Error) -> Self {
62 Self::UnexpectedResponse("non-UTF-8 message".to_string())
63 }
64}
65
66impl From<std::str::Utf8Error> for WeechatError {
67 fn from(_error: std::str::Utf8Error) -> Self {
68 Self::UnexpectedResponse("non-UTF-8 message".to_string())
69 }
70}
71
72#[derive(Debug)]
76pub struct Connection {
77 pub stream: TcpStream,
78 pub password_hash_algo: PasswordHashAlgo,
79 pub password_hash_iterations: u32,
80 pub totp: bool,
81 pub nonce: Vec<u8>,
82 pub compression: Compression,
83 pub escape_commands: bool,
84}
85
86impl Connection {
87 pub fn new(
93 mut stream: TcpStream,
94 handshake: Option<HandshakeCommand>,
95 ) -> Result<Self, WeechatError> {
96 let Some(handshake) = handshake else {
97 return Ok(Self {
98 stream,
99 password_hash_algo: PasswordHashAlgo::Plain,
100 password_hash_iterations: 0,
101 totp: false,
102 nonce: vec![],
103 compression: Compression::Off,
104 escape_commands: false,
105 });
106 };
107
108 stream.write_all(&Vec::<u8>::from(handshake.to_string()))?;
109 stream.flush()?;
110
111 let messages::Object::Htb(response) = message_parser::get_message::<NomError>(&mut stream)?
112 .objects
113 .into_iter()
114 .next()
115 .expect("shouldn't return without a response")
116 else {
117 return Err(WeechatError::UnexpectedResponse(
118 "non-htb handshake".to_string(),
119 ));
120 };
121
122 let WHashtable { keys, vals } = response;
123 let messages::WArray::Str(skeys) = keys else {
124 return Err(WeechatError::UnexpectedResponse(
125 "non-str handshake keys".to_string(),
126 ));
127 };
128 let messages::WArray::Str(svals) = vals else {
129 return Err(WeechatError::UnexpectedResponse(
130 "non-str handshake vals".to_string(),
131 ));
132 };
133 let config = messages::to_hashmap(skeys, svals);
134 let password_hash_algo = config
135 .get(&WString::from_ref(b"password_hash_algo"))
136 .ok_or(WeechatError::UnexpectedResponse(
137 "handshake did not return a password_hash_algo".to_string(),
138 ))?
139 .bytes()
140 .clone()
141 .map(String::from_utf8)
142 .transpose()?
143 .and_then(|s| PasswordHashAlgo::parse(&s))
144 .ok_or(WeechatError::FailedHandshake)?;
145
146 let password_hash_iterations = match password_hash_algo {
147 PasswordHashAlgo::Pbkdf2Sha256 | PasswordHashAlgo::Pbkdf2Sha512 => {
148 let bytes = config
149 .get(&WString::from_ref(b"password_hash_algo"))
150 .and_then(|s| s.bytes().clone())
151 .ok_or(WeechatError::UnexpectedResponse(
152 "iterated hash selected, but no iteration count returned in handshake"
153 .to_string(),
154 ))?;
155 let s = String::from_utf8(bytes)?;
156 s.parse().or(Err(WeechatError::UnexpectedResponse(
157 "password_hash_iterations was non-numerical".to_string(),
158 )))?
159 }
160 _ => 0,
161 };
162
163 let totp = config.get(&WString::from_ref(b"totp")) == Some(&WString::from_ref(b"on"));
164
165 let nonce_hex = config
166 .get(&WString::from_ref(b"nonce"))
167 .and_then(|w| w.bytes().clone());
168 let nonce = if let Some(hex) = nonce_hex {
169 bytes_from_hex(&hex)?
170 } else {
171 vec![]
172 };
173
174 let compression = config
175 .get(&WString::from_ref(b"compression"))
176 .and_then(|w| w.bytes().clone())
177 .map(String::from_utf8)
178 .transpose()?;
179 let compression = if let Some(compression) = compression {
180 Compression::parse(&compression).ok_or(WeechatError::FailedHandshake)?
181 } else {
182 Compression::Off
183 };
184
185 let escape_commands =
186 config.get(&WString::from_ref(b"escape_commands")) == Some(&WString::from_ref(b"on"));
187
188 Ok(Self {
189 stream,
190 password_hash_algo,
191 password_hash_iterations,
192 totp,
193 nonce,
194 compression,
195 escape_commands,
196 })
197 }
198
199 fn check_unescaped_arg(arg: String) -> Result<String, WeechatError> {
200 if !arg.is_empty() && arg[..arg.len() - 1].contains('\n') {
201 return Err(WeechatError::NewlineInArgument);
202 }
203 Ok(arg)
204 }
205
206 pub fn send_command<T: CommandType>(
208 &mut self,
209 command: &Command<T>,
210 ) -> Result<(), WeechatError> {
211 let string = if self.escape_commands {
212 command.escaped()
213 } else {
214 Connection::check_unescaped_arg(command.to_string())?
215 };
216 self.stream.write_all(&Vec::<u8>::from(string))?;
217 Ok(self.stream.flush()?)
218 }
219
220 pub fn send_commands(
222 &mut self,
223 commands: &mut dyn Iterator<Item = &DynCommand>,
224 ) -> Result<(), WeechatError> {
225 let commands: String = if self.escape_commands {
226 commands.map(DynCommand::escaped).collect()
227 } else {
228 commands
229 .map(|c| Connection::check_unescaped_arg(c.to_string()))
230 .collect::<Result<Vec<_>, _>>()?
231 .into_iter()
232 .collect()
233 };
234 self.stream.write_all(&Vec::<u8>::from(commands))?;
235 Ok(self.stream.flush()?)
236 }
237
238 pub fn get_message(&mut self) -> Result<Message, ParseMessageError<NomError>> {
240 message_parser::get_message::<NomError>(&mut self.stream)
241 }
242}
243
244fn bytes_from_hex(ascii_hex: &[u8]) -> Result<Vec<u8>, WeechatError> {
245 let s = std::str::from_utf8(ascii_hex)?;
246 (0..s.len())
247 .step_by(2)
248 .map(|i| u8::from_str_radix(&s[i..i + 2], 16))
249 .collect::<Result<Vec<_>, _>>()
250 .or(Err(WeechatError::UnexpectedResponse(
251 "expected valid hexidecimal".to_string(),
252 )))
253}