Skip to main content

singe_cuda/
checkpoint.rs

1use std::{
2    fmt::{self, Display, Formatter},
3    os::raw::c_int,
4    ptr, thread,
5    time::{Duration, Instant},
6};
7
8use num_enum::{IntoPrimitive, TryFromPrimitive};
9use singe_core::impl_enum_conversion;
10use singe_cuda_sys::driver;
11
12use crate::{
13    device::{Device, Uuid},
14    error::{Error, Result, Status},
15    try_ffi,
16};
17
18/// CUDA process state used by the checkpoint and restore driver APIs.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
20#[repr(u32)]
21#[non_exhaustive]
22pub enum ProcessState {
23    /// The process can make CUDA API calls.
24    Running = driver::CUprocessState::CU_PROCESS_STATE_RUNNING as _,
25    /// CUDA API locks are taken and further CUDA API calls will block.
26    Locked = driver::CUprocessState::CU_PROCESS_STATE_LOCKED as _,
27    /// GPU memory has been moved to host memory and device handles were released.
28    Checkpointed = driver::CUprocessState::CU_PROCESS_STATE_CHECKPOINTED as _,
29    /// The process entered an unrecoverable error during checkpoint or restore.
30    Failed = driver::CUprocessState::CU_PROCESS_STATE_FAILED as _,
31    /// Unknown process state returned by a newer CUDA driver version.
32    #[num_enum(catch_all)]
33    Unknown(u32),
34}
35
36impl_enum_conversion!(driver::CUprocessState, ProcessState);
37
38impl Display for ProcessState {
39    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::Running => write!(f, "CU_PROCESS_STATE_RUNNING"),
42            Self::Locked => write!(f, "CU_PROCESS_STATE_LOCKED"),
43            Self::Checkpointed => write!(f, "CU_PROCESS_STATE_CHECKPOINTED"),
44            Self::Failed => write!(f, "CU_PROCESS_STATE_FAILED"),
45            Self::Unknown(value) => write!(f, "CU_PROCESS_STATE_UNKNOWN({value})"),
46        }
47    }
48}
49
50/// Operating-system process identifier used by checkpoint APIs.
51pub type ProcessId = c_int;
52
53/// Options for [`CheckpointProcess::lock`].
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
55pub struct LockOptions {
56    timeout: Option<Duration>,
57}
58
59impl LockOptions {
60    /// Creates lock options without a timeout.
61    pub const fn new() -> Self {
62        Self { timeout: None }
63    }
64
65    /// Sets the maximum time CUDA should spend attempting to lock the process.
66    ///
67    /// A missing timeout passes `0` to CUDA, which means no timeout.
68    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
69        self.timeout = Some(timeout);
70        self
71    }
72
73    fn to_raw(self) -> Result<driver::CUcheckpointLockArgs> {
74        let timeout_ms = match self.timeout {
75            Some(timeout) => timeout
76                .as_millis()
77                .try_into()
78                .map_err(|_| Error::InvalidValue)?,
79            None => 0,
80        };
81
82        Ok(driver::CUcheckpointLockArgs {
83            timeoutMs: timeout_ms,
84            ..Default::default()
85        })
86    }
87}
88
89/// Outcome of [`CheckpointProcess::try_lock`].
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
91#[non_exhaustive]
92pub enum LockResult {
93    /// Lock completed and the process entered [`ProcessState::Locked`].
94    Locked,
95    /// Lock call timed out and the process remained in [`ProcessState::Running`].
96    TimedOut,
97}
98
99/// GPU UUID remapping entry used during restore.
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
101pub struct GpuPair {
102    /// UUID of the GPU that was checkpointed.
103    pub old_uuid: Uuid,
104    /// UUID of the GPU to restore onto.
105    pub new_uuid: Uuid,
106}
107
108impl From<(Uuid, Uuid)> for GpuPair {
109    fn from((old_uuid, new_uuid): (Uuid, Uuid)) -> Self {
110        Self::new(old_uuid, new_uuid)
111    }
112}
113
114impl GpuPair {
115    /// Creates a GPU pair from the original and target GPU UUIDs.
116    pub const fn new(old_uuid: Uuid, new_uuid: Uuid) -> Self {
117        Self { old_uuid, new_uuid }
118    }
119
120    /// Creates a GPU pair from two CUDA devices by reading their UUIDs.
121    pub fn from_devices(old: Device, new: Device) -> Result<Self> {
122        let old_uuid = old.properties()?.uuid;
123        let new_uuid = new.properties()?.uuid;
124        Ok(Self::new(old_uuid, new_uuid))
125    }
126
127    fn to_raw(self) -> driver::CUcheckpointGpuPair {
128        driver::CUcheckpointGpuPair {
129            oldUuid: self.old_uuid.into(),
130            newUuid: self.new_uuid.into(),
131        }
132    }
133}
134
135/// Options for [`CheckpointProcess::checkpoint`]. Reserved for future CUDA versions.
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
137pub struct CheckpointOptions;
138
139impl CheckpointOptions {
140    /// Creates default checkpoint options.
141    pub const fn new() -> Self {
142        Self
143    }
144
145    fn to_raw(self) -> driver::CUcheckpointCheckpointArgs {
146        let _ = self;
147        driver::CUcheckpointCheckpointArgs::default()
148    }
149}
150
151impl Default for CheckpointOptions {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157/// Options for [`CheckpointProcess::restore`], including optional GPU remap entries.
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub struct RestoreOptions {
160    gpu_pairs: Vec<driver::CUcheckpointGpuPair>,
161}
162
163impl RestoreOptions {
164    /// Creates restore options with no GPU remapping pairs.
165    pub fn new() -> Self {
166        Self {
167            gpu_pairs: Vec::new(),
168        }
169    }
170
171    /// Creates restore options from checkpointed/restored GPU UUID pairs.
172    pub fn with_gpu_pairs(gpu_pairs: &[GpuPair]) -> Self {
173        let mut options = Self::new();
174        options.gpu_pairs = gpu_pairs.iter().copied().map(GpuPair::to_raw).collect();
175        options
176    }
177
178    /// Adds a single GPU remapping pair.
179    pub fn with_gpu_pair(mut self, pair: GpuPair) -> Self {
180        self.gpu_pairs.push(pair.to_raw());
181        self
182    }
183
184    /// Adds a GPU remapping pair from CUDA devices.
185    pub fn with_device_pair(mut self, old: Device, new: Device) -> Result<Self> {
186        self.push_device_pair(old, new)?;
187        Ok(self)
188    }
189
190    /// Adds a single GPU remapping pair in place.
191    pub fn push_gpu_pair(&mut self, pair: GpuPair) {
192        self.gpu_pairs.push(pair.to_raw());
193    }
194
195    /// Adds a GPU remapping pair from CUDA devices in place.
196    pub fn push_device_pair(&mut self, old: Device, new: Device) -> Result<()> {
197        let pair = GpuPair::from_devices(old, new)?;
198        self.push_gpu_pair(pair);
199        Ok(())
200    }
201
202    /// Creates restore options from checkpointed/restored GPU device pairs.
203    pub fn with_device_pairs(gpu_pairs: impl AsRef<[(Device, Device)]>) -> Result<Self> {
204        let mut options = Self::new();
205        for &(old, new) in gpu_pairs.as_ref() {
206            options.push_device_pair(old, new)?;
207        }
208        Ok(options)
209    }
210
211    fn into_raw(mut self) -> Result<driver::CUcheckpointRestoreArgs> {
212        let gpu_pairs_count = self
213            .gpu_pairs
214            .len()
215            .try_into()
216            .map_err(|_| Error::InvalidValue)?;
217        let gpu_pairs = if self.gpu_pairs.is_empty() {
218            ptr::null_mut()
219        } else {
220            self.gpu_pairs.as_mut_ptr()
221        };
222
223        Ok(driver::CUcheckpointRestoreArgs {
224            gpuPairs: gpu_pairs,
225            gpuPairsCount: gpu_pairs_count,
226            ..Default::default()
227        })
228    }
229}
230
231impl Default for RestoreOptions {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237impl From<&[GpuPair]> for RestoreOptions {
238    fn from(gpu_pairs: &[GpuPair]) -> Self {
239        Self::with_gpu_pairs(gpu_pairs)
240    }
241}
242
243impl From<Vec<GpuPair>> for RestoreOptions {
244    fn from(gpu_pairs: Vec<GpuPair>) -> Self {
245        Self::with_gpu_pairs(&gpu_pairs)
246    }
247}
248
249/// Options for [`CheckpointProcess::unlock`]. Reserved for future CUDA versions.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
251pub struct UnlockOptions;
252
253impl UnlockOptions {
254    /// Creates default unlock options.
255    pub const fn new() -> Self {
256        Self
257    }
258
259    fn to_raw(self) -> driver::CUcheckpointUnlockArgs {
260        let _ = self;
261        driver::CUcheckpointUnlockArgs::default()
262    }
263}
264
265impl Default for UnlockOptions {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271/// A CUDA process controlled through the driver checkpoint APIs.
272///
273/// These APIs are intended for an external controller process. Locking a
274/// process blocks further CUDA API calls in that process until it is restored
275/// and unlocked.
276#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
277pub struct CheckpointProcess {
278    pid: ProcessId,
279}
280
281impl CheckpointProcess {
282    /// Creates a CUDA checkpoint target from an operating-system process ID.
283    pub const fn new(pid: ProcessId) -> Self {
284        Self { pid }
285    }
286
287    /// Creates a CUDA checkpoint target from an operating-system process ID.
288    pub const fn from_pid(pid: ProcessId) -> Self {
289        Self::new(pid)
290    }
291
292    /// Creates a CUDA checkpoint target for the current process.
293    pub fn current() -> Self {
294        Self::new(std::process::id() as ProcessId)
295    }
296
297    /// Returns the operating-system process ID controlled by this value.
298    pub const fn pid(self) -> ProcessId {
299        self.pid
300    }
301
302    /// Returns the current CUDA checkpoint state of the process.
303    pub fn state(self) -> Result<ProcessState> {
304        let mut state = driver::CUprocessState::CU_PROCESS_STATE_RUNNING;
305        unsafe {
306            try_ffi!(driver::cuCheckpointProcessGetState(
307                self.pid,
308                &raw mut state,
309            ))?;
310        }
311        ProcessState::try_from(state as u32).map_err(|_| Error::InvalidValue)
312    }
313
314    /// Returns `true` when this process currently has [`ProcessState::Running`].
315    pub fn is_running(self) -> bool {
316        self.state()
317            .is_ok_and(|state| state == ProcessState::Running)
318    }
319
320    /// Returns `true` when this process currently has [`ProcessState::Locked`].
321    pub fn is_locked(self) -> bool {
322        self.state()
323            .is_ok_and(|state| state == ProcessState::Locked)
324    }
325
326    /// Returns `true` when this process currently has [`ProcessState::Checkpointed`].
327    pub fn is_checkpointed(self) -> bool {
328        self.state()
329            .is_ok_and(|state| state == ProcessState::Checkpointed)
330    }
331
332    /// Blocks until the process reaches `expected`, or returns a timeout error.
333    pub fn wait_for_state(self, expected: ProcessState, timeout: Duration) -> Result<ProcessState> {
334        let end = Instant::now()
335            .checked_add(timeout)
336            .ok_or(Error::InvalidValue)?;
337        const POLL_INTERVAL: Duration = Duration::from_millis(25);
338
339        loop {
340            let state = self.state()?;
341            if state == expected {
342                return Ok(state);
343            }
344            if Instant::now() >= end {
345                return Err(Error::Cuda {
346                    code: Status::Timeout,
347                    message: format!(
348                        "timed out waiting for checkpoint process {} to reach {}",
349                        self.pid, expected
350                    ),
351                });
352            }
353            thread::sleep(POLL_INTERVAL);
354        }
355    }
356
357    /// Returns the CUDA restore thread ID for the process.
358    pub fn restore_thread_id(self) -> Result<i32> {
359        let mut thread_id = 0;
360        unsafe {
361            try_ffi!(driver::cuCheckpointProcessGetRestoreThreadId(
362                self.pid,
363                &raw mut thread_id,
364            ))?;
365        }
366        Ok(thread_id)
367    }
368
369    /// Locks a running CUDA process so further CUDA API calls in that process block.
370    ///
371    /// On success the process enters [`ProcessState::Locked`].
372    pub fn lock(self, options: LockOptions) -> Result<()> {
373        match self.try_lock(options)? {
374            LockResult::Locked => Ok(()),
375            LockResult::TimedOut => Err(driver::CUresult::CUDA_ERROR_NOT_READY.into()),
376        }
377    }
378
379    /// Attempts to lock a running CUDA process and reports whether the timeout was hit.
380    ///
381    /// On success the process usually enters [`ProcessState::Locked`]. If a timeout
382    /// is set and reached, this returns [`LockResult::TimedOut`].
383    pub fn try_lock(self, options: LockOptions) -> Result<LockResult> {
384        let mut args = options.to_raw()?;
385        let result = unsafe { driver::cuCheckpointProcessLock(self.pid, &raw mut args) };
386        match result {
387            driver::CUresult::CUDA_SUCCESS => Ok(LockResult::Locked),
388            driver::CUresult::CUDA_ERROR_NOT_READY => Ok(LockResult::TimedOut),
389            status => Err(status.into()),
390        }
391    }
392
393    /// Moves the locked process's GPU memory into host memory managed by the driver.
394    ///
395    /// On success the process enters [`ProcessState::Checkpointed`].
396    pub fn checkpoint(self) -> Result<()> {
397        self.checkpoint_with_options(CheckpointOptions::new())
398    }
399
400    /// Moves the locked process's GPU memory into host memory managed by the driver.
401    ///
402    /// This variant accepts explicit checkpoint options.
403    pub fn checkpoint_with_options(self, options: CheckpointOptions) -> Result<()> {
404        let mut args = options.to_raw();
405        unsafe {
406            try_ffi!(driver::cuCheckpointProcessCheckpoint(
407                self.pid,
408                &raw mut args
409            ))
410        }
411    }
412
413    /// Locks and checkpoints a running CUDA process.
414    ///
415    /// On success the process enters [`ProcessState::Checkpointed`].
416    pub fn suspend(self, options: LockOptions) -> Result<()> {
417        self.lock(options)?;
418        self.checkpoint()
419    }
420
421    /// Toggles the process between running and checkpointed states.
422    ///
423    /// - From `Running`, performs [`CheckpointProcess::suspend`].
424    /// - From `Checkpointed`, performs [`CheckpointProcess::resume`].
425    pub fn toggle(
426        self,
427        options: LockOptions,
428        gpu_pairs: impl AsRef<[GpuPair]>,
429    ) -> Result<ProcessState> {
430        match self.state()? {
431            ProcessState::Running => {
432                self.suspend(options)?;
433                Ok(ProcessState::Checkpointed)
434            }
435            ProcessState::Checkpointed => {
436                self.resume(gpu_pairs)?;
437                Ok(ProcessState::Running)
438            }
439            _ => Err(Error::Cuda {
440                code: Status::IllegalState,
441                message: String::from("cannot toggle checkpoint process from current state"),
442            }),
443        }
444    }
445
446    /// Restores a checkpointed process, optionally remapping checkpointed GPUs.
447    ///
448    /// If `gpu_pairs` is not empty, CUDA requires it to contain every
449    /// checkpointed GPU.
450    ///
451    /// On success the process enters [`ProcessState::Locked`].
452    pub fn restore_with_options(self, options: RestoreOptions) -> Result<()> {
453        let mut args = options.into_raw()?;
454        unsafe { try_ffi!(driver::cuCheckpointProcessRestore(self.pid, &raw mut args)) }
455    }
456
457    /// Restores a checkpointed process, optionally remapping checkpointed GPUs.
458    ///
459    /// If `gpu_pairs` is not empty, CUDA requires it to contain every
460    /// checkpointed GPU.
461    ///
462    /// On success the process enters [`ProcessState::Locked`].
463    pub fn restore(self, gpu_pairs: impl AsRef<[GpuPair]>) -> Result<()> {
464        self.restore_with_options(gpu_pairs.as_ref().into())
465    }
466
467    /// Restores and unlocks a checkpointed CUDA process.
468    ///
469    /// On success the process enters [`ProcessState::Running`].
470    pub fn resume(self, gpu_pairs: impl AsRef<[GpuPair]>) -> Result<()> {
471        self.restore(gpu_pairs)?;
472        self.unlock()
473    }
474
475    /// Unlocks a locked CUDA process so it can resume CUDA API calls.
476    ///
477    /// On success the process enters [`ProcessState::Running`].
478    pub fn unlock(self) -> Result<()> {
479        self.unlock_with_options(UnlockOptions::new())
480    }
481
482    /// Unlocks a locked CUDA process so it can resume CUDA API calls with options.
483    ///
484    /// On success the process enters [`ProcessState::Running`].
485    pub fn unlock_with_options(self, options: UnlockOptions) -> Result<()> {
486        let mut args = options.to_raw();
487        unsafe { try_ffi!(driver::cuCheckpointProcessUnlock(self.pid, &raw mut args)) }
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn it_works() {
497        let process = CheckpointProcess::current();
498        match process.state() {
499            Ok(state) => assert!(matches!(
500                state,
501                ProcessState::Running
502                    | ProcessState::Locked
503                    | ProcessState::Checkpointed
504                    | ProcessState::Failed
505                    | ProcessState::Unknown(_)
506            )),
507            Err(error) => assert_checkpoint_error(error),
508        }
509
510        let missing_process = CheckpointProcess::from_pid(-1);
511        checkpoint_fails(missing_process.restore_thread_id());
512        checkpoint_fails(
513            missing_process.lock(LockOptions::new().with_timeout(Duration::from_millis(1))),
514        );
515        checkpoint_fails(missing_process.checkpoint());
516        checkpoint_fails(missing_process.checkpoint_with_options(CheckpointOptions::new()));
517        checkpoint_fails(missing_process.restore(&[]));
518        checkpoint_fails(missing_process.restore_with_options(RestoreOptions::new()));
519        checkpoint_fails(missing_process.unlock());
520        checkpoint_fails(missing_process.unlock_with_options(UnlockOptions::new()));
521    }
522
523    fn checkpoint_fails<T>(result: Result<T>) {
524        match result {
525            Err(error) => assert_checkpoint_error(error),
526            Ok(_) => panic!("checkpoint call unexpectedly succeeded"),
527        }
528    }
529
530    fn assert_checkpoint_error(error: Error) {
531        match error {
532            Error::Cuda { code, .. }
533                if matches!(
534                    code,
535                    Status::InvalidValue
536                        | Status::NotInitialized
537                        | Status::NotSupported
538                        | Status::IllegalState
539                        | Status::OperatingSystem
540                ) => {}
541            error => panic!("{error:?}"),
542        }
543    }
544}