Skip to main content

swift_demangler/
autodiff.rs

1//! Automatic differentiation symbol representation.
2//!
3//! These symbols represent Swift's automatic differentiation features
4//! used in machine learning and scientific computing.
5
6use crate::helpers::{HasModule, NodeExt};
7use crate::raw::{Node, NodeKind};
8
9/// A Swift automatic differentiation symbol.
10#[derive(Clone, Copy)]
11pub struct AutoDiff<'ctx> {
12    raw: Node<'ctx>,
13}
14
15impl<'ctx> AutoDiff<'ctx> {
16    /// Create an AutoDiff from a raw node.
17    pub fn new(raw: Node<'ctx>) -> Self {
18        Self { raw }
19    }
20
21    /// Get the underlying raw node.
22    pub fn raw(&self) -> Node<'ctx> {
23        self.raw
24    }
25
26    /// Get the kind of auto-diff symbol.
27    pub fn kind(&self) -> AutoDiffKind {
28        match self.raw.kind() {
29            NodeKind::AutoDiffFunction => AutoDiffKind::Function,
30            NodeKind::DifferentiabilityWitness => AutoDiffKind::DifferentiabilityWitness,
31            NodeKind::AutoDiffDerivativeVTableThunk => AutoDiffKind::DerivativeVTableThunk,
32            NodeKind::AutoDiffSubsetParametersThunk => AutoDiffKind::SubsetParametersThunk,
33            NodeKind::AutoDiffSelfReorderingReabstractionThunk => {
34                AutoDiffKind::SelfReorderingReabstractionThunk
35            }
36            _ => AutoDiffKind::Other,
37        }
38    }
39
40    /// Get the inner function being differentiated, if any.
41    pub fn inner_function(&self) -> Option<crate::function::Function<'ctx>> {
42        self.raw
43            .child_of_kind(NodeKind::Function)
44            .map(crate::function::Function::new)
45    }
46
47    /// Get the module containing this auto-diff symbol.
48    pub fn module(&self) -> Option<&'ctx str> {
49        // Try inner function first
50        if let Some(func) = self.inner_function() {
51            return func.module();
52        }
53        // Search descendants
54        for node in self.raw.descendants() {
55            if node.kind() == NodeKind::Module {
56                return node.text();
57            }
58        }
59        None
60    }
61}
62
63impl std::fmt::Debug for AutoDiff<'_> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("AutoDiff")
66            .field("kind", &self.kind())
67            .field("inner_function", &self.inner_function())
68            .field("module", &self.module())
69            .finish()
70    }
71}
72
73impl std::fmt::Display for AutoDiff<'_> {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(f, "{}", self.raw)
76    }
77}
78
79/// The kind of automatic differentiation symbol.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum AutoDiffKind {
82    /// An auto-diff derivative function.
83    Function,
84    /// A differentiability witness.
85    DifferentiabilityWitness,
86    /// A derivative vtable thunk.
87    DerivativeVTableThunk,
88    /// A subset parameters thunk.
89    SubsetParametersThunk,
90    /// A self-reordering reabstraction thunk.
91    SelfReorderingReabstractionThunk,
92    /// Other auto-diff symbol.
93    Other,
94}
95
96impl AutoDiffKind {
97    /// Get a human-readable name for this auto-diff kind.
98    pub fn name(&self) -> &'static str {
99        match self {
100            AutoDiffKind::Function => "auto-diff function",
101            AutoDiffKind::DifferentiabilityWitness => "differentiability witness",
102            AutoDiffKind::DerivativeVTableThunk => "derivative vtable thunk",
103            AutoDiffKind::SubsetParametersThunk => "subset parameters thunk",
104            AutoDiffKind::SelfReorderingReabstractionThunk => "self-reordering reabstraction thunk",
105            AutoDiffKind::Other => "auto-diff symbol",
106        }
107    }
108}