rxqlite_tests_common/
lib.rs

1use rcgen::generate_simple_self_signed;
2use std::collections::HashMap;
3use std::path::{Path, PathBuf};
4use std::process::{/*Stdio ,*/ Child, Command};
5
6#[cfg(target_os = "linux")]
7use nix::sys::signal::{kill, Signal};
8#[cfg(target_os = "linux")]
9use nix::unistd::Pid;
10
11use std::sync::{
12    atomic::{AtomicU16, Ordering},
13    Arc,
14};
15
16#[cfg(target_os = "windows")]
17const BASE_PORT: u16 = 21000;
18
19#[cfg(target_os = "linux")]
20const BASE_PORT: u16 = 22000;
21
22pub struct PortManager {
23    next_port: Arc<AtomicU16>,
24}
25
26impl Default for PortManager {
27    fn default() -> Self {
28        Self {
29            next_port: Arc::new(AtomicU16::new(BASE_PORT)),
30        }
31    }
32}
33
34impl PortManager {
35    pub fn reserve(&self, port_count: usize) -> u16 {
36        self.next_port.fetch_add(port_count as _, Ordering::Relaxed)
37    }
38}
39
40pub static PORT_MANAGER: state::InitCell<PortManager> = state::InitCell::new();
41
42#[derive(Default, Clone)]
43pub struct TestTlsConfig {
44    pub accept_invalid_certificates: bool,
45}
46
47impl TestTlsConfig {
48    pub fn accept_invalid_certificates(mut self, accept_invalid_certificates: bool) -> Self {
49        self.accept_invalid_certificates = accept_invalid_certificates;
50        self
51    }
52}
53
54pub struct Instance {
55    pub node_id: u64,
56    pub child: Option<Child>,
57    pub data_path: PathBuf,
58    pub http_addr: String,
59    pub notifications_addr: String,
60}
61
62pub struct TestClusterManager {
63    pub instances: HashMap<u64, Instance>,
64    pub tls_config: Option<TestTlsConfig>,
65    pub working_directory: std::path::PathBuf,
66    pub executable: String,
67    pub keep_temp_directories: bool,
68    pub key_path: String, // empty if not used
69    pub cert_path: String, // empty if not used
70}
71
72impl TestClusterManager {
73    pub fn new<P: AsRef<Path>>(
74        instance_count: usize,
75        working_directory: P,
76        executable_path: P,
77        host: &str,
78        tls_config: Option<TestTlsConfig>,
79        data_dir_base_name: Option<&str>,
80    ) -> anyhow::Result<Self> {
81        let data_dir_base_name = data_dir_base_name.map(|x|x.to_string()).unwrap_or(String::from("data-"));
82        
83        println!("creating working_directory:{}",working_directory.as_ref().display());
84        std::fs::create_dir_all(&working_directory)?;
85        println!("created working_directory:{}",working_directory.as_ref().display());
86        let base_port = PORT_MANAGER
87            .get_or_init(Default::default)
88            .reserve(instance_count * 3);
89
90        let (cert_path, key_path, accept_invalid_certificates) =
91            if let Some(tls_config) = tls_config.as_ref() {
92                let certs_path = working_directory.as_ref().join("certs-test");
93                std::fs::create_dir_all(&certs_path)?;
94                let subject_alt_names = vec![host.to_string()];
95
96                let cert = generate_simple_self_signed(subject_alt_names)?;
97                let key = cert.serialize_private_key_pem();
98                let cert = cert.serialize_pem()?;
99                let key_path = certs_path.join("rxqlited.key").to_path_buf();
100                let cert_path = certs_path.join("rxqlited.cert").to_path_buf();
101
102                std::fs::write(&key_path, key.as_bytes())?;
103                std::fs::write(&cert_path, cert.as_bytes())?;
104                (
105                    cert_path.to_str().unwrap().to_string(),
106                    key_path.to_str().unwrap().to_string(),
107                    tls_config.accept_invalid_certificates,
108                )
109            } else {
110                (Default::default(), Default::default(), false)
111            };
112
113        let mut instances: HashMap<u64, Instance> = Default::default();
114
115        let executable = executable_path.as_ref();
116        let executable = executable.to_str().unwrap().to_string();
117        for i in 0..instance_count {
118            let http_port = base_port + (i * 3) as u16;
119            let rpc_port = base_port + ((i * 3) + 1) as u16;
120            let notifications_port = base_port + ((i * 3) + 2) as u16;
121            let http_addr = format!("{}:{}", host, http_port);
122            let rpc_addr = format!("{}:{}", host, rpc_port);
123            let notifications_addr = format!("{}:{}", host, notifications_port);
124
125            let mut cmd = Command::new(&executable);
126
127            cmd
128                //.stderr(Stdio::null())
129                //.stdout(Stdio::null())
130                //.env_clear()
131                .arg("--test-node")
132                .arg("--id")
133                .arg(&(i + 1).to_string())
134                .arg("--http-addr")
135                .arg(&http_addr)
136                .arg("--rpc-addr")
137                .arg(&rpc_addr)
138                .arg("--notifications-addr")
139                .arg(&notifications_addr)
140                .current_dir(&working_directory);
141
142            if tls_config.is_some() {
143                cmd.arg("--key-path")
144                    .arg(&key_path)
145                    .arg("--cert-path")
146                    .arg(&cert_path);
147                if accept_invalid_certificates {
148                    cmd.arg("--accept-invalid-certificates");
149                }
150            }
151            if i == 0 {
152                cmd.arg("--leader");
153                for j in 1..instance_count {
154                    cmd.arg("--member");
155                    let http_port = base_port + (j * 3) as u16;
156                    let rpc_port = base_port + ((j * 3) + 1) as u16;
157                    let http_addr = format!("{}:{}", host, http_port);
158                    let rpc_addr = format!("{}:{}", host, rpc_port);
159
160                    cmd.arg(format!("{};{};{}", j + 1, http_addr, rpc_addr));
161                }
162            }
163            println!("spawning child {}:{:?}",i,cmd);
164            let child = cmd.spawn()?;
165            println!("spawned child {}:{:?}",i,cmd);
166            let node_id: u64 = (i + 1) as _;
167
168            instances.insert(
169                node_id,
170                Instance {
171                    http_addr,
172                    notifications_addr,
173                    node_id,
174                    child: Some(child),
175                    data_path: working_directory.as_ref().join(format!("{}{}",data_dir_base_name, i + 1)),
176                },
177            );
178        }
179        Ok(Self {
180            instances,
181            tls_config,
182            working_directory: working_directory.as_ref().to_path_buf(),
183            executable,
184            keep_temp_directories: false,
185            key_path,
186            cert_path,
187        })
188    }
189    pub fn kill_all(&mut self) -> anyhow::Result<()> {
190        for (_, instance) in self.instances.iter_mut() {
191            if let Some(child) = instance.child.as_mut() {
192#[cfg(target_os = "linux")]
193{
194              let pid = child.id() as i32;
195              kill(Pid::from_raw(pid), Signal::SIGINT)?;
196}
197#[cfg(not(target_os = "linux"))]
198{
199                child.kill()?;
200}
201            }
202        }
203        loop {
204            let mut done = true;
205            for (_, instance) in self.instances.iter_mut() {
206                if let Some(child) = instance.child.as_mut() {
207                    if let Ok(Some(_exit_status)) = child.try_wait() {
208                        instance.child.take();
209                    } else {
210                        done = false;
211                    }
212                }
213            }
214            if done {
215                break;
216            }
217            std::thread::sleep(std::time::Duration::from_millis(250));
218        }
219        Ok(())
220    }
221    pub fn start(&mut self) -> anyhow::Result<()> {
222        for (node_id, instance) in self.instances.iter_mut() {
223            let mut cmd = Command::new(&self.executable);
224
225            cmd
226              .arg("--test-node")
227              .arg("--id")
228                .arg(&node_id.to_string())
229                .current_dir(&self.working_directory);
230            let child = cmd.spawn()?;
231            instance.child = Some(child);
232        }
233        Ok(())
234    }
235    pub fn clean_directories(&self) -> anyhow::Result<()> {
236        if self.keep_temp_directories {
237            return Ok(());
238        }
239        if let Err(err) = std::fs::remove_dir_all(&self.working_directory) {
240            eprintln!(
241                "error removing directory : {}({})",
242                self.working_directory.display(),
243                err
244            );
245            Err(err.into())
246        } else {
247            Ok(())
248        }
249    }
250}
251
252impl Drop for TestClusterManager {
253    fn drop(&mut self) {
254        let _ = self.kill_all();
255        let _ = self.clean_directories();
256    }
257}