Skip to main content

vyre_driver/
device_convergence.rs

1//! Backend-neutral device-side convergence planning for iterative analyses.
2
3/// Device-side convergence readback policy.
4#[derive(Clone, Copy, Debug, Eq, PartialEq)]
5pub enum ConvergenceReadbackPolicy {
6    /// Read the changed flag once after the device-side iteration budget completes.
7    FinalFlagOnly,
8}
9
10/// Execution plan for device-side fixed-point convergence.
11#[derive(Clone, Copy, Debug, Eq, PartialEq)]
12pub struct DeviceConvergencePlan {
13    /// Maximum number of device iterations before the final convergence flag is read.
14    pub max_device_iterations: u32,
15    /// Number of host-visible synchronization points caused by convergence detection.
16    pub host_sync_points: u32,
17    /// Number of changed-flag bytes read back to the host.
18    pub changed_flag_readback_bytes: u32,
19    /// Number of per-iteration host polls.
20    pub host_iteration_polls: u32,
21    /// Readback policy used by the plan.
22    pub readback_policy: ConvergenceReadbackPolicy,
23}
24
25/// Errors produced while planning device-side convergence.
26#[derive(Clone, Debug, Eq, PartialEq)]
27pub enum DeviceConvergencePlanError {
28    /// Iteration budget was zero.
29    EmptyIterationBudget,
30    /// Changed flag width is invalid.
31    InvalidChangedFlagWidth {
32        /// Observed changed-flag byte width.
33        bytes: u32,
34    },
35    /// The requested plan would poll the host every iteration.
36    HostPolledConvergence {
37        /// Requested number of host-side iteration polls.
38        polls: u32,
39    },
40}
41
42impl std::fmt::Display for DeviceConvergencePlanError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            Self::EmptyIterationBudget => f.write_str(
46                "device convergence iteration budget is zero. Fix: use at least one device iteration.",
47            ),
48            Self::InvalidChangedFlagWidth { bytes } => write!(
49                f,
50                "device convergence changed-flag width is {bytes} bytes. Fix: use a 4-byte device u32 changed flag."
51            ),
52            Self::HostPolledConvergence { polls } => write!(
53                f,
54                "device convergence requested {polls} host iteration polls. Fix: keep convergence detection device-side and read only the final changed flag."
55            ),
56        }
57    }
58}
59
60impl std::error::Error for DeviceConvergencePlanError {}
61
62/// Plan convergence detection for an iterative device dataflow kernel.
63///
64/// # Errors
65///
66/// Returns [`DeviceConvergencePlanError`] when the iteration budget is empty,
67/// the changed flag does not match the device ABI, or the caller asks for
68/// host-polled iteration convergence.
69pub fn plan_device_convergence(
70    max_device_iterations: u32,
71    changed_flag_bytes: u32,
72    requested_host_iteration_polls: u32,
73) -> Result<DeviceConvergencePlan, DeviceConvergencePlanError> {
74    if max_device_iterations == 0 {
75        return Err(DeviceConvergencePlanError::EmptyIterationBudget);
76    }
77    if changed_flag_bytes != 4 {
78        return Err(DeviceConvergencePlanError::InvalidChangedFlagWidth {
79            bytes: changed_flag_bytes,
80        });
81    }
82    if requested_host_iteration_polls != 0 {
83        return Err(DeviceConvergencePlanError::HostPolledConvergence {
84            polls: requested_host_iteration_polls,
85        });
86    }
87
88    Ok(DeviceConvergencePlan {
89        max_device_iterations,
90        host_sync_points: 1,
91        changed_flag_readback_bytes: changed_flag_bytes,
92        host_iteration_polls: 0,
93        readback_policy: ConvergenceReadbackPolicy::FinalFlagOnly,
94    })
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn convergence_plan_reads_final_flag_once() {
103        let plan = plan_device_convergence(128, 4, 0).expect("Fix: valid plan should build");
104
105        assert_eq!(plan.max_device_iterations, 128);
106        assert_eq!(plan.host_sync_points, 1);
107        assert_eq!(plan.changed_flag_readback_bytes, 4);
108        assert_eq!(plan.host_iteration_polls, 0);
109        assert_eq!(
110            plan.readback_policy,
111            ConvergenceReadbackPolicy::FinalFlagOnly
112        );
113    }
114
115    #[test]
116    fn convergence_plan_rejects_empty_iteration_budget() {
117        let err = plan_device_convergence(0, 4, 0).expect_err("zero iterations cannot converge");
118
119        assert_eq!(err, DeviceConvergencePlanError::EmptyIterationBudget);
120        assert!(err.to_string().contains("at least one device iteration"));
121    }
122
123    #[test]
124    fn convergence_plan_rejects_wrong_changed_flag_width() {
125        let err = plan_device_convergence(8, 1, 0).expect_err("changed flag must be a u32");
126
127        assert_eq!(
128            err,
129            DeviceConvergencePlanError::InvalidChangedFlagWidth { bytes: 1 }
130        );
131        assert!(err.to_string().contains("4-byte device u32 changed flag"));
132    }
133
134    #[test]
135    fn convergence_plan_rejects_host_polled_iterations() {
136        let err = plan_device_convergence(8, 4, 8)
137            .expect_err("host polling every iteration is forbidden");
138
139        assert_eq!(
140            err,
141            DeviceConvergencePlanError::HostPolledConvergence { polls: 8 }
142        );
143        assert!(err.to_string().contains("read only the final changed flag"));
144    }
145
146    #[test]
147    fn generated_convergence_iteration_budgets_preserve_final_only_contract() {
148        for max_device_iterations in 1..=4_096 {
149            let plan = plan_device_convergence(max_device_iterations, 4, 0)
150                .expect("Fix: generated nonzero iteration budgets should plan");
151            assert_eq!(plan.max_device_iterations, max_device_iterations);
152            assert_eq!(plan.host_sync_points, 1);
153            assert_eq!(plan.changed_flag_readback_bytes, 4);
154            assert_eq!(plan.host_iteration_polls, 0);
155            assert_eq!(
156                plan.readback_policy,
157                ConvergenceReadbackPolicy::FinalFlagOnly
158            );
159        }
160    }
161}