1use std::env;
14use std::io;
15use std::io::BufRead;
16use std::io::Read;
17use std::io::Write as _;
18use std::net::TcpListener;
19use std::net::TcpStream;
20use std::panic;
21use std::process;
22use std::process::Child;
23use std::process::Command;
24use std::process::ExitCode;
25use std::process::Stdio;
26use std::process::Termination;
27
28use crate::cmdline;
29use crate::error::Result;
30
31
32const OCCURS_ENV: &str = "TEST_FORK_OCCURS";
33const OCCURS_TERM_LENGTH: usize = 17; pub fn fork<F, T>(fork_id: &str, test_name: &str, test: F) -> Result<()>
69where
70 F: Fn() -> T,
74 T: Termination,
75{
76 fn supervise_child(child: &mut Child) {
77 let status = child.wait().expect("failed to wait for child");
78 assert!(
79 status.success(),
80 "child exited unsuccessfully with {}",
81 status
82 );
83 }
84
85 fn no_configure_child(_child: &mut Command) {}
86
87 fork_int(
88 test_name,
89 fork_id,
90 no_configure_child,
91 supervise_child,
92 test,
93 )
94}
95
96pub fn fork_in_out<F, T>(fork_id: &str, test_name: &str, test: F, data: &mut [u8]) -> Result<()>
101where
102 F: Fn(&mut [u8]) -> T,
103 T: Termination,
104{
105 let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind TCP socket");
106 let addr = listener.local_addr().unwrap();
107 let data_len = data.len();
108
109 fork_int(
110 test_name,
111 fork_id,
112 |cmd| {
113 cmd.env(fork_id, addr.to_string());
114 },
115 |child| {
116 let (mut stream, _addr) = listener
117 .accept()
118 .expect("failed to listen for child connection");
119 let () = stream
120 .write_all(data)
121 .expect("failed to send data to child");
122 let () = stream
123 .read_exact(data)
124 .expect("failed to receive data from child");
125 let status = child.wait().expect("failed to wait for child");
126 assert!(
127 status.success(),
128 "child exited unsuccessfully with {}",
129 status
130 );
131 },
132 || {
133 let addr = env::var(fork_id).unwrap_or_else(|err| {
134 panic!("failed to retrieve {fork_id} environment variable: {err}")
135 });
136 let mut stream =
137 TcpStream::connect(addr).expect("failed to establish connection with parent");
138
139 let mut data = Vec::with_capacity(data_len);
140 let () = unsafe { data.set_len(data_len) };
144
145 let () = stream
146 .read_exact(&mut data)
147 .expect("failed to receive data from parent");
148 let status = test(&mut data);
149 let () = stream
150 .write_all(&data)
151 .expect("failed to send data to parent");
152 status
153 },
154 )
155}
156
157pub(crate) fn fork_int<M, P, C, R, T>(
158 test_name: &str,
159 fork_id: &str,
160 process_modifier: M,
161 in_parent: P,
162 in_child: C,
163) -> Result<R>
164where
165 M: FnOnce(&mut process::Command),
166 P: FnOnce(&mut Child) -> R,
167 T: Termination,
168 C: FnOnce() -> T,
169{
170 let mut return_value = None;
173 let mut process_modifier = Some(process_modifier);
174 let mut in_parent = Some(in_parent);
175 let mut in_child = Some(in_child);
176
177 fork_impl(
178 test_name,
179 fork_id,
180 &mut |cmd| process_modifier.take().unwrap()(cmd),
181 &mut |child| return_value = Some(in_parent.take().unwrap()(child)),
182 &mut || in_child.take().unwrap()(),
183 )
184 .map(|_| return_value.unwrap())
185}
186
187fn fork_impl<T: Termination>(
188 test_name: &str,
189 fork_id: &str,
190 process_modifier: &mut dyn FnMut(&mut process::Command),
191 in_parent: &mut dyn FnMut(&mut Child),
192 in_child: &mut dyn FnMut() -> T,
193) -> Result<()> {
194 let mut occurs = env::var(OCCURS_ENV).unwrap_or_else(|_| String::new());
195 if occurs.contains(fork_id) {
196 match panic::catch_unwind(panic::AssertUnwindSafe(in_child)) {
197 Ok(test_result) => {
198 let rc = if test_result.report() == ExitCode::SUCCESS {
199 0
200 } else {
201 70
202 };
203 process::exit(rc)
204 }
205 Err(_) => process::exit(70 ),
211 }
212 } else {
213 if occurs.len() > 16 * OCCURS_TERM_LENGTH {
215 panic!("test-fork: Not forking due to >=16 levels of recursion");
216 }
217
218 struct KillOnDrop(Child);
219 impl Drop for KillOnDrop {
220 fn drop(&mut self) {
221 let _ = self.0.kill();
223
224 fn drain(read: &mut dyn Read, stderr: bool) {
232 let mut buf = Vec::new();
233 let mut br = io::BufReader::new(read);
234 loop {
235 if br.read_until(b'\n', &mut buf).is_err() {
240 break;
241 }
242 if buf.is_empty() {
243 break;
244 }
245
246 let s = String::from_utf8_lossy(&buf);
249 if stderr {
250 eprint!("{s}");
251 } else {
252 print!("{s}");
253 }
254 buf.clear();
255 }
256 }
257
258 if let Some(stdout) = self.0.stdout.as_mut() {
259 let () = drain(stdout, false);
260 }
261
262 if let Some(stderr) = self.0.stderr.as_mut() {
263 let () = drain(stderr, true);
264 }
265 }
266 }
267
268 occurs.push_str(fork_id);
269 let mut command =
270 process::Command::new(env::current_exe().expect("current_exe() failed, cannot fork"));
271 command
272 .args(cmdline::strip_cmdline(env::args())?)
273 .args(cmdline::RUN_TEST_ARGS)
274 .arg(test_name)
275 .env(OCCURS_ENV, &occurs)
276 .stdin(Stdio::null())
277 .stdout(Stdio::piped())
278 .stderr(Stdio::piped());
279 process_modifier(&mut command);
280
281 let mut child = command.spawn().map(KillOnDrop)?;
282 let () = in_parent(&mut child.0);
283
284 Ok(())
285 }
286}
287
288
289#[cfg(test)]
290mod test {
291 use super::*;
292
293 use std::thread;
294
295 use crate::fork_id;
296
297
298 fn sleep(ms: u64) {
299 thread::sleep(::std::time::Duration::from_millis(ms));
300 }
301
302 fn capturing_output(cmd: &mut process::Command) {
303 cmd.stdout(Stdio::piped()).stderr(Stdio::inherit());
304 }
305
306 fn inherit_output(cmd: &mut process::Command) {
307 cmd.stdout(Stdio::inherit()).stderr(Stdio::inherit());
308 }
309
310 fn wait_for_child_output(child: &mut Child) -> String {
311 let mut output = String::new();
312 child
313 .stdout
314 .as_mut()
315 .unwrap()
316 .read_to_string(&mut output)
317 .unwrap();
318 assert!(child.wait().unwrap().success());
319 output
320 }
321
322 fn wait_for_child(child: &mut Child) {
323 assert!(child.wait().unwrap().success());
324 }
325
326 #[test]
327 fn fork_basically_works() {
328 let status = fork_int(
329 "fork::test::fork_basically_works",
330 fork_id!(),
331 |_| (),
332 |child| child.wait().unwrap(),
333 || println!("hello from child"),
334 )
335 .unwrap();
336 assert!(status.success());
337 }
338
339 #[test]
340 fn child_output_captured_and_repeated() {
341 let output = fork_int(
342 "fork::test::child_output_captured_and_repeated",
343 fork_id!(),
344 capturing_output,
345 wait_for_child_output,
346 || {
347 fork_int(
348 "fork::test::child_output_captured_and_repeated",
349 fork_id!(),
350 |_| (),
351 wait_for_child,
352 || println!("hello from child"),
353 )
354 .unwrap()
355 },
356 )
357 .unwrap();
358 assert!(output.contains("hello from child"));
359 }
360
361 #[test]
362 fn child_killed_if_parent_exits_first() {
363 let output = fork_int(
364 "fork::test::child_killed_if_parent_exits_first",
365 fork_id!(),
366 capturing_output,
367 wait_for_child_output,
368 || {
369 fork_int(
370 "fork::test::child_killed_if_parent_exits_first",
371 fork_id!(),
372 inherit_output,
373 |_| (),
374 || {
375 sleep(100);
376 println!("hello from child");
377 },
378 )
379 .unwrap()
380 },
381 )
382 .unwrap();
383
384 sleep(200);
385 assert!(
386 !output.contains("hello from child"),
387 "Had unexpected output:\n{}",
388 output
389 );
390 }
391
392 #[test]
393 fn child_killed_if_parent_panics_first() {
394 let output = fork_int(
395 "fork::test::child_killed_if_parent_panics_first",
396 fork_id!(),
397 capturing_output,
398 wait_for_child_output,
399 || {
400 assert!(panic::catch_unwind(panic::AssertUnwindSafe(|| fork_int(
401 "fork::test::child_killed_if_parent_panics_first",
402 fork_id!(),
403 inherit_output,
404 |_| panic!("testing a panic, nothing to see here"),
405 || {
406 sleep(100);
407 println!("hello from child");
408 }
409 )
410 .unwrap()))
411 .is_err());
412 },
413 )
414 .unwrap();
415
416 sleep(200);
417 assert!(
418 !output.contains("hello from child"),
419 "Had unexpected output:\n{}",
420 output
421 );
422 }
423
424 #[test]
425 fn child_aborted_if_panics() {
426 let status = fork_int::<_, _, _, _, ()>(
427 "fork::test::child_aborted_if_panics",
428 fork_id!(),
429 |_| (),
430 |child| child.wait().unwrap(),
431 || panic!("testing a panic, nothing to see here"),
432 )
433 .unwrap();
434 assert_eq!(70, status.code().unwrap());
435 }
436
437 #[test]
439 fn data_exchange() {
440 let mut data = [1, 2, 3, 4, 5];
441
442 let () = fork_in_out(
443 fork_id!(),
444 "fork::test::data_exchange",
445 |data| {
446 assert_eq!(data.len(), 5);
447 let () = data.iter_mut().for_each(|x| *x += 1);
448 },
449 data.as_mut_slice(),
450 )
451 .unwrap();
452
453 assert_eq!(data, [2, 3, 4, 5, 6]);
454 }
455}