Skip to main content

tensorlogic_infer/
backend_kind.rs

1//! Backend selection and capability negotiation for TensorLogic inference engines.
2//!
3//! `BackendKind` is the canonical discriminant for which compute backend is active.
4//! - `Scirs` — pure-Rust CPU backend, always available.
5//! - `OxiCuda` — NVIDIA GPU backend via `tensorlogic-oxicuda-backend`; requires the
6//!   `gpu` feature on that crate.
7//! - All remaining variants are Round-5 stubs, planned for Round 6.
8
9use std::env;
10use std::str::FromStr;
11
12use thiserror::Error;
13
14// ─── Error type ──────────────────────────────────────────────────────────────
15
16/// Errors produced when resolving or validating a [`BackendKind`].
17#[derive(Debug, Error)]
18pub enum BackendKindError {
19    /// The named backend exists in the enum but has not been implemented yet.
20    #[error("backend '{name}' is not yet implemented (planned for Round 6)")]
21    Unimplemented { name: &'static str },
22
23    /// The string passed to [`BackendKind::from_str`] did not match any known backend.
24    #[error("unknown backend name: {0}")]
25    UnknownName(String),
26}
27
28// ─── Enum ────────────────────────────────────────────────────────────────────
29
30/// Discriminant for the active compute backend.
31///
32/// Only [`BackendKind::Scirs`] and [`BackendKind::OxiCuda`] are fully supported in
33/// Round 5. All other variants are doc-hidden stubs that will be enabled in Round 6.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum BackendKind {
36    /// CPU backend powered by SciRS2-Core. Always available.
37    Scirs,
38
39    /// NVIDIA GPU backend. Requires `feature = "gpu"` on `tensorlogic-oxicuda-backend`.
40    OxiCuda,
41
42    #[doc(hidden)]
43    /// Apple Metal GPU backend. Stub — not yet implemented.
44    Metal,
45
46    #[doc(hidden)]
47    /// Vulkan compute backend. Stub — not yet implemented.
48    Vulkan,
49
50    #[doc(hidden)]
51    /// AMD ROCm backend. Stub — not yet implemented.
52    Rocm,
53
54    #[doc(hidden)]
55    /// WebGPU backend. Stub — not yet implemented.
56    Webgpu,
57
58    #[doc(hidden)]
59    /// Intel Level Zero backend. Stub — not yet implemented.
60    Levelzero,
61}
62
63// ─── Core methods ─────────────────────────────────────────────────────────────
64
65impl BackendKind {
66    /// Returns `OxiCuda` when the environment variable `TENSORLOGIC_BACKEND` is set
67    /// to `"oxicuda"`, otherwise returns `Scirs`.
68    ///
69    /// This is the recommended entry-point for runtime backend selection.
70    pub fn default_backend() -> Self {
71        match env::var("TENSORLOGIC_BACKEND")
72            .unwrap_or_default()
73            .to_ascii_lowercase()
74            .as_str()
75        {
76            "oxicuda" => Self::OxiCuda,
77            _ => Self::Scirs,
78        }
79    }
80
81    /// Reads `TENSORLOGIC_BACKEND` and maps it to a `BackendKind`.
82    ///
83    /// Recognised values (case-insensitive):
84    /// - `"oxicuda"`, `"gpu"`, `"cuda"` → [`BackendKind::OxiCuda`]
85    /// - anything else (including unset) → [`BackendKind::Scirs`]
86    ///
87    /// Non-NVIDIA variants (`metal`, `vulkan`, `rocm`, `webgpu`, `levelzero`) are
88    /// parsed correctly by [`BackendKind::from_str`] but always fail
89    /// [`BackendKind::validate`] with [`BackendKindError::Unimplemented`].
90    pub fn from_env() -> Self {
91        env::var("TENSORLOGIC_BACKEND")
92            .ok()
93            .and_then(|s| s.parse::<Self>().ok())
94            .unwrap_or(Self::Scirs)
95    }
96
97    /// Returns `true` when this backend uses a GPU device.
98    pub fn is_gpu(&self) -> bool {
99        matches!(
100            self,
101            Self::OxiCuda
102                | Self::Metal
103                | Self::Vulkan
104                | Self::Rocm
105                | Self::Webgpu
106                | Self::Levelzero
107        )
108    }
109
110    /// Returns `true` when this backend supports automatic differentiation.
111    ///
112    /// In Round 5 only `Scirs` and `OxiCuda` implement autodiff; all other
113    /// variants are stubs.
114    pub fn supports_autodiff(&self) -> bool {
115        matches!(self, Self::Scirs | Self::OxiCuda)
116    }
117
118    /// Returns a human-readable, stable identifier for the backend.
119    pub fn as_str(&self) -> &'static str {
120        match self {
121            Self::Scirs => "scirs",
122            Self::OxiCuda => "oxicuda",
123            Self::Metal => "metal",
124            Self::Vulkan => "vulkan",
125            Self::Rocm => "rocm",
126            Self::Webgpu => "webgpu",
127            Self::Levelzero => "levelzero",
128        }
129    }
130
131    /// Returns all currently enumerated backends.
132    ///
133    /// `Scirs` is always fully supported. The remaining backends are listed for
134    /// completeness but are not yet implemented (Round 6).
135    pub fn available_backends() -> Vec<Self> {
136        vec![
137            Self::Scirs,     // fully supported
138            Self::OxiCuda,   // fully supported when GPU feature enabled
139            Self::Metal,     // stub
140            Self::Vulkan,    // stub
141            Self::Rocm,      // stub
142            Self::Webgpu,    // stub
143            Self::Levelzero, // stub
144        ]
145    }
146
147    /// Validates that this backend is fully implemented and can be activated.
148    ///
149    /// Returns `Ok(())` for `Scirs` and `OxiCuda`; returns
150    /// [`BackendKindError::Unimplemented`] for all Round-5 stubs.
151    pub fn validate(&self) -> Result<(), BackendKindError> {
152        match self {
153            Self::Scirs | Self::OxiCuda => Ok(()),
154            other => Err(BackendKindError::Unimplemented {
155                name: other.as_str(),
156            }),
157        }
158    }
159}
160
161// ─── FromStr impl ────────────────────────────────────────────────────────────
162
163impl FromStr for BackendKind {
164    type Err = BackendKindError;
165
166    /// Parses a backend name string into a `BackendKind`.
167    ///
168    /// Accepted aliases (all case-insensitive):
169    ///
170    /// | Input | Result |
171    /// |---|---|
172    /// | `"scirs"`, `"cpu"` | `Scirs` |
173    /// | `"oxicuda"`, `"gpu"`, `"cuda"` | `OxiCuda` |
174    /// | `"metal"` | `Metal` |
175    /// | `"vulkan"` | `Vulkan` |
176    /// | `"rocm"`, `"hip"` | `Rocm` |
177    /// | `"webgpu"` | `Webgpu` |
178    /// | `"levelzero"`, `"level_zero"` | `Levelzero` |
179    /// | anything else | `Err(BackendKindError::UnknownName)` |
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        match s.to_ascii_lowercase().as_str() {
182            "scirs" | "cpu" => Ok(Self::Scirs),
183            "oxicuda" | "gpu" | "cuda" => Ok(Self::OxiCuda),
184            "metal" => Ok(Self::Metal),
185            "vulkan" => Ok(Self::Vulkan),
186            "rocm" | "hip" => Ok(Self::Rocm),
187            "webgpu" => Ok(Self::Webgpu),
188            "levelzero" | "level_zero" => Ok(Self::Levelzero),
189            other => Err(BackendKindError::UnknownName(other.to_string())),
190        }
191    }
192}