tensorlogic_scirs_backend/autodiff.rs
1//! Automatic differentiation support (forward/backward passes).
2
3use tensorlogic_infer::{ExecutorError, TlAutodiff, TlExecutor};
4use tensorlogic_ir::EinsumGraph;
5
6use crate::einsum_grad::compute_einsum_gradients;
7use crate::ops::{parse_elem_op, parse_reduce_op};
8use crate::{Scirs2Exec, Scirs2Tensor};
9
10/// Stores intermediate values from forward pass for gradient computation
11#[derive(Clone)]
12pub struct ForwardTape {
13 /// All computed tensors indexed by their tensor index
14 pub tensors: Vec<Option<Scirs2Tensor>>,
15 /// Input tensors for each node (for gradient computation)
16 pub node_inputs: Vec<Vec<Scirs2Tensor>>,
17}
18
19impl ForwardTape {
20 /// Check if the tape has any computed gradients
21 pub fn is_empty(&self) -> bool {
22 self.tensors.iter().all(|t| t.is_none())
23 }
24
25 /// Get the number of non-None gradients in the tape
26 pub fn len(&self) -> usize {
27 self.tensors.iter().filter(|t| t.is_some()).count()
28 }
29}
30
31impl TlAutodiff for Scirs2Exec {
32 type Tape = ForwardTape;
33
34 fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
35 if graph.is_empty() {
36 return Err(ExecutorError::InvalidEinsumSpec(
37 "Empty graph provided".to_string(),
38 ));
39 }
40
41 if graph.outputs.is_empty() {
42 return Err(ExecutorError::InvalidEinsumSpec(
43 "No output tensors specified".to_string(),
44 ));
45 }
46
47 let mut computed_tensors: Vec<Option<Scirs2Tensor>> = vec![None; graph.tensors.len()];
48 let mut node_inputs: Vec<Vec<Scirs2Tensor>> = Vec::with_capacity(graph.nodes.len());
49
50 // Initialize input tensors from our stored tensors
51 for (idx, tensor_name) in graph.tensors.iter().enumerate() {
52 // Try direct lookup first
53 if let Some(tensor) = self.tensors.get(tensor_name) {
54 computed_tensors[idx] = Some(tensor.clone());
55 } else {
56 // Handle tensors with axes notation (e.g., "age[a]" -> "age")
57 let base_name = tensor_name.split('[').next().unwrap_or(tensor_name);
58
59 if let Some(tensor) = self.tensors.get(base_name) {
60 computed_tensors[idx] = Some(tensor.clone());
61 } else if tensor_name.starts_with("const_") || base_name.starts_with("const_") {
62 // Handle constant tensors: parse value from name like "const_5" or "const_3.14"
63 let const_name = if tensor_name.starts_with("const_") {
64 tensor_name
65 } else {
66 base_name
67 };
68
69 if let Some(value_str) = const_name.strip_prefix("const_") {
70 if let Ok(value) = value_str.parse::<f64>() {
71 // Create a scalar tensor with the constant value
72 use scirs2_core::ndarray::arr0;
73 computed_tensors[idx] = Some(arr0(value).into_dyn());
74 }
75 }
76 }
77 }
78 }
79
80 // Execute each operation node in the graph
81 for node in &graph.nodes {
82 let inputs: Result<Vec<_>, _> = node
83 .inputs
84 .iter()
85 .map(|&idx| {
86 computed_tensors
87 .get(idx)
88 .and_then(|t| t.as_ref())
89 .cloned()
90 .ok_or_else(|| {
91 ExecutorError::TensorNotFound(format!(
92 "Tensor at index {} not found for node with op: {:?}",
93 idx, node.op
94 ))
95 })
96 })
97 .collect();
98
99 let input_tensors = inputs?;
100
101 // Store input tensors for backward pass
102 node_inputs.push(input_tensors.clone());
103
104 // Dispatch based on operation type
105 let result = match &node.op {
106 tensorlogic_ir::OpType::Einsum { spec } => self.einsum(spec, &input_tensors)?,
107 tensorlogic_ir::OpType::ElemUnary { op } => {
108 if input_tensors.len() != 1 {
109 return Err(ExecutorError::InvalidEinsumSpec(format!(
110 "Element-wise unary op '{}' requires 1 input, got {}",
111 op,
112 input_tensors.len()
113 )));
114 }
115 let elem_op = parse_elem_op(op)?;
116 self.elem_op(elem_op, &input_tensors[0])?
117 }
118 tensorlogic_ir::OpType::ElemBinary { op } => {
119 if input_tensors.len() != 2 {
120 return Err(ExecutorError::InvalidEinsumSpec(format!(
121 "Element-wise binary op '{}' requires 2 inputs, got {}",
122 op,
123 input_tensors.len()
124 )));
125 }
126 let elem_op = parse_elem_op(op)?;
127 self.elem_op_binary(elem_op, &input_tensors[0], &input_tensors[1])?
128 }
129 tensorlogic_ir::OpType::Reduce { op, axes } => {
130 if input_tensors.len() != 1 {
131 return Err(ExecutorError::InvalidEinsumSpec(format!(
132 "Reduce op '{}' requires 1 input, got {}",
133 op,
134 input_tensors.len()
135 )));
136 }
137 let reduce_op = parse_reduce_op(op)?;
138 self.reduce(reduce_op, &input_tensors[0], axes)?
139 }
140 };
141
142 // Store the result at the correct output index specified by the node
143 if let Some(&output_idx) = node.outputs.first() {
144 computed_tensors[output_idx] = Some(result);
145 } else {
146 return Err(ExecutorError::InvalidEinsumSpec(
147 "Node has no output index specified".to_string(),
148 ));
149 }
150 }
151
152 // Store tape for potential backward pass
153 self.tape = Some(ForwardTape {
154 tensors: computed_tensors.clone(),
155 node_inputs,
156 });
157
158 // Return the output tensor
159 let output_idx = graph.outputs[0];
160 computed_tensors
161 .get(output_idx)
162 .and_then(|t| t.clone())
163 .ok_or_else(|| ExecutorError::TensorNotFound("Output tensor not computed".to_string()))
164 }
165
166 fn backward(
167 &mut self,
168 graph: &EinsumGraph,
169 loss_grad: &Self::Tensor,
170 ) -> Result<Self::Tape, Self::Error> {
171 if graph.is_empty() {
172 return Err(ExecutorError::InvalidEinsumSpec(
173 "Empty graph provided".to_string(),
174 ));
175 }
176
177 // Get the stored forward tape and clone node_inputs to avoid borrow conflicts
178 let node_inputs_vec = {
179 let forward_tape = self.tape.as_ref().ok_or_else(|| {
180 ExecutorError::InvalidEinsumSpec(
181 "Forward pass must be called before backward pass".to_string(),
182 )
183 })?;
184 forward_tape.node_inputs.clone()
185 };
186
187 // Initialize gradient storage - one gradient per tensor in the graph
188 let mut gradients: Vec<Option<Scirs2Tensor>> = vec![None; graph.tensors.len()];
189
190 // Set the gradient of the output tensor to the provided loss gradient
191 if !graph.outputs.is_empty() {
192 let output_idx = graph.outputs[0];
193 gradients[output_idx] = Some(loss_grad.clone());
194 }
195
196 // Backward pass through nodes in reverse order
197 for (node_idx, node) in graph.nodes.iter().enumerate().rev() {
198 // Get the gradient of this node's output
199 let output_idx = if let Some(&idx) = node.outputs.first() {
200 idx
201 } else {
202 continue;
203 };
204
205 let output_grad = if let Some(grad) = &gradients[output_idx] {
206 grad.clone()
207 } else {
208 // No gradient for this node's output - skip it
209 continue;
210 };
211
212 // Get the input tensors that were used in forward pass
213 let input_tensors = &node_inputs_vec[node_idx];
214
215 // Compute gradients for inputs based on operation type
216 match &node.op {
217 tensorlogic_ir::OpType::Einsum { spec } => {
218 // Proper einsum gradient computation
219 match compute_einsum_gradients(spec, input_tensors, &output_grad, self) {
220 Ok(einsum_grads) => {
221 // Accumulate gradients for each input
222 for (i, &input_idx) in node.inputs.iter().enumerate() {
223 if i < einsum_grads.len() {
224 let grad = &einsum_grads[i];
225 if gradients[input_idx].is_none() {
226 gradients[input_idx] = Some(grad.clone());
227 } else if let Some(existing_grad) = &mut gradients[input_idx] {
228 *existing_grad = &*existing_grad + grad;
229 }
230 }
231 }
232 }
233 Err(_) => {
234 // Fallback: pass gradients through (for unsupported einsum patterns)
235 for &input_idx in &node.inputs {
236 if gradients[input_idx].is_none() {
237 gradients[input_idx] = Some(output_grad.clone());
238 } else if let Some(existing_grad) = &mut gradients[input_idx] {
239 *existing_grad = &*existing_grad + &output_grad;
240 }
241 }
242 }
243 }
244 }
245 tensorlogic_ir::OpType::ElemUnary { op } => {
246 // Gradient through unary operations
247 if node.inputs.len() == 1 && !input_tensors.is_empty() {
248 let input_idx = node.inputs[0];
249 let input = &input_tensors[0];
250
251 let grad = match op.as_str() {
252 "relu" => {
253 // ReLU gradient: grad * (input > 0)
254 use scirs2_core::ndarray::Zip;
255 Zip::from(&output_grad).and(input).map_collect(|&g, &x| {
256 if x > 0.0 {
257 g
258 } else {
259 0.0
260 }
261 })
262 }
263 "sigmoid" => {
264 // Sigmoid gradient: grad * sigmoid(x) * (1 - sigmoid(x))
265 use scirs2_core::ndarray::Zip;
266 Zip::from(&output_grad).and(input).map_collect(|&g, &x| {
267 let s = 1.0 / (1.0 + (-x).exp());
268 g * s * (1.0 - s)
269 })
270 }
271 "oneminus" => {
272 // OneMinus gradient: d/dx(1 - x) = -1
273 &output_grad * (-1.0)
274 }
275 _ => output_grad.clone(),
276 };
277
278 if gradients[input_idx].is_none() {
279 gradients[input_idx] = Some(grad);
280 } else if let Some(existing_grad) = &mut gradients[input_idx] {
281 *existing_grad = &*existing_grad + &grad;
282 }
283 }
284 }
285 tensorlogic_ir::OpType::ElemBinary { op } => {
286 // Gradient through binary operations with access to input values
287 if node.inputs.len() == 2 && input_tensors.len() == 2 {
288 let x = &input_tensors[0];
289 let y = &input_tensors[1];
290
291 let (grad_x, grad_y) = match op.as_str() {
292 "add" => {
293 // d/dx(x + y) = 1, d/dy(x + y) = 1
294 (output_grad.clone(), output_grad.clone())
295 }
296 "subtract" | "sub" => {
297 // d/dx(x - y) = 1, d/dy(x - y) = -1
298 (output_grad.clone(), &output_grad * (-1.0))
299 }
300 "multiply" | "mul" => {
301 // d/dx(x * y) = y, d/dy(x * y) = x
302 (&output_grad * y, &output_grad * x)
303 }
304 "divide" | "div" => {
305 // d/dx(x / y) = 1/y, d/dy(x / y) = -x/y^2
306 (&output_grad / y, &output_grad * (-x) / (y * y))
307 }
308 // Comparison operations have zero gradients (non-differentiable)
309 "eq" | "lt" | "gt" | "lte" | "gte" => {
310 let zero_grad = Scirs2Tensor::zeros(output_grad.raw_dim());
311 (zero_grad.clone(), zero_grad)
312 }
313 // Extended logical operations with proper gradients
314 "or_max" | "ormax" => {
315 // OR(max): gradient flows to the larger value
316 use scirs2_core::ndarray::Zip;
317 let grad_x = Zip::from(&output_grad)
318 .and(x)
319 .and(y)
320 .map_collect(|&g, &a, &b| if a >= b { g } else { 0.0 });
321 let grad_y = Zip::from(&output_grad)
322 .and(x)
323 .and(y)
324 .map_collect(|&g, &a, &b| if b > a { g } else { 0.0 });
325 (grad_x, grad_y)
326 }
327 "or_prob_sum" | "orprobsum" | "or_probabilistic" => {
328 // OR(prob): a + b - ab, gradient: da = (1-b), db = (1-a)
329 use scirs2_core::ndarray::Zip;
330 let grad_x = Zip::from(&output_grad)
331 .and(y)
332 .map_collect(|&g, &b| g * (1.0 - b));
333 let grad_y = Zip::from(&output_grad)
334 .and(x)
335 .map_collect(|&g, &a| g * (1.0 - a));
336 (grad_x, grad_y)
337 }
338 "nand" => {
339 // NAND: 1 - ab, gradient: da = -b, db = -a
340 (&output_grad * (-y), &output_grad * (-x))
341 }
342 "nor" => {
343 // NOR: 1 - max(a,b), gradient flows negatively to max
344 use scirs2_core::ndarray::Zip;
345 let grad_x = Zip::from(&output_grad)
346 .and(x)
347 .and(y)
348 .map_collect(|&g, &a, &b| if a >= b { -g } else { 0.0 });
349 let grad_y = Zip::from(&output_grad)
350 .and(x)
351 .and(y)
352 .map_collect(|&g, &a, &b| if b > a { -g } else { 0.0 });
353 (grad_x, grad_y)
354 }
355 "xor" => {
356 // XOR: a + b - 2ab, gradient: da = 1 - 2b, db = 1 - 2a
357 use scirs2_core::ndarray::Zip;
358 let grad_x = Zip::from(&output_grad)
359 .and(y)
360 .map_collect(|&g, &b| g * (1.0 - 2.0 * b));
361 let grad_y = Zip::from(&output_grad)
362 .and(x)
363 .map_collect(|&g, &a| g * (1.0 - 2.0 * a));
364 (grad_x, grad_y)
365 }
366 _ => (output_grad.clone(), output_grad.clone()),
367 };
368
369 // Accumulate gradient for first input
370 let input_idx_0 = node.inputs[0];
371 if gradients[input_idx_0].is_none() {
372 gradients[input_idx_0] = Some(grad_x);
373 } else if let Some(existing_grad) = &mut gradients[input_idx_0] {
374 *existing_grad = &*existing_grad + &grad_x;
375 }
376
377 // Accumulate gradient for second input
378 let input_idx_1 = node.inputs[1];
379 if gradients[input_idx_1].is_none() {
380 gradients[input_idx_1] = Some(grad_y);
381 } else if let Some(existing_grad) = &mut gradients[input_idx_1] {
382 *existing_grad = &*existing_grad + &grad_y;
383 }
384 }
385 }
386 tensorlogic_ir::OpType::Reduce { op: _, axes } => {
387 // Gradient through reduction: broadcast gradient back to original shape
388 if node.inputs.len() == 1 && !input_tensors.is_empty() {
389 let input_idx = node.inputs[0];
390 let input_shape = input_tensors[0].shape();
391
392 // For reduction, gradient needs to be broadcast back to input shape
393 let grad = if axes.is_empty() {
394 // Global reduction - broadcast scalar to original shape
395 let mut result = Scirs2Tensor::zeros(input_shape);
396 result.fill(output_grad[[]]);
397 result
398 } else {
399 // Reduction over specific axes - expand dimensions
400 // For sum reduction, gradient is broadcast
401 // For max/min, gradient goes to the locations that were selected
402 use scirs2_core::ndarray::ArrayD;
403 let mut expanded_shape: Vec<usize> = input_shape.to_vec();
404 for &axis in axes {
405 expanded_shape[axis] = 1;
406 }
407
408 // Reshape output grad to match expanded shape
409 let reshaped = if let Ok(reshaped) = output_grad
410 .clone()
411 .into_shape_with_order(expanded_shape.clone())
412 {
413 reshaped
414 } else {
415 output_grad.clone()
416 };
417
418 // Broadcast to original shape
419 if let Some(broadcasted) = reshaped.broadcast(input_shape) {
420 broadcasted.to_owned()
421 } else {
422 // Fallback: just replicate the gradient
423 let mut result = ArrayD::zeros(input_shape);
424 // Simple replication for sum (correct for sum, approximate for max/min)
425 result
426 .iter_mut()
427 .for_each(|v| *v = output_grad.iter().sum::<f64>());
428 result
429 }
430 };
431
432 if gradients[input_idx].is_none() {
433 gradients[input_idx] = Some(grad);
434 } else if let Some(existing_grad) = &mut gradients[input_idx] {
435 *existing_grad = &*existing_grad + &grad;
436 }
437 }
438 }
439 }
440 }
441
442 // Return the forward tape with gradients computed
443 Ok(ForwardTape {
444 tensors: gradients,
445 node_inputs: node_inputs_vec,
446 })
447 }
448}