1use std::{
2 fmt::{self, Display, Formatter},
3 os::raw::c_int,
4 ptr, thread,
5 time::{Duration, Instant},
6};
7
8use num_enum::{IntoPrimitive, TryFromPrimitive};
9use singe_core::impl_enum_conversion;
10use singe_cuda_sys::driver;
11
12use crate::{
13 device::{Device, Uuid},
14 error::{Error, Result, Status},
15 try_ffi,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
20#[repr(u32)]
21#[non_exhaustive]
22pub enum ProcessState {
23 Running = driver::CUprocessState::CU_PROCESS_STATE_RUNNING as _,
25 Locked = driver::CUprocessState::CU_PROCESS_STATE_LOCKED as _,
27 Checkpointed = driver::CUprocessState::CU_PROCESS_STATE_CHECKPOINTED as _,
29 Failed = driver::CUprocessState::CU_PROCESS_STATE_FAILED as _,
31 #[num_enum(catch_all)]
33 Unknown(u32),
34}
35
36impl_enum_conversion!(driver::CUprocessState, ProcessState);
37
38impl Display for ProcessState {
39 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::Running => write!(f, "CU_PROCESS_STATE_RUNNING"),
42 Self::Locked => write!(f, "CU_PROCESS_STATE_LOCKED"),
43 Self::Checkpointed => write!(f, "CU_PROCESS_STATE_CHECKPOINTED"),
44 Self::Failed => write!(f, "CU_PROCESS_STATE_FAILED"),
45 Self::Unknown(value) => write!(f, "CU_PROCESS_STATE_UNKNOWN({value})"),
46 }
47 }
48}
49
50pub type ProcessId = c_int;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
55pub struct LockOptions {
56 timeout: Option<Duration>,
57}
58
59impl LockOptions {
60 pub const fn new() -> Self {
62 Self { timeout: None }
63 }
64
65 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
69 self.timeout = Some(timeout);
70 self
71 }
72
73 fn to_raw(self) -> Result<driver::CUcheckpointLockArgs> {
74 let timeout_ms = match self.timeout {
75 Some(timeout) => timeout
76 .as_millis()
77 .try_into()
78 .map_err(|_| Error::InvalidValue)?,
79 None => 0,
80 };
81
82 Ok(driver::CUcheckpointLockArgs {
83 timeoutMs: timeout_ms,
84 ..Default::default()
85 })
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
91#[non_exhaustive]
92pub enum LockResult {
93 Locked,
95 TimedOut,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
101pub struct GpuPair {
102 pub old_uuid: Uuid,
104 pub new_uuid: Uuid,
106}
107
108impl From<(Uuid, Uuid)> for GpuPair {
109 fn from((old_uuid, new_uuid): (Uuid, Uuid)) -> Self {
110 Self::new(old_uuid, new_uuid)
111 }
112}
113
114impl GpuPair {
115 pub const fn new(old_uuid: Uuid, new_uuid: Uuid) -> Self {
117 Self { old_uuid, new_uuid }
118 }
119
120 pub fn from_devices(old: Device, new: Device) -> Result<Self> {
122 let old_uuid = old.properties()?.uuid;
123 let new_uuid = new.properties()?.uuid;
124 Ok(Self::new(old_uuid, new_uuid))
125 }
126
127 fn to_raw(self) -> driver::CUcheckpointGpuPair {
128 driver::CUcheckpointGpuPair {
129 oldUuid: self.old_uuid.into(),
130 newUuid: self.new_uuid.into(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
137pub struct CheckpointOptions;
138
139impl CheckpointOptions {
140 pub const fn new() -> Self {
142 Self
143 }
144
145 fn to_raw(self) -> driver::CUcheckpointCheckpointArgs {
146 let _ = self;
147 driver::CUcheckpointCheckpointArgs::default()
148 }
149}
150
151impl Default for CheckpointOptions {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
159pub struct RestoreOptions {
160 gpu_pairs: Vec<driver::CUcheckpointGpuPair>,
161}
162
163impl RestoreOptions {
164 pub fn new() -> Self {
166 Self {
167 gpu_pairs: Vec::new(),
168 }
169 }
170
171 pub fn with_gpu_pairs(gpu_pairs: &[GpuPair]) -> Self {
173 let mut options = Self::new();
174 options.gpu_pairs = gpu_pairs.iter().copied().map(GpuPair::to_raw).collect();
175 options
176 }
177
178 pub fn with_gpu_pair(mut self, pair: GpuPair) -> Self {
180 self.gpu_pairs.push(pair.to_raw());
181 self
182 }
183
184 pub fn with_device_pair(mut self, old: Device, new: Device) -> Result<Self> {
186 self.push_device_pair(old, new)?;
187 Ok(self)
188 }
189
190 pub fn push_gpu_pair(&mut self, pair: GpuPair) {
192 self.gpu_pairs.push(pair.to_raw());
193 }
194
195 pub fn push_device_pair(&mut self, old: Device, new: Device) -> Result<()> {
197 let pair = GpuPair::from_devices(old, new)?;
198 self.push_gpu_pair(pair);
199 Ok(())
200 }
201
202 pub fn with_device_pairs(gpu_pairs: impl AsRef<[(Device, Device)]>) -> Result<Self> {
204 let mut options = Self::new();
205 for &(old, new) in gpu_pairs.as_ref() {
206 options.push_device_pair(old, new)?;
207 }
208 Ok(options)
209 }
210
211 fn into_raw(mut self) -> Result<driver::CUcheckpointRestoreArgs> {
212 let gpu_pairs_count = self
213 .gpu_pairs
214 .len()
215 .try_into()
216 .map_err(|_| Error::InvalidValue)?;
217 let gpu_pairs = if self.gpu_pairs.is_empty() {
218 ptr::null_mut()
219 } else {
220 self.gpu_pairs.as_mut_ptr()
221 };
222
223 Ok(driver::CUcheckpointRestoreArgs {
224 gpuPairs: gpu_pairs,
225 gpuPairsCount: gpu_pairs_count,
226 ..Default::default()
227 })
228 }
229}
230
231impl Default for RestoreOptions {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237impl From<&[GpuPair]> for RestoreOptions {
238 fn from(gpu_pairs: &[GpuPair]) -> Self {
239 Self::with_gpu_pairs(gpu_pairs)
240 }
241}
242
243impl From<Vec<GpuPair>> for RestoreOptions {
244 fn from(gpu_pairs: Vec<GpuPair>) -> Self {
245 Self::with_gpu_pairs(&gpu_pairs)
246 }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
251pub struct UnlockOptions;
252
253impl UnlockOptions {
254 pub const fn new() -> Self {
256 Self
257 }
258
259 fn to_raw(self) -> driver::CUcheckpointUnlockArgs {
260 let _ = self;
261 driver::CUcheckpointUnlockArgs::default()
262 }
263}
264
265impl Default for UnlockOptions {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
277pub struct CheckpointProcess {
278 pid: ProcessId,
279}
280
281impl CheckpointProcess {
282 pub const fn new(pid: ProcessId) -> Self {
284 Self { pid }
285 }
286
287 pub const fn from_pid(pid: ProcessId) -> Self {
289 Self::new(pid)
290 }
291
292 pub fn current() -> Self {
294 Self::new(std::process::id() as ProcessId)
295 }
296
297 pub const fn pid(self) -> ProcessId {
299 self.pid
300 }
301
302 pub fn state(self) -> Result<ProcessState> {
304 let mut state = driver::CUprocessState::CU_PROCESS_STATE_RUNNING;
305 unsafe {
306 try_ffi!(driver::cuCheckpointProcessGetState(
307 self.pid,
308 &raw mut state,
309 ))?;
310 }
311 ProcessState::try_from(state as u32).map_err(|_| Error::InvalidValue)
312 }
313
314 pub fn is_running(self) -> bool {
316 self.state()
317 .is_ok_and(|state| state == ProcessState::Running)
318 }
319
320 pub fn is_locked(self) -> bool {
322 self.state()
323 .is_ok_and(|state| state == ProcessState::Locked)
324 }
325
326 pub fn is_checkpointed(self) -> bool {
328 self.state()
329 .is_ok_and(|state| state == ProcessState::Checkpointed)
330 }
331
332 pub fn wait_for_state(self, expected: ProcessState, timeout: Duration) -> Result<ProcessState> {
334 let end = Instant::now()
335 .checked_add(timeout)
336 .ok_or(Error::InvalidValue)?;
337 const POLL_INTERVAL: Duration = Duration::from_millis(25);
338
339 loop {
340 let state = self.state()?;
341 if state == expected {
342 return Ok(state);
343 }
344 if Instant::now() >= end {
345 return Err(Error::Cuda {
346 code: Status::Timeout,
347 message: format!(
348 "timed out waiting for checkpoint process {} to reach {}",
349 self.pid, expected
350 ),
351 });
352 }
353 thread::sleep(POLL_INTERVAL);
354 }
355 }
356
357 pub fn restore_thread_id(self) -> Result<i32> {
359 let mut thread_id = 0;
360 unsafe {
361 try_ffi!(driver::cuCheckpointProcessGetRestoreThreadId(
362 self.pid,
363 &raw mut thread_id,
364 ))?;
365 }
366 Ok(thread_id)
367 }
368
369 pub fn lock(self, options: LockOptions) -> Result<()> {
373 match self.try_lock(options)? {
374 LockResult::Locked => Ok(()),
375 LockResult::TimedOut => Err(driver::CUresult::CUDA_ERROR_NOT_READY.into()),
376 }
377 }
378
379 pub fn try_lock(self, options: LockOptions) -> Result<LockResult> {
384 let mut args = options.to_raw()?;
385 let result = unsafe { driver::cuCheckpointProcessLock(self.pid, &raw mut args) };
386 match result {
387 driver::CUresult::CUDA_SUCCESS => Ok(LockResult::Locked),
388 driver::CUresult::CUDA_ERROR_NOT_READY => Ok(LockResult::TimedOut),
389 status => Err(status.into()),
390 }
391 }
392
393 pub fn checkpoint(self) -> Result<()> {
397 self.checkpoint_with_options(CheckpointOptions::new())
398 }
399
400 pub fn checkpoint_with_options(self, options: CheckpointOptions) -> Result<()> {
404 let mut args = options.to_raw();
405 unsafe {
406 try_ffi!(driver::cuCheckpointProcessCheckpoint(
407 self.pid,
408 &raw mut args
409 ))
410 }
411 }
412
413 pub fn suspend(self, options: LockOptions) -> Result<()> {
417 self.lock(options)?;
418 self.checkpoint()
419 }
420
421 pub fn toggle(
426 self,
427 options: LockOptions,
428 gpu_pairs: impl AsRef<[GpuPair]>,
429 ) -> Result<ProcessState> {
430 match self.state()? {
431 ProcessState::Running => {
432 self.suspend(options)?;
433 Ok(ProcessState::Checkpointed)
434 }
435 ProcessState::Checkpointed => {
436 self.resume(gpu_pairs)?;
437 Ok(ProcessState::Running)
438 }
439 _ => Err(Error::Cuda {
440 code: Status::IllegalState,
441 message: String::from("cannot toggle checkpoint process from current state"),
442 }),
443 }
444 }
445
446 pub fn restore_with_options(self, options: RestoreOptions) -> Result<()> {
453 let mut args = options.into_raw()?;
454 unsafe { try_ffi!(driver::cuCheckpointProcessRestore(self.pid, &raw mut args)) }
455 }
456
457 pub fn restore(self, gpu_pairs: impl AsRef<[GpuPair]>) -> Result<()> {
464 self.restore_with_options(gpu_pairs.as_ref().into())
465 }
466
467 pub fn resume(self, gpu_pairs: impl AsRef<[GpuPair]>) -> Result<()> {
471 self.restore(gpu_pairs)?;
472 self.unlock()
473 }
474
475 pub fn unlock(self) -> Result<()> {
479 self.unlock_with_options(UnlockOptions::new())
480 }
481
482 pub fn unlock_with_options(self, options: UnlockOptions) -> Result<()> {
486 let mut args = options.to_raw();
487 unsafe { try_ffi!(driver::cuCheckpointProcessUnlock(self.pid, &raw mut args)) }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn it_works() {
497 let process = CheckpointProcess::current();
498 match process.state() {
499 Ok(state) => assert!(matches!(
500 state,
501 ProcessState::Running
502 | ProcessState::Locked
503 | ProcessState::Checkpointed
504 | ProcessState::Failed
505 | ProcessState::Unknown(_)
506 )),
507 Err(error) => assert_checkpoint_error(error),
508 }
509
510 let missing_process = CheckpointProcess::from_pid(-1);
511 checkpoint_fails(missing_process.restore_thread_id());
512 checkpoint_fails(
513 missing_process.lock(LockOptions::new().with_timeout(Duration::from_millis(1))),
514 );
515 checkpoint_fails(missing_process.checkpoint());
516 checkpoint_fails(missing_process.checkpoint_with_options(CheckpointOptions::new()));
517 checkpoint_fails(missing_process.restore(&[]));
518 checkpoint_fails(missing_process.restore_with_options(RestoreOptions::new()));
519 checkpoint_fails(missing_process.unlock());
520 checkpoint_fails(missing_process.unlock_with_options(UnlockOptions::new()));
521 }
522
523 fn checkpoint_fails<T>(result: Result<T>) {
524 match result {
525 Err(error) => assert_checkpoint_error(error),
526 Ok(_) => panic!("checkpoint call unexpectedly succeeded"),
527 }
528 }
529
530 fn assert_checkpoint_error(error: Error) {
531 match error {
532 Error::Cuda { code, .. }
533 if matches!(
534 code,
535 Status::InvalidValue
536 | Status::NotInitialized
537 | Status::NotSupported
538 | Status::IllegalState
539 | Status::OperatingSystem
540 ) => {}
541 error => panic!("{error:?}"),
542 }
543 }
544}