1use std::collections::HashMap;
2use std::fmt::{Debug, Formatter};
3use std::io::{Read, Write};
4use std::net::{Ipv4Addr, Shutdown, TcpStream, ToSocketAddrs};
5use std::sync::atomic::{AtomicI64, Ordering};
6use std::sync::mpsc::{Receiver, Sender};
7use std::sync::{Arc, Mutex, RwLock, mpsc};
8use std::thread;
9use std::time::Duration;
10
11use bytebuffer::ByteBuffer;
12use byteorder::{BigEndian, ReadBytesExt};
13
14use crate::encode::{Value, VoltError};
15use crate::procedure_invocation::new_procedure_invocation;
16use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
17use crate::response::VoltResponseInfo;
18use crate::table::{VoltTable, new_volt_table};
19use crate::volt_param;
20
21#[derive(Clone, Eq, PartialEq, Debug)]
22pub struct Opts(pub(crate) Box<InnerOpts>);
23
24#[derive(Debug, Clone, Eq, PartialEq)]
25pub struct IpPort {
26 pub(crate) ip_host: String,
27 pub(crate) port: u16,
28}
29
30impl IpPort {
31 pub fn new(ip_host: String, port: u16) -> Self {
32 IpPort { ip_host, port }
33 }
34}
35
36impl Opts {
37 pub fn new(hosts: Vec<IpPort>) -> Opts {
38 Opts(Box::new(InnerOpts {
39 ip_ports: hosts,
40 user: None,
41 pass: None,
42 connect_timeout: None,
43 read_timeout: None,
44 }))
45 }
46
47 pub fn builder() -> OptsBuilder {
49 OptsBuilder::default()
50 }
51}
52
53#[derive(Debug, Clone, Default)]
69pub struct OptsBuilder {
70 hosts: Vec<IpPort>,
71 user: Option<String>,
72 pass: Option<String>,
73 connect_timeout: Option<Duration>,
74 read_timeout: Option<Duration>,
75}
76
77impl OptsBuilder {
78 pub fn host(mut self, ip: &str, port: u16) -> Self {
80 self.hosts.push(IpPort::new(ip.to_string(), port));
81 self
82 }
83
84 pub fn hosts(mut self, hosts: Vec<IpPort>) -> Self {
86 self.hosts.extend(hosts);
87 self
88 }
89
90 pub fn user(mut self, user: &str) -> Self {
92 self.user = Some(user.to_string());
93 self
94 }
95
96 pub fn password(mut self, pass: &str) -> Self {
98 self.pass = Some(pass.to_string());
99 self
100 }
101
102 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
104 self.connect_timeout = Some(timeout);
105 self
106 }
107
108 pub fn read_timeout(mut self, timeout: Duration) -> Self {
110 self.read_timeout = Some(timeout);
111 self
112 }
113
114 pub fn build(self) -> Result<Opts, VoltError> {
118 if self.hosts.is_empty() {
119 return Err(VoltError::InvalidConfig);
120 }
121 Ok(Opts(Box::new(InnerOpts {
122 ip_ports: self.hosts,
123 user: self.user,
124 pass: self.pass,
125 connect_timeout: self.connect_timeout,
126 read_timeout: self.read_timeout,
127 })))
128 }
129}
130
131#[derive(Debug, Clone, Eq, PartialEq)]
132pub(crate) struct InnerOpts {
133 pub(crate) ip_ports: Vec<IpPort>,
134 pub(crate) user: Option<String>,
135 pub(crate) pass: Option<String>,
136 pub(crate) connect_timeout: Option<Duration>,
137 pub(crate) read_timeout: Option<Duration>,
138}
139
140pub struct NodeOpt {
141 pub ip_port: IpPort,
142 pub user: Option<String>,
143 pub pass: Option<String>,
144}
145
146#[derive(Debug)]
147#[allow(dead_code)]
148pub(crate) struct NetworkRequest {
149 handle: i64,
150 query: bool,
151 sync: bool,
152 num_bytes: i32,
153 channel: Mutex<Sender<VoltTable>>,
154}
155
156pub trait Connection: Sync + Send + 'static {}
157
158#[allow(dead_code)]
159pub struct Node {
160 tcp_stream: Box<Option<TcpStream>>,
161 info: ConnInfo,
162 requests: Arc<RwLock<HashMap<i64, NetworkRequest>>>,
163 stop: Arc<Mutex<bool>>,
164 counter: Mutex<AtomicI64>,
165}
166
167impl Debug for Node {
168 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
169 write!(f, "Pending request: {}", 1)
170 }
171}
172
173impl Drop for Node {
174 fn drop(&mut self) {
175 let res = self.shutdown();
176 match res {
177 Ok(_) => {}
178 Err(e) => {
179 eprintln!("{:?}", e);
180 }
181 }
182 }
183}
184
185impl Connection for Node {}
186
187impl Node {
188 pub fn new(opt: NodeOpt) -> Result<Node, VoltError> {
189 let ip_host = opt.ip_port;
190 let addr = format!("{}:{}", ip_host.ip_host, ip_host.port);
191
192 let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
194
195 let mut stream: TcpStream = TcpStream::connect(addr)?;
197
198 stream.write_all(&auth_msg)?;
200 stream.flush()?;
201
202 let read = stream.read_u32::<BigEndian>()?;
204 let mut all = vec![0; read as usize];
205 stream.read_exact(&mut all)?;
206
207 let info = parse_auth_response(&all)?;
209
210 let data = Arc::new(RwLock::new(HashMap::new()));
211 let mut res = Node {
212 stop: Arc::new(Mutex::new(false)),
213 tcp_stream: Box::new(Option::Some(stream)),
214 info,
215 requests: data,
216 counter: Mutex::new(AtomicI64::new(1)),
217 };
218 res.listen()?;
219 Ok(res)
220 }
221 pub fn get_sequence(&self) -> i64 {
222 let lock = self.counter.lock();
223 let seq = lock.unwrap();
224
225 seq.fetch_add(1, Ordering::Relaxed)
226 }
227
228 pub fn list_procedures(&mut self) -> Result<Receiver<VoltTable>, VoltError> {
229 self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
230 }
231
232 pub fn call_sp(
233 &mut self,
234 query: &str,
235 param: Vec<&dyn Value>,
236 ) -> Result<Receiver<VoltTable>, VoltError> {
237 let req = self.get_sequence();
238 let mut proc = new_procedure_invocation(req, false, ¶m, query);
239 let (tx, rx): (Sender<VoltTable>, Receiver<VoltTable>) = mpsc::channel();
240 let shared_sender = Mutex::new(tx);
241 let seq = NetworkRequest {
242 query: true,
243 handle: req,
244 num_bytes: proc.slen,
245 sync: true,
246 channel: shared_sender,
247 };
248 self.requests.write()?.insert(req, seq);
249 let bs = proc.bytes();
250 let tcp_stream = self.tcp_stream.as_mut();
251 match tcp_stream {
252 None => {
253 return Err(VoltError::ConnectionNotAvailable);
254 }
255 Some(stream) => {
256 stream.write_all(&bs)?;
257 }
258 }
259 Ok(rx)
260 }
261
262 pub fn upload_jar(&mut self, bs: Vec<u8>) -> Result<Receiver<VoltTable>, VoltError> {
263 self.call_sp("@UpdateClasses", volt_param!(bs, ""))
264 }
265 pub fn query(&mut self, sql: &str) -> Result<Receiver<VoltTable>, VoltError> {
267 let zero_vec: Vec<&dyn Value> = vec![&sql];
268 self.call_sp("@AdHoc", zero_vec)
269 }
270
271 pub fn ping(&mut self) -> Result<(), VoltError> {
272 let zero_vec: Vec<&dyn Value> = Vec::new();
273 let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
274 let bs = proc.bytes();
275 let res = self.tcp_stream.as_mut();
276 match res {
277 None => {
278 return Err(VoltError::ConnectionNotAvailable);
279 }
280 Some(stream) => {
281 stream.write_all(&bs)?;
282 }
283 }
284 Ok(())
285 }
286
287 fn job(
288 mut tcp: &TcpStream,
289 requests: &Arc<RwLock<HashMap<i64, NetworkRequest>>>,
290 ) -> Result<(), VoltError> {
291 let read_res = tcp.read_u32::<BigEndian>();
292 match read_res {
293 Ok(read) => {
294 if read > 0 {
295 let mut all = vec![0; read as usize];
296 tcp.read_exact(&mut all)?;
297 let mut res = ByteBuffer::from_bytes(&all);
298 let _ = res.read_u8()?;
299 let handle = res.read_i64()?;
300 if handle == PING_HANDLE {
301 return Ok(());
302 }
303 if let Some(t) = requests.write()?.remove(&handle) {
304 let info = VoltResponseInfo::new(&mut res, handle)?;
305 let table = new_volt_table(&mut res, info)?;
306 let sender = t.channel.lock()?;
307 sender.send(table).unwrap();
308 }
309 }
310 }
311 Err(e) => {
312 return Err(VoltError::Io(e));
313 }
314 }
315 Ok(())
316 }
317 pub fn shutdown(&mut self) -> Result<(), VoltError> {
318 let mut stop = self.stop.lock().unwrap();
319 *stop = true;
320 let res = self.tcp_stream.as_mut();
321 match res {
322 None => {}
323 Some(stream) => {
324 stream.shutdown(Shutdown::Both)?;
325 }
326 }
327 *self.tcp_stream = Option::None;
328 Ok(())
329 }
330 fn listen(&mut self) -> Result<(), VoltError> {
332 let requests = Arc::clone(&self.requests);
333
334 let res = self.tcp_stream.as_mut();
335 match res {
336 None => Ok(()),
337 Some(res) => {
338 let tcp = res.try_clone()?;
339 let stopping = Arc::clone(&self.stop);
340 thread::spawn(move || {
341 loop {
342 if *stopping.lock().unwrap() {
343 break;
344 } else {
345 let res = crate::node::Node::job(&tcp, &requests);
346 match res {
347 Ok(_) => {}
348 Err(err) => {
349 if !*stopping.lock().unwrap() {
350 eprintln!("{} ", err)
351 }
352 }
353 }
354 }
355 }
356 });
357 Ok(())
358 }
359 }
360 }
361}
362
363#[derive(Debug, Clone)]
364pub struct ConnInfo {
365 pub host_id: i32,
366 pub connection: i64,
367 pub leader_addr: Ipv4Addr,
368 pub build: String,
369}
370impl Default for ConnInfo {
371 fn default() -> Self {
372 Self {
373 host_id: 0,
374 connection: 0,
375 leader_addr: Ipv4Addr::new(127, 0, 0, 1),
376 build: String::new(),
377 }
378 }
379}
380
381pub fn block_for_result(res: &Receiver<VoltTable>) -> Result<VoltTable, VoltError> {
383 let mut table = res.recv()?;
384 let err = table.has_error();
385 match err {
386 None => Ok(table),
387 Some(err) => Err(err),
388 }
389}
390
391pub fn reset() {}
392
393pub fn get_node(addr: &str) -> Result<Node, VoltError> {
395 let addrs = addr
396 .to_socket_addrs()
397 .map_err(|_| VoltError::InvalidConfig)?;
398
399 let socket_addr = addrs
400 .into_iter()
401 .find(|s| s.is_ipv4()) .ok_or(VoltError::InvalidConfig)?;
403
404 let ip_port = IpPort::new(socket_addr.ip().to_string(), socket_addr.port());
405
406 let opt = NodeOpt {
407 ip_port,
408 user: None,
409 pass: None,
410 };
411 Node::new(opt)
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_opts_builder_basic() {
420 let opts = Opts::builder().host("localhost", 21212).build().unwrap();
421
422 assert_eq!(opts.0.ip_ports.len(), 1);
423 assert_eq!(opts.0.ip_ports[0].ip_host, "localhost");
424 assert_eq!(opts.0.ip_ports[0].port, 21212);
425 assert!(opts.0.user.is_none());
426 assert!(opts.0.pass.is_none());
427 }
428
429 #[test]
430 fn test_opts_builder_with_auth() {
431 let opts = Opts::builder()
432 .host("127.0.0.1", 21211)
433 .user("admin")
434 .password("secret")
435 .build()
436 .unwrap();
437
438 assert_eq!(opts.0.user, Some("admin".to_string()));
439 assert_eq!(opts.0.pass, Some("secret".to_string()));
440 }
441
442 #[test]
443 fn test_opts_builder_multiple_hosts() {
444 let opts = Opts::builder()
445 .host("host1", 21212)
446 .host("host2", 21212)
447 .host("host3", 21212)
448 .build()
449 .unwrap();
450
451 assert_eq!(opts.0.ip_ports.len(), 3);
452 assert_eq!(opts.0.ip_ports[0].ip_host, "host1");
453 assert_eq!(opts.0.ip_ports[1].ip_host, "host2");
454 assert_eq!(opts.0.ip_ports[2].ip_host, "host3");
455 }
456
457 #[test]
458 fn test_opts_builder_with_hosts_vec() {
459 let hosts = vec![
460 IpPort::new("node1".to_string(), 21212),
461 IpPort::new("node2".to_string(), 21213),
462 ];
463 let opts = Opts::builder().hosts(hosts).build().unwrap();
464
465 assert_eq!(opts.0.ip_ports.len(), 2);
466 }
467
468 #[test]
469 fn test_opts_builder_with_timeouts() {
470 let opts = Opts::builder()
471 .host("localhost", 21212)
472 .connect_timeout(Duration::from_secs(10))
473 .read_timeout(Duration::from_secs(30))
474 .build()
475 .unwrap();
476
477 assert_eq!(opts.0.connect_timeout, Some(Duration::from_secs(10)));
478 assert_eq!(opts.0.read_timeout, Some(Duration::from_secs(30)));
479 }
480
481 #[test]
482 fn test_opts_builder_no_hosts_fails() {
483 let result = Opts::builder().build();
484 assert!(result.is_err());
485 match result {
486 Err(VoltError::InvalidConfig) => {}
487 _ => panic!("Expected InvalidConfig error"),
488 }
489 }
490
491 #[test]
492 fn test_opts_new_compatibility() {
493 let hosts = vec![IpPort::new("localhost".to_string(), 21212)];
494 let opts = Opts::new(hosts);
495
496 assert_eq!(opts.0.ip_ports.len(), 1);
497 assert!(opts.0.user.is_none());
498 assert!(opts.0.connect_timeout.is_none());
499 }
500
501 #[test]
502 fn test_ip_port_new() {
503 let ip_port = IpPort::new("192.168.1.1".to_string(), 8080);
504 assert_eq!(ip_port.ip_host, "192.168.1.1");
505 assert_eq!(ip_port.port, 8080);
506 }
507}