pub mod reconnect;
pub mod tcp;
use std::collections::HashMap;
use std::io::Cursor;
use std::rc::Rc;
use std::sync::Arc;
use self::tcp::TcpStream;
use super::protocol::api::{Call, Eval, Execute, Ping, Request};
use super::protocol::{self, Protocol, SyncIndex};
use crate::error;
use crate::error::BoxError;
use crate::fiber;
use crate::fiber::r#async::oneshot;
use crate::fiber::r#async::IntoOnDrop as _;
use crate::fiber::FiberId;
use crate::fiber::NoYieldsRefCell;
use crate::tuple::{ToTupleBuffer, Tuple};
use crate::unwrap_ok_or;
use futures::{AsyncReadExt, AsyncWriteExt};
#[deprecated = "use `ClientError` instead"]
pub type Error = ClientError;
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("{0}")]
ConnectionClosed(Arc<crate::error::Error>),
#[error("{0}")]
RequestEncode(crate::error::Error),
#[error("{0}")]
ResponseDecode(crate::error::Error),
#[error("{0}")]
ErrorResponse(BoxError),
}
impl From<ClientError> for crate::error::Error {
#[inline(always)]
fn from(err: ClientError) -> Self {
match err {
ClientError::ConnectionClosed(err) => crate::error::Error::ConnectionClosed(err),
ClientError::RequestEncode(err) => err,
ClientError::ResponseDecode(err) => err,
ClientError::ErrorResponse(err) => crate::error::Error::Remote(err),
}
}
}
#[derive(Clone, Debug)]
enum State {
Alive,
ClosedManually,
ClosedWithError(Arc<error::Error>),
}
impl State {
fn is_alive(&self) -> bool {
matches!(self, Self::Alive)
}
fn is_closed(&self) -> bool {
!self.is_alive()
}
}
#[derive(Debug)]
struct ClientInner {
protocol: Protocol,
awaiting_response: HashMap<SyncIndex, oneshot::Sender<Result<(), Arc<error::Error>>>>,
state: State,
stream: TcpStream,
sender_fiber_id: Option<FiberId>,
receiver_fiber_id: Option<FiberId>,
clients_count: usize,
}
impl ClientInner {
pub fn new(config: protocol::Config, stream: TcpStream) -> Self {
#[cfg(feature = "picodata")]
if config.auth_method == crate::auth::AuthMethod::Ldap {
crate::say_warn!(
"You're using the 'ldap' authentication method, which implies sending the password UNENCRYPTED over the TCP connection. TLS is not yet implemented for IPROTO connections so make sure your communication channel is secure by other means."
)
}
Self {
protocol: Protocol::with_config(config),
awaiting_response: HashMap::new(),
state: State::Alive,
stream,
sender_fiber_id: None,
receiver_fiber_id: None,
clients_count: 1,
}
}
}
fn maybe_wake_sender(client: &ClientInner) {
if client.protocol.ready_outgoing_len() == 0 {
return;
}
if let Some(id) = client.sender_fiber_id {
fiber::wakeup(id);
}
}
#[derive(Debug)]
pub struct Client(Rc<NoYieldsRefCell<ClientInner>>);
impl Client {
pub async fn connect(url: &str, port: u16) -> Result<Self, ClientError> {
Self::connect_with_config(url, port, Default::default()).await
}
pub async fn connect_with_config(
url: &str,
port: u16,
config: protocol::Config,
) -> Result<Self, ClientError> {
let stream = TcpStream::connect(url, port)
.map_err(|e| ClientError::ConnectionClosed(Arc::new(e.into())))?;
let client = ClientInner::new(config, stream.clone());
let client = Rc::new(NoYieldsRefCell::new(client));
let receiver_fiber_id = fiber::Builder::new()
.func_async(receiver(client.clone(), stream.clone()))
.name(format!("iproto-in/{url}:{port}"))
.start_non_joinable()
.unwrap();
let sender_fiber_id = fiber::Builder::new()
.func_async(sender(client.clone(), stream))
.name(format!("iproto-out/{url}:{port}"))
.start_non_joinable()
.unwrap();
{
let mut client_mut = client.borrow_mut();
client_mut.receiver_fiber_id = Some(receiver_fiber_id);
client_mut.sender_fiber_id = Some(sender_fiber_id);
}
Ok(Self(client))
}
fn check_state(&self) -> Result<(), Arc<error::Error>> {
match &self.0.borrow().state {
State::Alive => Ok(()),
State::ClosedManually => unreachable!("All client handles are dropped at this point"),
State::ClosedWithError(err) => Err(err.clone()),
}
}
}
#[async_trait::async_trait(?Send)]
pub trait AsClient {
async fn send<R: Request>(&self, request: &R) -> Result<R::Response, ClientError>;
async fn ping(&self) -> Result<(), ClientError> {
self.send(&Ping).await
}
async fn call<T>(&self, fn_name: &str, args: &T) -> Result<Tuple, ClientError>
where
T: ToTupleBuffer + ?Sized,
{
self.send(&Call { fn_name, args }).await
}
async fn eval<T>(&self, expr: &str, args: &T) -> Result<Tuple, ClientError>
where
T: ToTupleBuffer + ?Sized,
{
self.send(&Eval { args, expr }).await
}
async fn execute<T>(&self, sql: &str, bind_params: &T) -> Result<Vec<Tuple>, ClientError>
where
T: ToTupleBuffer + ?Sized,
{
self.send(&Execute { sql, bind_params }).await
}
}
#[async_trait::async_trait(?Send)]
impl AsClient for Client {
async fn send<R: Request>(&self, request: &R) -> Result<R::Response, ClientError> {
if let Err(e) = self.check_state() {
return Err(ClientError::ConnectionClosed(e));
}
let res = self.0.borrow_mut().protocol.send_request(request);
let sync = unwrap_ok_or!(res,
Err(e) => {
return Err(ClientError::RequestEncode(e));
}
);
let (tx, rx) = oneshot::channel();
self.0.borrow_mut().awaiting_response.insert(sync, tx);
maybe_wake_sender(&self.0.borrow());
let res = rx
.on_drop(|| {
let _ = self.0.borrow_mut().awaiting_response.remove(&sync);
})
.await
.expect("Channel should be open");
if let Err(e) = res {
return Err(ClientError::ConnectionClosed(e));
}
let res = self
.0
.borrow_mut()
.protocol
.take_response::<R>(sync)
.expect("Is present at this point");
let response = unwrap_ok_or!(res,
Err(error::Error::Remote(response)) => {
return Err(ClientError::ErrorResponse(response));
}
Err(e) => {
return Err(ClientError::ResponseDecode(e));
}
);
Ok(response)
}
}
impl Drop for Client {
fn drop(&mut self) {
let clients_count = self.0.borrow().clients_count;
if clients_count == 1 {
let mut client = self.0.borrow_mut();
client.state = State::ClosedManually;
let receiver_fiber_id = client.receiver_fiber_id;
let sender_fiber_id = client.sender_fiber_id;
if let Err(e) = client.stream.close() {
crate::say_error!("Client::drop: failed closing tcp stream: {e}");
}
drop(client);
if let Some(id) = receiver_fiber_id {
fiber::cancel(id);
fiber::wakeup(id);
}
if let Some(id) = sender_fiber_id {
fiber::cancel(id);
fiber::wakeup(id);
}
} else {
self.0.borrow_mut().clients_count -= 1;
}
}
}
impl Clone for Client {
fn clone(&self) -> Self {
self.0.borrow_mut().clients_count += 1;
Self(self.0.clone())
}
}
macro_rules! handle_result {
($client:expr, $e:expr) => {
match $e {
Ok(value) => value,
Err(err) => {
let err = Arc::new(error::Error::from(err));
let subscriptions: HashMap<_, _> = $client.awaiting_response.drain().collect();
for (_, subscription) in subscriptions {
let _ = subscription.send(Err(err.clone()));
}
$client.state = State::ClosedWithError(err);
return;
}
}
};
}
async fn sender(client: Rc<NoYieldsRefCell<ClientInner>>, mut writer: TcpStream) {
loop {
if client.borrow().state.is_closed() || fiber::is_cancelled() {
return;
}
let data = client.borrow_mut().protocol.take_outgoing_data();
if data.is_empty() {
fiber::fiber_yield();
} else {
let result = writer.write_all(&data).await;
handle_result!(client.borrow_mut(), result);
}
}
}
#[allow(clippy::await_holding_refcell_ref)]
async fn receiver(client_cell: Rc<NoYieldsRefCell<ClientInner>>, mut reader: TcpStream) {
let mut buf = vec![0_u8; 4096];
loop {
let client = client_cell.borrow();
if client.state.is_closed() || fiber::is_cancelled() {
return;
}
let size = client.protocol.read_size_hint();
if buf.len() < size {
buf.resize(size, 0);
}
let buf_slice = &mut buf[0..size];
drop(client);
let res = reader.read_exact(buf_slice).await;
let mut client = client_cell.borrow_mut();
handle_result!(client, res);
let result = client
.protocol
.process_incoming(&mut Cursor::new(buf_slice));
let result = handle_result!(client, result);
if let Some(sync) = result {
let subscription = client.awaiting_response.remove(&sync);
if let Some(subscription) = subscription {
subscription
.send(Ok(()))
.expect("cannot be closed at this point");
} else {
crate::say_warn!("received unwaited message for {sync:?}");
}
}
maybe_wake_sender(&client);
}
}
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
use crate::error::TarantoolErrorCode;
use crate::fiber::r#async::timeout::IntoTimeout as _;
use crate::space::Space;
use crate::test::util::listen_port;
use std::time::Duration;
async fn test_client() -> Client {
Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some(("test_user".into(), "password".into())),
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap()
}
#[crate::test(tarantool = "crate")]
async fn connect() {
let _client = Client::connect("localhost", listen_port()).await.unwrap();
}
#[crate::test(tarantool = "crate")]
async fn connect_failure() {
let err = Client::connect("localhost", 0).await.unwrap_err();
assert!(matches!(dbg!(err), ClientError::ConnectionClosed(_)))
}
#[crate::test(tarantool = "crate")]
async fn ping() {
let client = test_client().await;
for _ in 0..5 {
client.ping().timeout(Duration::from_secs(3)).await.unwrap();
}
}
#[crate::test(tarantool = "crate")]
fn ping_concurrent() {
let client = fiber::block_on(test_client());
let fiber_a = fiber::start_async(async {
client.ping().timeout(Duration::from_secs(3)).await.unwrap()
});
let fiber_b = fiber::start_async(async {
client.ping().timeout(Duration::from_secs(3)).await.unwrap()
});
fiber_a.join();
fiber_b.join();
}
#[crate::test(tarantool = "crate")]
async fn execute() {
Space::find("test_s1")
.unwrap()
.insert(&(6001, "6001"))
.unwrap();
Space::find("test_s1")
.unwrap()
.insert(&(6002, "6002"))
.unwrap();
let client = test_client().await;
let lua = crate::lua_state();
_ = lua.exec("require'compat'.sql_seq_scan_default = 'old'");
let result = client
.execute(r#"SELECT * FROM "test_s1""#, &())
.timeout(Duration::from_secs(3))
.await
.unwrap();
assert!(result.len() >= 2);
let result = client
.execute(r#"SELECT * FROM "test_s1" WHERE "id" = ?"#, &(6002,))
.timeout(Duration::from_secs(3))
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result.first().unwrap().decode::<(u64, String)>().unwrap(),
(6002, "6002".into())
);
}
#[crate::test(tarantool = "crate")]
async fn call() {
let client = test_client().await;
let result = client
.call("test_stored_proc", &(1, 2))
.timeout(Duration::from_secs(3))
.await
.unwrap();
assert_eq!(result.decode::<(i32,)>().unwrap(), (3,));
}
#[crate::test(tarantool = "crate")]
async fn invalid_call() {
let client = test_client().await;
let err = client
.call("unexistent_proc", &())
.timeout(Duration::from_secs(3))
.await
.unwrap_err();
let err = error::Error::from(err);
let error::Error::Remote(err) = err else {
panic!()
};
assert_eq!(err.error_code(), TarantoolErrorCode::NoSuchProc as u32);
#[rustfmt::skip]
assert_eq!(err.to_string(), "NoSuchProc: Procedure 'unexistent_proc' is not defined");
}
#[crate::test(tarantool = "crate")]
async fn eval() {
let client = test_client().await;
let result = client
.eval("return ...", &(1, 2))
.timeout(Duration::from_secs(3))
.await
.unwrap();
assert_eq!(result.decode::<(i32, i32)>().unwrap(), (1, 2));
let err = client
.eval("box.error(420)", &())
.timeout(Duration::from_secs(3))
.await
.unwrap_err();
let err = error::Error::from(err);
let error::Error::Remote(err) = err else {
panic!()
};
assert_eq!(err.error_code(), 420);
}
#[crate::test(tarantool = "crate")]
async fn client_count_regression() {
let client = test_client().await;
client.0.borrow_mut().stream.close().unwrap();
fiber::reschedule();
let fiber_id = client.0.borrow().sender_fiber_id.unwrap();
let fiber_exists = fiber::wakeup(fiber_id);
debug_assert!(fiber_exists);
fiber::reschedule();
assert_eq!(Rc::strong_count(&client.0), 1);
let client_clone = client.clone();
assert_eq!(Rc::strong_count(&client.0), 2);
drop(client_clone);
assert_eq!(Rc::strong_count(&client.0), 1);
client.check_state().unwrap_err();
}
#[crate::test(tarantool = "crate")]
async fn concurrent_messages_one_fiber() {
let client = test_client().await;
let mut ping_futures = vec![];
for _ in 0..10 {
ping_futures.push(client.ping());
}
for res in futures::future::join_all(ping_futures).await {
res.unwrap();
}
}
#[crate::test(tarantool = "crate")]
async fn data_always_present_in_response() {
let client = test_client().await;
client.eval("return", &()).await.unwrap();
client.call("LUA", &("return",)).await.unwrap();
}
#[crate::test(tarantool = "crate")]
async fn big_data() {
use crate::tuple::RawByteBuf;
#[crate::proc(tarantool = "crate")]
fn proc_big_data<'a>(s: &'a serde_bytes::Bytes) -> usize {
s.len() + 17
}
let proc = crate::define_stored_proc_for_tests!(proc_big_data);
let client = test_client().await;
#[cfg(target_os = "macos")]
const N: u32 = 0x1fff_ff69;
#[cfg(not(target_os = "macos"))]
const N: u32 = 0x6fff_ff69;
#[allow(clippy::uninit_vec)]
let s = unsafe {
let buf_size = (N + 6) as usize;
let mut data = Vec::<u8>::with_capacity(buf_size);
data.set_len(buf_size);
data[0] = b'\x91';
data[1] = b'\xc6'; data[2..6].copy_from_slice(&N.to_be_bytes());
RawByteBuf::from(data)
};
let t0 = std::time::Instant::now();
let t = client.call(&proc, &s).await.unwrap();
dbg!(t0.elapsed());
if let Ok((len,)) = t.decode::<(u32,)>() {
assert_eq!(len, N + 17);
} else {
let ((len,),): ((u32,),) = t.decode().unwrap();
assert_eq!(len, N + 17);
}
}
#[cfg(feature = "picodata")]
#[crate::test(tarantool = "crate")]
async fn md5_auth_method() {
use crate::auth::AuthMethod;
use std::time::Duration;
let username = "Johnny";
let password = "B. Goode";
crate::lua_state()
.exec_with(
"local username, password = ...
box.cfg { }
box.schema.user.create(username, { if_not_exists = true, auth_type = 'md5', password = password })
box.schema.user.grant(username, 'super', nil, nil, { if_not_exists = true })",
(username, password),
)
.unwrap();
{
let client = Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some((username.into(), password.into())),
auth_method: AuthMethod::Md5,
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap();
client
.eval("print('\\x1b[32mit works!\\x1b[0m')", &())
.await
.unwrap();
}
{
let client = Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some((username.into(), "wrong password".into())),
auth_method: AuthMethod::Md5,
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap();
let err = client.eval("return", &()).await.unwrap_err().to_string();
#[rustfmt::skip]
assert_eq!(err, "server responded with error: PasswordMismatch: User not found or supplied credentials are invalid");
}
{
let client = Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some((username.into(), password.into())),
auth_method: AuthMethod::ChapSha1,
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap();
let err = client.eval("return", &()).await.unwrap_err().to_string();
#[rustfmt::skip]
assert_eq!(err, "server responded with error: PasswordMismatch: User not found or supplied credentials are invalid");
}
crate::lua_state()
.exec_with(
"local username = ...
box.cfg { auth_type = 'chap-sha1' }
box.schema.user.drop(username)",
username,
)
.unwrap();
}
#[cfg(feature = "picodata")]
#[crate::test(tarantool = "crate")]
async fn ldap_auth_method() {
use crate::auth::AuthMethod;
use std::time::Duration;
let username = "Johnny";
let password = "B. Goode";
let _guard = crate::unwrap_ok_or!(
crate::test::util::setup_ldap_auth(username, password),
Err(e) => {
println!("{e}, skipping ldap test");
return;
}
);
{
let client = Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some((username.into(), password.into())),
auth_method: AuthMethod::Ldap,
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap();
client
.eval("print('\\x1b[32mit works!\\x1b[0m')", &())
.await
.unwrap();
}
{
let client = Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some((username.into(), "wrong password".into())),
auth_method: AuthMethod::Ldap,
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap();
let err = client.eval("return", &()).await.unwrap_err().to_string();
#[rustfmt::skip]
assert_eq!(err, "server responded with error: PasswordMismatch: User not found or supplied credentials are invalid");
}
{
let client = Client::connect_with_config(
"localhost",
listen_port(),
protocol::Config {
creds: Some((username.into(), password.into())),
auth_method: AuthMethod::ChapSha1,
..Default::default()
},
)
.timeout(Duration::from_secs(3))
.await
.unwrap();
let err = client.eval("return", &()).await.unwrap_err().to_string();
#[rustfmt::skip]
assert_eq!(err, "server responded with error: PasswordMismatch: User not found or supplied credentials are invalid");
}
}
#[crate::test(tarantool = "crate")]
async fn extended_error_info() {
let client = test_client().await;
let res = client
.eval(
"error1 = box.error.new(box.error.UNSUPPORTED, 'this', 'that')
error2 = box.error.new('MyCode', 'my message')
error3 = box.error.new('MyOtherCode', 'my other message')
error2:set_prev(error3)
error1:set_prev(error2)
error1:raise()",
&(),
)
.timeout(Duration::from_secs(3))
.await;
let error::Error::Remote(e) = error::Error::from(res.unwrap_err()) else {
panic!();
};
assert_eq!(e.error_code(), TarantoolErrorCode::Unsupported as u32);
assert_eq!(e.message(), "this does not support that");
assert_eq!(e.error_type(), "ClientError");
assert_eq!(e.file(), Some("eval"));
assert_eq!(e.line(), Some(1));
assert_eq!(e.fields().len(), 0);
let e = e.cause().unwrap();
assert_eq!(e.error_code(), 0);
assert_eq!(e.message(), "my message");
assert_eq!(e.error_type(), "CustomError");
assert_eq!(e.file(), Some("eval"));
assert_eq!(e.line(), Some(2));
assert_eq!(e.fields().len(), 1);
assert_eq!(e.fields()["custom_type"], rmpv::Value::from("MyCode"));
let e = e.cause().unwrap();
assert_eq!(e.error_code(), 0);
assert_eq!(e.message(), "my other message");
assert_eq!(e.error_type(), "CustomError");
assert_eq!(e.file(), Some("eval"));
assert_eq!(e.line(), Some(3));
assert_eq!(e.fields().len(), 1);
assert_eq!(e.fields()["custom_type"], rmpv::Value::from("MyOtherCode"));
assert!(e.cause().is_none());
}
#[crate::test(tarantool = "crate")]
async fn custom_error_code_from_proc() {
#[crate::proc(tarantool = "crate")]
fn proc_custom_error_code() -> Result<(), crate::error::Error> {
Err(BoxError::new(666666_u32, "na ah").into())
}
let error_line = line!() - 2; let proc = crate::define_stored_proc_for_tests!(proc_custom_error_code);
let client = test_client().await;
let res = client
.call(&proc, &())
.timeout(Duration::from_secs(3))
.await;
let e = match error::Error::from(res.unwrap_err()) {
error::Error::Remote(e) => e,
other => {
panic!("unexpected error: {}", other);
}
};
assert_eq!(e.error_code(), 666666);
assert_eq!(e.message(), "na ah");
assert_eq!(e.error_type(), "ClientError");
assert_eq!(e.file(), Some(file!()));
assert_eq!(e.line(), Some(error_line));
}
#[crate::test(tarantool = "crate")]
async fn check_error_location() {
let error_line = line!() + 1;
#[crate::proc(tarantool = "crate")]
fn proc_check_error_location_implicit() -> Result<(), error::Error> {
Err(error::Error::other("not good"))
}
let proc = crate::define_stored_proc_for_tests!(proc_check_error_location_implicit);
let client = test_client().await;
let res = client
.call(&proc, &())
.timeout(Duration::from_secs(3))
.await;
let e = match error::Error::from(res.unwrap_err()) {
error::Error::Remote(e) => e,
other => {
panic!("unexpected error: {}", other);
}
};
assert_eq!(e.file(), Some(file!()));
assert_eq!(e.line(), Some(error_line));
}
}