Skip to main content

tensorlogic_scirs_backend/
autodiff.rs

1//! Automatic differentiation support (forward/backward passes).
2
3use tensorlogic_infer::{ExecutorError, TlAutodiff, TlExecutor};
4use tensorlogic_ir::EinsumGraph;
5
6use crate::einsum_grad::compute_einsum_gradients;
7use crate::ops::{parse_elem_op, parse_reduce_op};
8use crate::{Scirs2Exec, Scirs2Tensor};
9
10/// Stores intermediate values from forward pass for gradient computation
11#[derive(Clone)]
12pub struct ForwardTape {
13    /// All computed tensors indexed by their tensor index
14    pub tensors: Vec<Option<Scirs2Tensor>>,
15    /// Input tensors for each node (for gradient computation)
16    pub node_inputs: Vec<Vec<Scirs2Tensor>>,
17}
18
19impl ForwardTape {
20    /// Check if the tape has any computed gradients
21    pub fn is_empty(&self) -> bool {
22        self.tensors.iter().all(|t| t.is_none())
23    }
24
25    /// Get the number of non-None gradients in the tape
26    pub fn len(&self) -> usize {
27        self.tensors.iter().filter(|t| t.is_some()).count()
28    }
29}
30
31impl TlAutodiff for Scirs2Exec {
32    type Tape = ForwardTape;
33
34    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
35        if graph.is_empty() {
36            return Err(ExecutorError::InvalidEinsumSpec(
37                "Empty graph provided".to_string(),
38            ));
39        }
40
41        if graph.outputs.is_empty() {
42            return Err(ExecutorError::InvalidEinsumSpec(
43                "No output tensors specified".to_string(),
44            ));
45        }
46
47        let mut computed_tensors: Vec<Option<Scirs2Tensor>> = vec![None; graph.tensors.len()];
48        let mut node_inputs: Vec<Vec<Scirs2Tensor>> = Vec::with_capacity(graph.nodes.len());
49
50        // Initialize input tensors from our stored tensors
51        for (idx, tensor_name) in graph.tensors.iter().enumerate() {
52            // Try direct lookup first
53            if let Some(tensor) = self.tensors.get(tensor_name) {
54                computed_tensors[idx] = Some(tensor.clone());
55            } else {
56                // Handle tensors with axes notation (e.g., "age[a]" -> "age")
57                let base_name = tensor_name.split('[').next().unwrap_or(tensor_name);
58
59                if let Some(tensor) = self.tensors.get(base_name) {
60                    computed_tensors[idx] = Some(tensor.clone());
61                } else if tensor_name.starts_with("const_") || base_name.starts_with("const_") {
62                    // Handle constant tensors: parse value from name like "const_5" or "const_3.14"
63                    let const_name = if tensor_name.starts_with("const_") {
64                        tensor_name
65                    } else {
66                        base_name
67                    };
68
69                    if let Some(value_str) = const_name.strip_prefix("const_") {
70                        if let Ok(value) = value_str.parse::<f64>() {
71                            // Create a scalar tensor with the constant value
72                            use scirs2_core::ndarray::arr0;
73                            computed_tensors[idx] = Some(arr0(value).into_dyn());
74                        }
75                    }
76                }
77            }
78        }
79
80        // Execute each operation node in the graph
81        for node in &graph.nodes {
82            let inputs: Result<Vec<_>, _> = node
83                .inputs
84                .iter()
85                .map(|&idx| {
86                    computed_tensors
87                        .get(idx)
88                        .and_then(|t| t.as_ref())
89                        .cloned()
90                        .ok_or_else(|| {
91                            ExecutorError::TensorNotFound(format!(
92                                "Tensor at index {} not found for node with op: {:?}",
93                                idx, node.op
94                            ))
95                        })
96                })
97                .collect();
98
99            let input_tensors = inputs?;
100
101            // Store input tensors for backward pass
102            node_inputs.push(input_tensors.clone());
103
104            // Dispatch based on operation type
105            let result = match &node.op {
106                tensorlogic_ir::OpType::Einsum { spec } => self.einsum(spec, &input_tensors)?,
107                tensorlogic_ir::OpType::ElemUnary { op } => {
108                    if input_tensors.len() != 1 {
109                        return Err(ExecutorError::InvalidEinsumSpec(format!(
110                            "Element-wise unary op '{}' requires 1 input, got {}",
111                            op,
112                            input_tensors.len()
113                        )));
114                    }
115                    let elem_op = parse_elem_op(op)?;
116                    self.elem_op(elem_op, &input_tensors[0])?
117                }
118                tensorlogic_ir::OpType::ElemBinary { op } => {
119                    if input_tensors.len() != 2 {
120                        return Err(ExecutorError::InvalidEinsumSpec(format!(
121                            "Element-wise binary op '{}' requires 2 inputs, got {}",
122                            op,
123                            input_tensors.len()
124                        )));
125                    }
126                    let elem_op = parse_elem_op(op)?;
127                    self.elem_op_binary(elem_op, &input_tensors[0], &input_tensors[1])?
128                }
129                tensorlogic_ir::OpType::Reduce { op, axes } => {
130                    if input_tensors.len() != 1 {
131                        return Err(ExecutorError::InvalidEinsumSpec(format!(
132                            "Reduce op '{}' requires 1 input, got {}",
133                            op,
134                            input_tensors.len()
135                        )));
136                    }
137                    let reduce_op = parse_reduce_op(op)?;
138                    self.reduce(reduce_op, &input_tensors[0], axes)?
139                }
140            };
141
142            // Store the result at the correct output index specified by the node
143            if let Some(&output_idx) = node.outputs.first() {
144                computed_tensors[output_idx] = Some(result);
145            } else {
146                return Err(ExecutorError::InvalidEinsumSpec(
147                    "Node has no output index specified".to_string(),
148                ));
149            }
150        }
151
152        // Store tape for potential backward pass
153        self.tape = Some(ForwardTape {
154            tensors: computed_tensors.clone(),
155            node_inputs,
156        });
157
158        // Return the output tensor
159        let output_idx = graph.outputs[0];
160        computed_tensors
161            .get(output_idx)
162            .and_then(|t| t.clone())
163            .ok_or_else(|| ExecutorError::TensorNotFound("Output tensor not computed".to_string()))
164    }
165
166    fn backward(
167        &mut self,
168        graph: &EinsumGraph,
169        loss_grad: &Self::Tensor,
170    ) -> Result<Self::Tape, Self::Error> {
171        if graph.is_empty() {
172            return Err(ExecutorError::InvalidEinsumSpec(
173                "Empty graph provided".to_string(),
174            ));
175        }
176
177        // Get the stored forward tape and clone node_inputs to avoid borrow conflicts
178        let node_inputs_vec = {
179            let forward_tape = self.tape.as_ref().ok_or_else(|| {
180                ExecutorError::InvalidEinsumSpec(
181                    "Forward pass must be called before backward pass".to_string(),
182                )
183            })?;
184            forward_tape.node_inputs.clone()
185        };
186
187        // Initialize gradient storage - one gradient per tensor in the graph
188        let mut gradients: Vec<Option<Scirs2Tensor>> = vec![None; graph.tensors.len()];
189
190        // Set the gradient of the output tensor to the provided loss gradient
191        if !graph.outputs.is_empty() {
192            let output_idx = graph.outputs[0];
193            gradients[output_idx] = Some(loss_grad.clone());
194        }
195
196        // Backward pass through nodes in reverse order
197        for (node_idx, node) in graph.nodes.iter().enumerate().rev() {
198            // Get the gradient of this node's output
199            let output_idx = if let Some(&idx) = node.outputs.first() {
200                idx
201            } else {
202                continue;
203            };
204
205            let output_grad = if let Some(grad) = &gradients[output_idx] {
206                grad.clone()
207            } else {
208                // No gradient for this node's output - skip it
209                continue;
210            };
211
212            // Get the input tensors that were used in forward pass
213            let input_tensors = &node_inputs_vec[node_idx];
214
215            // Compute gradients for inputs based on operation type
216            match &node.op {
217                tensorlogic_ir::OpType::Einsum { spec } => {
218                    // Proper einsum gradient computation
219                    match compute_einsum_gradients(spec, input_tensors, &output_grad, self) {
220                        Ok(einsum_grads) => {
221                            // Accumulate gradients for each input
222                            for (i, &input_idx) in node.inputs.iter().enumerate() {
223                                if i < einsum_grads.len() {
224                                    let grad = &einsum_grads[i];
225                                    if gradients[input_idx].is_none() {
226                                        gradients[input_idx] = Some(grad.clone());
227                                    } else if let Some(existing_grad) = &mut gradients[input_idx] {
228                                        *existing_grad = &*existing_grad + grad;
229                                    }
230                                }
231                            }
232                        }
233                        Err(_) => {
234                            // Fallback: pass gradients through (for unsupported einsum patterns)
235                            for &input_idx in &node.inputs {
236                                if gradients[input_idx].is_none() {
237                                    gradients[input_idx] = Some(output_grad.clone());
238                                } else if let Some(existing_grad) = &mut gradients[input_idx] {
239                                    *existing_grad = &*existing_grad + &output_grad;
240                                }
241                            }
242                        }
243                    }
244                }
245                tensorlogic_ir::OpType::ElemUnary { op } => {
246                    // Gradient through unary operations
247                    if node.inputs.len() == 1 && !input_tensors.is_empty() {
248                        let input_idx = node.inputs[0];
249                        let input = &input_tensors[0];
250
251                        let grad = match op.as_str() {
252                            "relu" => {
253                                // ReLU gradient: grad * (input > 0)
254                                use scirs2_core::ndarray::Zip;
255                                Zip::from(&output_grad).and(input).map_collect(|&g, &x| {
256                                    if x > 0.0 {
257                                        g
258                                    } else {
259                                        0.0
260                                    }
261                                })
262                            }
263                            "sigmoid" => {
264                                // Sigmoid gradient: grad * sigmoid(x) * (1 - sigmoid(x))
265                                use scirs2_core::ndarray::Zip;
266                                Zip::from(&output_grad).and(input).map_collect(|&g, &x| {
267                                    let s = 1.0 / (1.0 + (-x).exp());
268                                    g * s * (1.0 - s)
269                                })
270                            }
271                            "oneminus" => {
272                                // OneMinus gradient: d/dx(1 - x) = -1
273                                &output_grad * (-1.0)
274                            }
275                            _ => output_grad.clone(),
276                        };
277
278                        if gradients[input_idx].is_none() {
279                            gradients[input_idx] = Some(grad);
280                        } else if let Some(existing_grad) = &mut gradients[input_idx] {
281                            *existing_grad = &*existing_grad + &grad;
282                        }
283                    }
284                }
285                tensorlogic_ir::OpType::ElemBinary { op } => {
286                    // Gradient through binary operations with access to input values
287                    if node.inputs.len() == 2 && input_tensors.len() == 2 {
288                        let x = &input_tensors[0];
289                        let y = &input_tensors[1];
290
291                        let (grad_x, grad_y) = match op.as_str() {
292                            "add" => {
293                                // d/dx(x + y) = 1, d/dy(x + y) = 1
294                                (output_grad.clone(), output_grad.clone())
295                            }
296                            "subtract" | "sub" => {
297                                // d/dx(x - y) = 1, d/dy(x - y) = -1
298                                (output_grad.clone(), &output_grad * (-1.0))
299                            }
300                            "multiply" | "mul" => {
301                                // d/dx(x * y) = y, d/dy(x * y) = x
302                                (&output_grad * y, &output_grad * x)
303                            }
304                            "divide" | "div" => {
305                                // d/dx(x / y) = 1/y, d/dy(x / y) = -x/y^2
306                                (&output_grad / y, &output_grad * (-x) / (y * y))
307                            }
308                            // Comparison operations have zero gradients (non-differentiable)
309                            "eq" | "lt" | "gt" | "lte" | "gte" => {
310                                let zero_grad = Scirs2Tensor::zeros(output_grad.raw_dim());
311                                (zero_grad.clone(), zero_grad)
312                            }
313                            // Extended logical operations with proper gradients
314                            "or_max" | "ormax" => {
315                                // OR(max): gradient flows to the larger value
316                                use scirs2_core::ndarray::Zip;
317                                let grad_x = Zip::from(&output_grad)
318                                    .and(x)
319                                    .and(y)
320                                    .map_collect(|&g, &a, &b| if a >= b { g } else { 0.0 });
321                                let grad_y = Zip::from(&output_grad)
322                                    .and(x)
323                                    .and(y)
324                                    .map_collect(|&g, &a, &b| if b > a { g } else { 0.0 });
325                                (grad_x, grad_y)
326                            }
327                            "or_prob_sum" | "orprobsum" | "or_probabilistic" => {
328                                // OR(prob): a + b - ab, gradient: da = (1-b), db = (1-a)
329                                use scirs2_core::ndarray::Zip;
330                                let grad_x = Zip::from(&output_grad)
331                                    .and(y)
332                                    .map_collect(|&g, &b| g * (1.0 - b));
333                                let grad_y = Zip::from(&output_grad)
334                                    .and(x)
335                                    .map_collect(|&g, &a| g * (1.0 - a));
336                                (grad_x, grad_y)
337                            }
338                            "nand" => {
339                                // NAND: 1 - ab, gradient: da = -b, db = -a
340                                (&output_grad * (-y), &output_grad * (-x))
341                            }
342                            "nor" => {
343                                // NOR: 1 - max(a,b), gradient flows negatively to max
344                                use scirs2_core::ndarray::Zip;
345                                let grad_x = Zip::from(&output_grad)
346                                    .and(x)
347                                    .and(y)
348                                    .map_collect(|&g, &a, &b| if a >= b { -g } else { 0.0 });
349                                let grad_y = Zip::from(&output_grad)
350                                    .and(x)
351                                    .and(y)
352                                    .map_collect(|&g, &a, &b| if b > a { -g } else { 0.0 });
353                                (grad_x, grad_y)
354                            }
355                            "xor" => {
356                                // XOR: a + b - 2ab, gradient: da = 1 - 2b, db = 1 - 2a
357                                use scirs2_core::ndarray::Zip;
358                                let grad_x = Zip::from(&output_grad)
359                                    .and(y)
360                                    .map_collect(|&g, &b| g * (1.0 - 2.0 * b));
361                                let grad_y = Zip::from(&output_grad)
362                                    .and(x)
363                                    .map_collect(|&g, &a| g * (1.0 - 2.0 * a));
364                                (grad_x, grad_y)
365                            }
366                            _ => (output_grad.clone(), output_grad.clone()),
367                        };
368
369                        // Accumulate gradient for first input
370                        let input_idx_0 = node.inputs[0];
371                        if gradients[input_idx_0].is_none() {
372                            gradients[input_idx_0] = Some(grad_x);
373                        } else if let Some(existing_grad) = &mut gradients[input_idx_0] {
374                            *existing_grad = &*existing_grad + &grad_x;
375                        }
376
377                        // Accumulate gradient for second input
378                        let input_idx_1 = node.inputs[1];
379                        if gradients[input_idx_1].is_none() {
380                            gradients[input_idx_1] = Some(grad_y);
381                        } else if let Some(existing_grad) = &mut gradients[input_idx_1] {
382                            *existing_grad = &*existing_grad + &grad_y;
383                        }
384                    }
385                }
386                tensorlogic_ir::OpType::Reduce { op: _, axes } => {
387                    // Gradient through reduction: broadcast gradient back to original shape
388                    if node.inputs.len() == 1 && !input_tensors.is_empty() {
389                        let input_idx = node.inputs[0];
390                        let input_shape = input_tensors[0].shape();
391
392                        // For reduction, gradient needs to be broadcast back to input shape
393                        let grad = if axes.is_empty() {
394                            // Global reduction - broadcast scalar to original shape
395                            let mut result = Scirs2Tensor::zeros(input_shape);
396                            result.fill(output_grad[[]]);
397                            result
398                        } else {
399                            // Reduction over specific axes - expand dimensions
400                            // For sum reduction, gradient is broadcast
401                            // For max/min, gradient goes to the locations that were selected
402                            use scirs2_core::ndarray::ArrayD;
403                            let mut expanded_shape: Vec<usize> = input_shape.to_vec();
404                            for &axis in axes {
405                                expanded_shape[axis] = 1;
406                            }
407
408                            // Reshape output grad to match expanded shape
409                            let reshaped = if let Ok(reshaped) = output_grad
410                                .clone()
411                                .into_shape_with_order(expanded_shape.clone())
412                            {
413                                reshaped
414                            } else {
415                                output_grad.clone()
416                            };
417
418                            // Broadcast to original shape
419                            if let Some(broadcasted) = reshaped.broadcast(input_shape) {
420                                broadcasted.to_owned()
421                            } else {
422                                // Fallback: just replicate the gradient
423                                let mut result = ArrayD::zeros(input_shape);
424                                // Simple replication for sum (correct for sum, approximate for max/min)
425                                result
426                                    .iter_mut()
427                                    .for_each(|v| *v = output_grad.iter().sum::<f64>());
428                                result
429                            }
430                        };
431
432                        if gradients[input_idx].is_none() {
433                            gradients[input_idx] = Some(grad);
434                        } else if let Some(existing_grad) = &mut gradients[input_idx] {
435                            *existing_grad = &*existing_grad + &grad;
436                        }
437                    }
438                }
439            }
440        }
441
442        // Return the forward tape with gradients computed
443        Ok(ForwardTape {
444            tensors: gradients,
445            node_inputs: node_inputs_vec,
446        })
447    }
448}