test_fork_core/
fork.rs

1// Copyright (C) 2025 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: (Apache-2.0 OR MIT)
3
4//-
5// Copyright 2018 Jason Lingle
6//
7// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
8// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
9// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
10// option. This file may not be copied, modified, or distributed
11// except according to those terms.
12
13use std::env;
14use std::io;
15use std::io::BufRead;
16use std::io::Read;
17use std::io::Write as _;
18use std::net::TcpListener;
19use std::net::TcpStream;
20use std::panic;
21use std::process;
22use std::process::Child;
23use std::process::Command;
24use std::process::ExitCode;
25use std::process::Stdio;
26use std::process::Termination;
27
28use crate::cmdline;
29use crate::error::Result;
30
31
32const OCCURS_ENV: &str = "TEST_FORK_OCCURS";
33const OCCURS_TERM_LENGTH: usize = 17; /* ':' plus 16 hexits */
34
35
36/// Simulate a process fork.
37///
38/// Since this is not a true process fork, the calling code must be structured
39/// to ensure that the child process, upon starting from the same entry point,
40/// also reaches this same `fork()` call. Recursive forks are supported; the
41/// child branch is taken from all child processes of the fork even if it is
42/// not directly the child of a particular branch. However, encountering the
43/// same fork point more than once in a single execution sequence of a child
44/// process is not (e.g., putting this call in a recursive function) and
45/// results in unspecified behaviour.
46///
47/// `fork_id` is a unique identifier identifying this particular fork location.
48/// This *must* be stable across processes of the same executable; pointers are
49/// not suitable stable, and string constants may not be suitably unique. The
50/// [`fork_id!()`] macro is the recommended way to supply this
51/// parameter.
52///
53/// `test_name` must exactly match the full path of the test function being
54/// run.
55///
56/// If `test` panics, the child process exits with a failure code immediately
57/// rather than let the panic propagate out of the `fork()` call.
58///
59/// ## Panics
60///
61/// Panics if the environment indicates that there are already at least 16
62/// levels of fork nesting.
63///
64/// Panics if `std::env::current_exe()` fails determine the path to the current
65/// executable.
66///
67/// Panics if any argument to the current process is not valid UTF-8.
68pub fn fork<F, T>(fork_id: &str, test_name: &str, test: F) -> Result<()>
69where
70    // NB: We use `Fn` here, because `FnMut` and `FnOnce` would allow
71    //     for modification of captured variables, but that will not
72    //     work across process boundaries.
73    F: Fn() -> T,
74    T: Termination,
75{
76    fn supervise_child(child: &mut Child) {
77        let status = child.wait().expect("failed to wait for child");
78        assert!(
79            status.success(),
80            "child exited unsuccessfully with {}",
81            status
82        );
83    }
84
85    fn no_configure_child(_child: &mut Command) {}
86
87    fork_int(
88        test_name,
89        fork_id,
90        no_configure_child,
91        supervise_child,
92        test,
93    )
94}
95
96/// Simulate a process fork.
97///
98/// This function is similar to [`fork`], except that it allows for data
99/// exchange with the child process.
100pub fn fork_in_out<F, T>(fork_id: &str, test_name: &str, test: F, data: &mut [u8]) -> Result<()>
101where
102    F: Fn(&mut [u8]) -> T,
103    T: Termination,
104{
105    let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind TCP socket");
106    let addr = listener.local_addr().unwrap();
107    let data_len = data.len();
108
109    fork_int(
110        test_name,
111        fork_id,
112        |cmd| {
113            cmd.env(fork_id, addr.to_string());
114        },
115        |child| {
116            let (mut stream, _addr) = listener
117                .accept()
118                .expect("failed to listen for child connection");
119            let () = stream
120                .write_all(data)
121                .expect("failed to send data to child");
122            let () = stream
123                .read_exact(data)
124                .expect("failed to receive data from child");
125            let status = child.wait().expect("failed to wait for child");
126            assert!(
127                status.success(),
128                "child exited unsuccessfully with {}",
129                status
130            );
131        },
132        || {
133            let addr = env::var(fork_id).unwrap_or_else(|err| {
134                panic!("failed to retrieve {fork_id} environment variable: {err}")
135            });
136            let mut stream =
137                TcpStream::connect(addr).expect("failed to establish connection with parent");
138
139            let mut data = Vec::with_capacity(data_len);
140            // SAFETY: The `Vec` contains `data_len` `u8` values, which
141            //         are valid for any bit pattern, so we can safely
142            //         adjust the length.
143            let () = unsafe { data.set_len(data_len) };
144
145            let () = stream
146                .read_exact(&mut data)
147                .expect("failed to receive data from parent");
148            let status = test(&mut data);
149            let () = stream
150                .write_all(&data)
151                .expect("failed to send data to parent");
152            status
153        },
154    )
155}
156
157pub(crate) fn fork_int<M, P, C, R, T>(
158    test_name: &str,
159    fork_id: &str,
160    process_modifier: M,
161    in_parent: P,
162    in_child: C,
163) -> Result<R>
164where
165    M: FnOnce(&mut process::Command),
166    P: FnOnce(&mut Child) -> R,
167    T: Termination,
168    C: FnOnce() -> T,
169{
170    // Erase the generics so we don't instantiate the actual implementation for
171    // every single test
172    let mut return_value = None;
173    let mut process_modifier = Some(process_modifier);
174    let mut in_parent = Some(in_parent);
175    let mut in_child = Some(in_child);
176
177    fork_impl(
178        test_name,
179        fork_id,
180        &mut |cmd| process_modifier.take().unwrap()(cmd),
181        &mut |child| return_value = Some(in_parent.take().unwrap()(child)),
182        &mut || in_child.take().unwrap()(),
183    )
184    .map(|_| return_value.unwrap())
185}
186
187fn fork_impl<T: Termination>(
188    test_name: &str,
189    fork_id: &str,
190    process_modifier: &mut dyn FnMut(&mut process::Command),
191    in_parent: &mut dyn FnMut(&mut Child),
192    in_child: &mut dyn FnMut() -> T,
193) -> Result<()> {
194    let mut occurs = env::var(OCCURS_ENV).unwrap_or_else(|_| String::new());
195    if occurs.contains(fork_id) {
196        match panic::catch_unwind(panic::AssertUnwindSafe(in_child)) {
197            Ok(test_result) => {
198                let rc = if test_result.report() == ExitCode::SUCCESS {
199                    0
200                } else {
201                    70
202                };
203                process::exit(rc)
204            }
205            // Assume that the default panic handler already printed something
206            //
207            // We don't use process::abort() since it produces core dumps on
208            // some systems and isn't something more special than a normal
209            // panic.
210            Err(_) => process::exit(70 /* EX_SOFTWARE */),
211        }
212    } else {
213        // Prevent misconfiguration creating a fork bomb
214        if occurs.len() > 16 * OCCURS_TERM_LENGTH {
215            panic!("test-fork: Not forking due to >=16 levels of recursion");
216        }
217
218        struct KillOnDrop(Child);
219        impl Drop for KillOnDrop {
220            fn drop(&mut self) {
221                // Kill the child if it hasn't exited yet
222                let _ = self.0.kill();
223
224                // Copy the child's output to our own
225                // Awkwardly, `print!()` and `println!()` are our only gateway
226                // to putting things in the captured output. Generally test
227                // output really is text, so work on that assumption and read
228                // line-by-line, converting lossily into UTF-8 so we can
229                // println!() it.
230
231                fn drain(read: &mut dyn Read, stderr: bool) {
232                    let mut buf = Vec::new();
233                    let mut br = io::BufReader::new(read);
234                    loop {
235                        // We can't use read_line() or lines() since they break if
236                        // there's any non-UTF-8 output at all. \n occurs at the
237                        // end of the line endings on all major platforms, so we
238                        // can just use that as a delimiter.
239                        if br.read_until(b'\n', &mut buf).is_err() {
240                            break;
241                        }
242                        if buf.is_empty() {
243                            break;
244                        }
245
246                        // not println!() because we already have a line ending
247                        // from above.
248                        let s = String::from_utf8_lossy(&buf);
249                        if stderr {
250                            eprint!("{s}");
251                        } else {
252                            print!("{s}");
253                        }
254                        buf.clear();
255                    }
256                }
257
258                if let Some(stdout) = self.0.stdout.as_mut() {
259                    let () = drain(stdout, false);
260                }
261
262                if let Some(stderr) = self.0.stderr.as_mut() {
263                    let () = drain(stderr, true);
264                }
265            }
266        }
267
268        occurs.push_str(fork_id);
269        let mut command =
270            process::Command::new(env::current_exe().expect("current_exe() failed, cannot fork"));
271        command
272            .args(cmdline::strip_cmdline(env::args())?)
273            .args(cmdline::RUN_TEST_ARGS)
274            .arg(test_name)
275            .env(OCCURS_ENV, &occurs)
276            .stdin(Stdio::null())
277            .stdout(Stdio::piped())
278            .stderr(Stdio::piped());
279        process_modifier(&mut command);
280
281        let mut child = command.spawn().map(KillOnDrop)?;
282        let () = in_parent(&mut child.0);
283
284        Ok(())
285    }
286}
287
288
289#[cfg(test)]
290mod test {
291    use super::*;
292
293    use std::thread;
294
295    use crate::fork_id;
296
297
298    fn sleep(ms: u64) {
299        thread::sleep(::std::time::Duration::from_millis(ms));
300    }
301
302    fn capturing_output(cmd: &mut process::Command) {
303        cmd.stdout(Stdio::piped()).stderr(Stdio::inherit());
304    }
305
306    fn inherit_output(cmd: &mut process::Command) {
307        cmd.stdout(Stdio::inherit()).stderr(Stdio::inherit());
308    }
309
310    fn wait_for_child_output(child: &mut Child) -> String {
311        let mut output = String::new();
312        child
313            .stdout
314            .as_mut()
315            .unwrap()
316            .read_to_string(&mut output)
317            .unwrap();
318        assert!(child.wait().unwrap().success());
319        output
320    }
321
322    fn wait_for_child(child: &mut Child) {
323        assert!(child.wait().unwrap().success());
324    }
325
326    #[test]
327    fn fork_basically_works() {
328        let status = fork_int(
329            "fork::test::fork_basically_works",
330            fork_id!(),
331            |_| (),
332            |child| child.wait().unwrap(),
333            || println!("hello from child"),
334        )
335        .unwrap();
336        assert!(status.success());
337    }
338
339    #[test]
340    fn child_output_captured_and_repeated() {
341        let output = fork_int(
342            "fork::test::child_output_captured_and_repeated",
343            fork_id!(),
344            capturing_output,
345            wait_for_child_output,
346            || {
347                fork_int(
348                    "fork::test::child_output_captured_and_repeated",
349                    fork_id!(),
350                    |_| (),
351                    wait_for_child,
352                    || println!("hello from child"),
353                )
354                .unwrap()
355            },
356        )
357        .unwrap();
358        assert!(output.contains("hello from child"));
359    }
360
361    #[test]
362    fn child_killed_if_parent_exits_first() {
363        let output = fork_int(
364            "fork::test::child_killed_if_parent_exits_first",
365            fork_id!(),
366            capturing_output,
367            wait_for_child_output,
368            || {
369                fork_int(
370                    "fork::test::child_killed_if_parent_exits_first",
371                    fork_id!(),
372                    inherit_output,
373                    |_| (),
374                    || {
375                        sleep(100);
376                        println!("hello from child");
377                    },
378                )
379                .unwrap()
380            },
381        )
382        .unwrap();
383
384        sleep(200);
385        assert!(
386            !output.contains("hello from child"),
387            "Had unexpected output:\n{}",
388            output
389        );
390    }
391
392    #[test]
393    fn child_killed_if_parent_panics_first() {
394        let output = fork_int(
395            "fork::test::child_killed_if_parent_panics_first",
396            fork_id!(),
397            capturing_output,
398            wait_for_child_output,
399            || {
400                assert!(panic::catch_unwind(panic::AssertUnwindSafe(|| fork_int(
401                    "fork::test::child_killed_if_parent_panics_first",
402                    fork_id!(),
403                    inherit_output,
404                    |_| panic!("testing a panic, nothing to see here"),
405                    || {
406                        sleep(100);
407                        println!("hello from child");
408                    }
409                )
410                .unwrap()))
411                .is_err());
412            },
413        )
414        .unwrap();
415
416        sleep(200);
417        assert!(
418            !output.contains("hello from child"),
419            "Had unexpected output:\n{}",
420            output
421        );
422    }
423
424    #[test]
425    fn child_aborted_if_panics() {
426        let status = fork_int::<_, _, _, _, ()>(
427            "fork::test::child_aborted_if_panics",
428            fork_id!(),
429            |_| (),
430            |child| child.wait().unwrap(),
431            || panic!("testing a panic, nothing to see here"),
432        )
433        .unwrap();
434        assert_eq!(70, status.code().unwrap());
435    }
436
437    /// Check that we can exchange data with the child process.
438    #[test]
439    fn data_exchange() {
440        let mut data = [1, 2, 3, 4, 5];
441
442        let () = fork_in_out(
443            fork_id!(),
444            "fork::test::data_exchange",
445            |data| {
446                assert_eq!(data.len(), 5);
447                let () = data.iter_mut().for_each(|x| *x += 1);
448            },
449            data.as_mut_slice(),
450        )
451        .unwrap();
452
453        assert_eq!(data, [2, 3, 4, 5, 6]);
454    }
455}