1use std::ops;
7
8use scivex_core::{Float, Tensor};
9
10use crate::variable::Variable;
11
12pub fn add<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
28 let data = &a.data() + &b.data();
29 Variable::from_op(
30 data,
31 vec![a.clone(), b.clone()],
32 Box::new(|g: &Tensor<T>| vec![g.clone(), g.clone()]),
33 )
34}
35
36pub fn sub<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
50 let data = &a.data() - &b.data();
51 Variable::from_op(
52 data,
53 vec![a.clone(), b.clone()],
54 Box::new(|g: &Tensor<T>| vec![g.clone(), -g]),
55 )
56}
57
58pub fn mul<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
72 let a_data = a.data();
73 let b_data = b.data();
74 let data = &a_data * &b_data;
75 Variable::from_op(
76 data,
77 vec![a.clone(), b.clone()],
78 Box::new(move |g: &Tensor<T>| {
79 let ga = g
80 .zip_map(&b_data, |gi, bi| gi * bi)
81 .expect("shapes match from forward pass");
82 let gb = g
83 .zip_map(&a_data, |gi, ai| gi * ai)
84 .expect("shapes match from forward pass");
85 vec![ga, gb]
86 }),
87 )
88}
89
90pub fn neg<T: Float>(a: &Variable<T>) -> Variable<T> {
103 let data = -&a.data();
104 Variable::from_op(data, vec![a.clone()], Box::new(|g: &Tensor<T>| vec![-g]))
105}
106
107pub fn matmul<T: Float>(a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
125 let a_data = a.data();
126 let b_data = b.data();
127 let data = a_data
128 .matmul(&b_data)
129 .expect("matmul shapes validated at call site");
130 Variable::from_op(
131 data,
132 vec![a.clone(), b.clone()],
133 Box::new(move |g: &Tensor<T>| {
134 let bt = b_data.transpose().expect("2-D from forward pass");
136 let ga = g.matmul(&bt).expect("shapes match from forward pass");
137 let at = a_data.transpose().expect("2-D from forward pass");
139 let gb = at.matmul(g).expect("shapes match from forward pass");
140 vec![ga, gb]
141 }),
142 )
143}
144
145pub fn sum<T: Float>(a: &Variable<T>) -> Variable<T> {
160 let s = a.data().sum();
161 let shape = a.shape();
162 let data = Tensor::from_vec(vec![s], vec![1]).expect("scalar tensor");
163 Variable::from_op(
164 data,
165 vec![a.clone()],
166 Box::new(move |g: &Tensor<T>| {
167 let g_val = g.as_slice()[0];
169 vec![Tensor::full(shape.clone(), g_val)]
170 }),
171 )
172}
173
174pub fn mean<T: Float>(a: &Variable<T>) -> Variable<T> {
187 let n = a.data().numel();
188 let m = a.data().mean();
189 let shape = a.shape();
190 let data = Tensor::from_vec(vec![m], vec![1]).expect("scalar tensor");
191 Variable::from_op(
192 data,
193 vec![a.clone()],
194 Box::new(move |g: &Tensor<T>| {
195 let g_val = g.as_slice()[0];
196 let scale = g_val / T::from_usize(n);
197 vec![Tensor::full(shape.clone(), scale)]
198 }),
199 )
200}
201
202pub fn pow<T: Float>(a: &Variable<T>, exponent: T) -> Variable<T> {
215 let a_data = a.data();
216 let data = a_data.powf(exponent);
217 Variable::from_op(
218 data,
219 vec![a.clone()],
220 Box::new(move |g: &Tensor<T>| {
221 let n_minus_1 = exponent - T::one();
223 let deriv = a_data.powf(n_minus_1).map(|v| exponent * v);
224 let grad = g
225 .zip_map(&deriv, |gi, di| gi * di)
226 .expect("shapes match from forward pass");
227 vec![grad]
228 }),
229 )
230}
231
232pub fn scalar_mul<T: Float>(a: &Variable<T>, scalar: T) -> Variable<T> {
245 let data = &a.data() * scalar;
246 Variable::from_op(
247 data,
248 vec![a.clone()],
249 Box::new(move |g: &Tensor<T>| vec![g.map(|v| v * scalar)]),
250 )
251}
252
253pub fn scalar_div<T: Float>(a: &Variable<T>, scalar: T) -> Variable<T> {
266 scalar_mul(a, T::one() / scalar)
267}
268
269pub fn add_bias<T: Float>(input: &Variable<T>, bias: &Variable<T>) -> Variable<T> {
285 let x = input.data();
286 let b = bias.data();
287 let shape = x.shape().to_vec();
288 let rows = shape[0];
289 let cols = shape[1];
290
291 let mut out_data = Vec::with_capacity(rows * cols);
293 let b_slice = b.as_slice();
294 let x_slice = x.as_slice();
295 for r in 0..rows {
296 for c in 0..cols {
297 out_data.push(x_slice[r * cols + c] + b_slice[c]);
298 }
299 }
300 let data =
301 Tensor::from_vec(out_data, shape).expect("output data length matches shape from input");
302
303 let cols_copy = cols;
304 Variable::from_op(
305 data,
306 vec![input.clone(), bias.clone()],
307 Box::new(move |g: &Tensor<T>| {
308 let g_input = g.clone();
310 let g_slice = g.as_slice();
312 let g_rows = g.shape()[0];
313 let mut bias_grad = vec![T::zero(); cols_copy];
314 for r in 0..g_rows {
315 for c in 0..cols_copy {
316 bias_grad[c] += g_slice[r * cols_copy + c];
317 }
318 }
319 let g_bias = Tensor::from_vec(bias_grad, vec![cols_copy])
320 .expect("bias grad length matches feature count");
321 vec![g_input, g_bias]
322 }),
323 )
324}
325
326impl<T: Float> ops::Add for &Variable<T> {
329 type Output = Variable<T>;
330 fn add(self, rhs: Self) -> Variable<T> {
331 add(self, rhs)
332 }
333}
334
335impl<T: Float> ops::Add for Variable<T> {
336 type Output = Variable<T>;
337 fn add(self, rhs: Self) -> Variable<T> {
338 add(&self, &rhs)
339 }
340}
341
342impl<T: Float> ops::Sub for &Variable<T> {
343 type Output = Variable<T>;
344 fn sub(self, rhs: Self) -> Variable<T> {
345 sub(self, rhs)
346 }
347}
348
349impl<T: Float> ops::Sub for Variable<T> {
350 type Output = Variable<T>;
351 fn sub(self, rhs: Self) -> Variable<T> {
352 sub(&self, &rhs)
353 }
354}
355
356impl<T: Float> ops::Mul for &Variable<T> {
357 type Output = Variable<T>;
358 fn mul(self, rhs: Self) -> Variable<T> {
359 mul(self, rhs)
360 }
361}
362
363impl<T: Float> ops::Mul for Variable<T> {
364 type Output = Variable<T>;
365 fn mul(self, rhs: Self) -> Variable<T> {
366 mul(&self, &rhs)
367 }
368}
369
370impl<T: Float> ops::Neg for &Variable<T> {
371 type Output = Variable<T>;
372 fn neg(self) -> Variable<T> {
373 neg(self)
374 }
375}
376
377impl<T: Float> ops::Neg for Variable<T> {
378 type Output = Variable<T>;
379 fn neg(self) -> Variable<T> {
380 neg(&self)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 fn var(vals: &[f64]) -> Variable<f64> {
389 let t = Tensor::from_vec(vals.to_vec(), vec![vals.len()]).unwrap();
390 Variable::new(t, true)
391 }
392
393 #[test]
394 fn test_add_backward() {
395 let a = var(&[2.0, 3.0]);
396 let b = var(&[4.0, 5.0]);
397 let c = add(&a, &b);
398 let s = sum(&c);
399 s.backward();
400 assert_eq!(a.grad().unwrap().as_slice(), &[1.0, 1.0]);
402 assert_eq!(b.grad().unwrap().as_slice(), &[1.0, 1.0]);
403 }
404
405 #[test]
406 fn test_mul_backward() {
407 let a = var(&[2.0, 3.0]);
408 let b = var(&[4.0, 5.0]);
409 let c = mul(&a, &b);
410 let s = sum(&c);
411 s.backward();
412 assert_eq!(a.grad().unwrap().as_slice(), &[4.0, 5.0]);
414 assert_eq!(b.grad().unwrap().as_slice(), &[2.0, 3.0]);
415 }
416
417 #[test]
418 fn test_sub_backward() {
419 let a = var(&[5.0]);
420 let b = var(&[3.0]);
421 let c = sub(&a, &b);
422 let s = sum(&c);
423 s.backward();
424 assert_eq!(a.grad().unwrap().as_slice(), &[1.0]);
425 assert_eq!(b.grad().unwrap().as_slice(), &[-1.0]);
426 }
427
428 #[test]
429 fn test_matmul_backward() {
430 let a = Variable::new(
432 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap(),
433 true,
434 );
435 let b = Variable::new(
436 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap(),
437 true,
438 );
439 let c = matmul(&a, &b);
440 let s = sum(&c);
441 s.backward();
442 assert_eq!(a.grad().unwrap().shape(), &[2, 3]);
444 assert_eq!(b.grad().unwrap().shape(), &[3, 2]);
445 }
446
447 #[test]
448 fn test_pow_backward() {
449 let a = var(&[2.0, 3.0]);
450 let c = pow(&a, 2.0);
451 let s = sum(&c);
452 s.backward();
453 assert_eq!(a.grad().unwrap().as_slice(), &[4.0, 6.0]);
455 }
456
457 #[test]
458 fn test_mean_backward() {
459 let a = var(&[2.0, 4.0]);
460 let m = mean(&a);
461 m.backward();
462 assert_eq!(a.grad().unwrap().as_slice(), &[0.5, 0.5]);
464 }
465
466 #[test]
467 fn test_neg_backward() {
468 let a = var(&[3.0]);
469 let c = neg(&a);
470 let s = sum(&c);
471 s.backward();
472 assert_eq!(a.grad().unwrap().as_slice(), &[-1.0]);
473 }
474
475 #[test]
476 fn test_operator_overloads() {
477 let a = var(&[1.0, 2.0]);
478 let b = var(&[3.0, 4.0]);
479 let c = &a + &b;
480 let d = &a * &b;
481 let s = sum(&(&c + &d));
482 s.backward();
483 assert_eq!(a.grad().unwrap().as_slice(), &[4.0, 5.0]);
487 assert_eq!(b.grad().unwrap().as_slice(), &[2.0, 3.0]);
488 }
489
490 #[test]
491 fn test_scalar_mul_backward() {
492 let a = var(&[2.0, 3.0]);
493 let c = scalar_mul(&a, 5.0);
494 let s = sum(&c);
495 s.backward();
496 assert_eq!(a.grad().unwrap().as_slice(), &[5.0, 5.0]);
498 }
499
500 #[test]
501 fn test_scalar_div_backward() {
502 let a = var(&[4.0, 8.0]);
503 let c = scalar_div(&a, 2.0);
504 let s = sum(&c);
505 s.backward();
506 assert_eq!(a.grad().unwrap().as_slice(), &[0.5, 0.5]);
508 }
509
510 #[test]
511 fn test_add_bias_forward_and_backward() {
512 let input = Variable::new(
514 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap(),
515 true,
516 );
517 let bias = Variable::new(
518 Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![3]).unwrap(),
519 true,
520 );
521 let y = add_bias(&input, &bias);
522 let y_data = y.data();
524 let y_s = y_data.as_slice();
525 assert!((y_s[0] - 1.1).abs() < 1e-10);
526 assert!((y_s[4] - 5.2).abs() < 1e-10);
527
528 let s = sum(&y);
529 s.backward();
530 let g_input = input.grad().unwrap();
532 assert_eq!(g_input.shape(), &[2, 3]);
533 for &v in g_input.as_slice() {
534 assert!((v - 1.0).abs() < 1e-10);
535 }
536 let g_bias = bias.grad().unwrap();
538 assert_eq!(g_bias.shape(), &[3]);
539 for &v in g_bias.as_slice() {
540 assert!((v - 2.0).abs() < 1e-10);
541 }
542 }
543
544 #[test]
545 fn test_single_element_sum() {
546 let a = var(&[7.0]);
547 let s = sum(&a);
548 assert_eq!(s.data().as_slice(), &[7.0]);
549 s.backward();
550 assert_eq!(a.grad().unwrap().as_slice(), &[1.0]);
551 }
552
553 #[test]
554 fn test_pow_cubic_backward() {
555 let a = var(&[2.0]);
556 let c = pow(&a, 3.0);
557 let s = sum(&c);
558 s.backward();
559 assert!((a.grad().unwrap().as_slice()[0] - 12.0).abs() < 1e-10);
561 }
562}