vineyard/client/
ipc_client.rs

1// Copyright 2020-2023 Alibaba Group Holding Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::io;
17use std::net::Shutdown;
18use std::os::unix::net::UnixStream;
19use std::sync::Arc;
20use std::sync::Mutex;
21
22use arrow_buffer::Buffer;
23use parking_lot::ReentrantMutex;
24use parking_lot::ReentrantMutexGuard;
25
26use crate::common::util::arrow::*;
27use crate::common::util::protocol::*;
28use crate::common::util::status::*;
29use crate::common::util::typename::*;
30use crate::common::util::uuid::*;
31
32use super::client::*;
33use super::ds::blob::{Blob, BlobWriter};
34use super::ds::object::*;
35use super::ds::object_meta::ObjectMeta;
36use super::io::*;
37
38mod memory {
39
40    use std::collections::{hash_map, HashMap, HashSet};
41    use std::fs::File;
42    use std::os::fd::{AsRawFd, FromRawFd};
43    use std::os::unix::net::UnixStream;
44
45    use memmap2::{Mmap, MmapMut, MmapOptions};
46
47    use crate::common::memory::fling::recv_fd;
48    use crate::common::util::status::*;
49
50    #[derive(Debug)]
51    pub struct MmapEntry {
52        fd: File,
53        ro_pointer: *const u8,
54        rw_pointer: *mut u8,
55        length: usize,
56
57        mmap_readonly: Option<Mmap>,
58        mmap_readwrite: Option<MmapMut>,
59    }
60
61    impl MmapEntry {
62        pub fn new(fd: i32, map_size: usize, realign: bool) -> Self {
63            let size = if realign {
64                map_size - std::mem::size_of::<usize>()
65            } else {
66                map_size
67            };
68            return MmapEntry {
69                fd: unsafe { File::from_raw_fd(fd) },
70                ro_pointer: std::ptr::null(),
71                rw_pointer: std::ptr::null_mut(),
72                length: size,
73                mmap_readonly: None,
74                mmap_readwrite: None,
75            };
76        }
77
78        #[allow(dead_code)]
79        pub fn fd(&self) -> i32 {
80            return self.fd.as_raw_fd();
81        }
82
83        pub fn map(&mut self) -> Result<*const u8> {
84            if self.ro_pointer.is_null() {
85                let mmap = unsafe { MmapOptions::new().len(self.length).offset(0).map(&self.fd) }?;
86                self.ro_pointer = mmap.as_ptr();
87                self.mmap_readonly = Some(mmap);
88            }
89            return Ok(self.ro_pointer);
90        }
91
92        pub fn map_mut(&mut self) -> Result<*mut u8> {
93            if self.rw_pointer.is_null() {
94                let mut mmap = unsafe { MmapOptions::new().len(self.length).map_mut(&self.fd) }?;
95                self.rw_pointer = mmap.as_mut_ptr();
96                self.mmap_readwrite = Some(mmap);
97            }
98            return Ok(self.rw_pointer);
99        }
100    }
101
102    #[derive(Debug)]
103    pub struct MmapManager {
104        entries: HashMap<i32, MmapEntry>,
105    }
106
107    impl MmapManager {
108        pub fn new() -> Self {
109            return MmapManager {
110                entries: HashMap::new(),
111            };
112        }
113
114        pub fn mmap(
115            &mut self,
116            stream: &UnixStream,
117            fd: i32,
118            map_size: usize,
119            realign: bool,
120        ) -> Result<*const u8> {
121            if let hash_map::Entry::Vacant(entry) = self.entries.entry(fd) {
122                entry.insert(MmapEntry::new(recv_fd(stream)?, map_size, realign));
123            }
124            match self.entries.get_mut(&fd) {
125                Some(entry) => {
126                    return entry.map();
127                }
128                None => {
129                    return Err(VineyardError::invalid(format!(
130                        "Failed to find mmap entry for fd even after insert: {}",
131                        fd
132                    )));
133                }
134            }
135        }
136
137        pub fn mmap_mut(
138            &mut self,
139            stream: &UnixStream,
140            fd: i32,
141            map_size: usize,
142            realign: bool,
143        ) -> Result<*mut u8> {
144            if let hash_map::Entry::Vacant(entry) = self.entries.entry(fd) {
145                entry.insert(MmapEntry::new(recv_fd(stream)?, map_size, realign));
146            }
147            match self.entries.get_mut(&fd) {
148                Some(entry) => {
149                    return entry.map_mut();
150                }
151                None => {
152                    return Err(VineyardError::invalid(format!(
153                        "Failed to find mmap entry for fd even after insert: {}",
154                        fd
155                    )));
156                }
157            }
158        }
159
160        #[allow(dead_code)]
161        pub fn exists(&self, fd: i32) -> i32 {
162            if self.entries.contains_key(&fd) {
163                return -1;
164            } else {
165                return fd;
166            }
167        }
168
169        #[allow(dead_code)]
170        pub fn deduplicate(&self, fd: i32, fds: &mut Vec<i32>, dedup: &mut HashSet<i32>) {
171            if !dedup.contains(&fd) && !self.entries.contains_key(&fd) {
172                fds.push(fd);
173                dedup.insert(fd);
174            }
175        }
176    }
177}
178
179#[derive(Debug)]
180pub struct IPCClient {
181    connected: bool,
182    pub ipc_socket: String,
183    pub rpc_endpoint: String,
184    pub instance_id: InstanceID,
185    pub server_version: String,
186    pub support_rpc_compression: bool,
187
188    stream: UnixStream,
189    lock: ReentrantMutex<()>,
190    mmap: memory::MmapManager,
191}
192
193impl Drop for IPCClient {
194    fn drop(&mut self) {
195        self.disconnect();
196    }
197}
198
199unsafe impl Send for IPCClient {}
200unsafe impl Sync for IPCClient {}
201
202impl Client for IPCClient {
203    fn disconnect(&mut self) {
204        if !self.connected() {
205            return;
206        }
207        self.connected = false;
208        if let Ok(message_out) = write_exit_request() {
209            if let Err(err) = self.do_write(&message_out) {
210                error!("Failed to disconnect the client: {}", err);
211            }
212        }
213        self.stream.shutdown(Shutdown::Both).unwrap_or_else(|e| {
214            error!("Failed to shutdown IPCClient stream: {}", e);
215        });
216    }
217
218    #[cfg(not(feature = "nightly"))]
219    fn connected(&mut self) -> bool {
220        return self.connected;
221    }
222
223    #[cfg(feature = "nightly")]
224    fn connected(&mut self) -> bool {
225        if self.stream.set_nonblocking(true).is_err() {
226            return false;
227        }
228        match self.stream.peek(&mut [0]) {
229            Ok(_) => {
230                let _ = self.stream.set_nonblocking(false);
231                return true;
232            }
233            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
234                let _ = self.stream.set_nonblocking(false);
235                return true;
236            }
237            Err(_) => {
238                self.connected = false;
239                return false;
240            }
241        }
242    }
243
244    fn ensure_connect(&mut self) -> Result<ReentrantMutexGuard<'_, ()>> {
245        if !self.connected() {
246            return Err(VineyardError::io_error("client not connected"));
247        }
248        return Ok(self.lock.lock());
249    }
250
251    fn do_read(&mut self) -> Result<String> {
252        return do_read(&mut self.stream);
253    }
254
255    fn do_write(&mut self, message_out: &str) -> Result<()> {
256        return do_write(&mut self.stream, message_out);
257    }
258
259    fn instance_id(&self) -> InstanceID {
260        return self.instance_id;
261    }
262
263    fn create_metadata(&mut self, metadata: &ObjectMeta) -> Result<ObjectMeta> {
264        let mut meta = metadata.clone();
265        meta.set_instance_id(self.instance_id());
266        meta.set_transient(true);
267        if !meta.has_key("nbytes") {
268            meta.set_nbytes(0usize);
269        }
270        if meta.is_incomplete() {
271            let _ = self.sync_metadata();
272        }
273        let (id, signature, instance_id) = self.create_data(meta.meta_data())?;
274        meta.set_id(id);
275        meta.set_signature(signature);
276        meta.set_instance_id(instance_id);
277        if meta.is_incomplete() {
278            meta = self.get_metadata(id)?;
279        }
280        meta.set_client(self);
281        return Ok(meta);
282    }
283
284    fn get_metadata(&mut self, id: ObjectID) -> Result<ObjectMeta> {
285        let data = self.get_data(id, false, false)?;
286        let mut meta = ObjectMeta::new(self, data)?;
287
288        let buffer_id_vec: Vec<ObjectID> =
289            meta.get_buffers().buffer_ids().iter().cloned().collect();
290        let buffers = self.get_buffers(&buffer_id_vec, false)?;
291        for (buffer_id, buffer) in buffers {
292            meta.set_buffer(buffer_id, buffer)?;
293        }
294        return Ok(meta);
295    }
296
297    fn get_metadata_batch(&mut self, ids: &[ObjectID]) -> Result<Vec<ObjectMeta>> {
298        let data_vec = self.get_data_batch(ids)?;
299        let mut metadatas = Vec::new();
300        let mut buffer_id_vec: Vec<ObjectID> = Vec::new();
301        for data in data_vec {
302            let meta = ObjectMeta::new(self, data)?;
303            buffer_id_vec.extend(meta.get_buffers().buffer_ids());
304            metadatas.push(meta);
305        }
306
307        let buffers = self.get_buffers(&buffer_id_vec, false)?;
308        for meta in metadatas.iter_mut() {
309            for buffer_id in meta.get_buffers().buffer_ids().clone() {
310                if let Some(buffer) = buffers.get(&buffer_id) {
311                    meta.set_buffer(buffer_id, buffer.clone())?;
312                }
313            }
314        }
315        return Ok(metadatas);
316    }
317}
318
319impl IPCClient {
320    #[allow(clippy::should_implement_trait)]
321    pub fn default() -> Result<IPCClient> {
322        let default_ipc_socket = std::env::var(VINEYARD_IPC_SOCKET_KEY)?;
323        return IPCClient::connect(&default_ipc_socket);
324    }
325
326    pub fn connect(socket: &str) -> Result<IPCClient> {
327        let mut stream = connect_ipc_socket_retry(&socket)?;
328        let message_out = write_register_request(RegisterRequest {
329            version: VERSION.into(),
330            store_type: "Normal".into(),
331            session_id: 0,
332            username: String::new(),
333            password: String::new(),
334            support_rpc_compression: false,
335        })?;
336        do_write(&mut stream, &message_out)?;
337        let reply = read_register_reply(&do_read(&mut stream)?)?;
338        return Ok(IPCClient {
339            connected: true,
340            ipc_socket: reply.ipc_socket,
341            rpc_endpoint: reply.rpc_endpoint,
342            instance_id: reply.instance_id,
343            server_version: reply.version,
344            support_rpc_compression: reply.support_rpc_compression,
345            stream: stream,
346            lock: ReentrantMutex::new(()),
347            mmap: memory::MmapManager::new(),
348        });
349    }
350
351    pub fn create_blob(&mut self, size: usize) -> Result<BlobWriter> {
352        let (id, buffer) = self.create_buffer(size)?;
353        return Ok(BlobWriter::new(id, buffer));
354    }
355
356    pub fn get_blob(&mut self, id: ObjectID) -> Result<Blob> {
357        let buffer = self.get_buffer(id, false)?;
358        let size = match &buffer {
359            Some(buffer) => buffer.len(),
360            None => 0,
361        };
362        let mut meta = ObjectMeta::new_from_typename(typename::<Blob>());
363        meta.set_id(id);
364        meta.set_instance_id(self.instance_id());
365        meta.set_or_add_buffer(id, buffer.clone())?;
366        return Ok(Blob::new(meta, size, buffer));
367    }
368
369    fn create_buffer(&mut self, size: usize) -> Result<(ObjectID, Option<Buffer>)> {
370        if size == 0 {
371            return Ok((empty_blob_id(), Some(arrow_buffer_null())));
372        }
373        let _ = self.ensure_connect()?;
374        let message_out = write_create_buffer_request(size)?;
375        self.do_write(&message_out)?;
376        let reply = read_create_buffer_reply(&self.do_read()?)?;
377        if reply.payload.data_size == 0 {
378            return Ok((reply.id, Some(arrow_buffer_null())));
379        }
380        let pointer = self.mmap.mmap_mut(
381            &self.stream,
382            reply.payload.store_fd,
383            reply.payload.map_size,
384            true,
385        )?;
386        let buffer =
387            arrow_buffer_with_offset(pointer, reply.payload.data_offset, reply.payload.data_size);
388        return Ok((reply.id, Some(buffer)));
389    }
390
391    fn get_buffer(&mut self, id: ObjectID, unsafe_: bool) -> Result<Option<Buffer>> {
392        let buffers = self.get_buffers(&[id], unsafe_)?;
393        return buffers
394            .get(&id)
395            .cloned()
396            .ok_or(VineyardError::object_not_exists(format!(
397                "buffer {} doesn't exist",
398                id
399            )));
400    }
401
402    fn get_buffers(
403        &mut self,
404        ids: &[ObjectID],
405        unsafe_: bool,
406    ) -> Result<HashMap<ObjectID, Option<Buffer>>> {
407        let _ = self.ensure_connect()?;
408        let message_out = write_get_buffers_request(&ids, unsafe_)?;
409        self.do_write(&message_out)?;
410        let reply = read_get_buffers_reply(&self.do_read()?)?;
411
412        let mut buffers = HashMap::new();
413        for payload in reply.payloads {
414            if payload.data_size == 0 {
415                buffers.insert(payload.object_id, Some(arrow_buffer_null()));
416                continue;
417            }
418            let pointer = self
419                .mmap
420                .mmap(&self.stream, payload.store_fd, payload.map_size, true)?;
421            let buffer = arrow_buffer_with_offset(pointer, payload.data_offset, payload.data_size);
422            buffers.insert(payload.object_id, Some(buffer));
423        }
424        return Ok(buffers);
425    }
426
427    pub fn get<T: Object + Create>(&mut self, id: ObjectID) -> Result<Box<T>> {
428        let meta = self.get_metadata(id)?;
429        let mut object = T::create();
430        object.construct(meta)?;
431        return downcast_object(object);
432    }
433
434    pub fn fetch_and_get<T: Object + Create>(&mut self, id: ObjectID) -> Result<Box<dyn Object>> {
435        let meta = self.fetch_and_get_metadata(id)?;
436        let mut object = T::create();
437        object.construct(meta)?;
438        return Ok(object);
439    }
440}
441
442pub struct IPCClientManager {}
443
444impl IPCClientManager {
445    pub fn get_default() -> Result<Arc<Mutex<IPCClient>>> {
446        let default_ipc_socket = std::env::var(VINEYARD_IPC_SOCKET_KEY)?;
447        return IPCClientManager::get(default_ipc_socket);
448    }
449
450    pub fn get<S: Into<String>>(socket: S) -> Result<Arc<Mutex<IPCClient>>> {
451        let mut clients = IPCClientManager::get_clients().lock()?;
452        let socket = socket.into();
453        if let Some(client) = clients.get(&socket) {
454            if client.lock()?.connected() {
455                return Ok(client.clone());
456            }
457        }
458        let client = Arc::new(Mutex::new(IPCClient::connect(&socket)?));
459        clients.insert(socket, client.clone());
460        return Ok(client);
461    }
462
463    pub fn close<S: Into<String>>(socket: S) -> Result<()> {
464        let mut clients = IPCClientManager::get_clients().lock()?;
465        let socket = socket.into();
466        if let Some(client) = clients.get(&socket) {
467            if Arc::strong_count(client) == 1 {
468                clients.remove(&socket);
469            }
470            return Ok(());
471        } else {
472            return Err(VineyardError::invalid(format!(
473                "Failed to close the client due to the unknown socket: {}",
474                socket
475            )));
476        }
477    }
478
479    fn get_clients() -> &'static Arc<Mutex<HashMap<String, Arc<Mutex<IPCClient>>>>> {
480        lazy_static! {
481            static ref CONNECTED_CLIENTS: Arc<Mutex<HashMap<String, Arc<Mutex<IPCClient>>>>> =
482                Arc::new(Mutex::new(HashMap::new()));
483        }
484        return &CONNECTED_CLIENTS;
485    }
486}
487
488#[macro_export]
489macro_rules! get {
490    ($client: ident, $object_ty: ty, $object_id: expr) => {
491        $client.get::<$object_ty>($object_id)
492    };
493}
494
495#[macro_export]
496macro_rules! put {
497    ($client: expr, $builder_ty: ty, $($arg: expr),* $(,)?) => {
498        match <$builder_ty>::new($client, $($arg),*) {
499            Ok(builder) => builder.seal($client),
500            Err(e) => Err(e),
501        }
502    };
503}