1use crate::error::{Result, SklearsError};
75use proc_macro2::{Span, TokenStream};
76use quote::quote;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use std::sync::{Arc, Mutex};
80use syn::{Attribute, Expr, FnArg, ItemFn, ReturnType, Stmt, Type};
81
82#[derive(Debug, Clone, Copy, PartialEq)]
88pub struct Dual {
89 pub real: f64,
91 pub dual: f64,
93}
94
95impl Dual {
96 pub fn new(real: f64, dual: f64) -> Self {
98 Self { real, dual }
99 }
100
101 pub fn variable(value: f64) -> Self {
103 Self::new(value, 1.0)
104 }
105
106 pub fn constant(value: f64) -> Self {
108 Self::new(value, 0.0)
109 }
110
111 pub fn value(&self) -> f64 {
113 self.real
114 }
115
116 pub fn derivative(&self) -> f64 {
118 self.dual
119 }
120}
121
122impl std::ops::Add for Dual {
124 type Output = Self;
125
126 fn add(self, other: Self) -> Self {
127 Self::new(self.real + other.real, self.dual + other.dual)
128 }
129}
130
131impl std::ops::Sub for Dual {
132 type Output = Self;
133
134 fn sub(self, other: Self) -> Self {
135 Self::new(self.real - other.real, self.dual - other.dual)
136 }
137}
138
139impl std::ops::Mul for Dual {
140 type Output = Self;
141
142 fn mul(self, other: Self) -> Self {
143 Self::new(
144 self.real * other.real,
145 self.real * other.dual + self.dual * other.real,
146 )
147 }
148}
149
150impl std::ops::Div for Dual {
151 type Output = Self;
152
153 fn div(self, other: Self) -> Self {
154 let inv_other_real = 1.0 / other.real;
155 Self::new(
156 self.real * inv_other_real,
157 (self.dual * other.real - self.real * other.dual) * inv_other_real * inv_other_real,
158 )
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct Variable {
165 pub id: VariableId,
167 pub value: f64,
169 pub gradient: f64,
171 pub node: Option<Arc<ComputationNode>>,
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
177pub struct VariableId(pub u64);
178
179impl Variable {
180 pub fn new(value: f64) -> Self {
182 static NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
183 let id = VariableId(NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst));
184
185 Self {
186 id,
187 value,
188 gradient: 0.0,
189 node: None,
190 }
191 }
192
193 pub fn with_graph(value: f64, tape: Arc<Mutex<ComputationTape>>) -> Self {
195 let mut var = Self::new(value);
196
197 let node = ComputationNode {
198 operation: Operation::Input,
199 inputs: Vec::new(),
200 output_id: var.id,
201 gradient_fn: Box::new(|_inputs, _output_grad| Vec::new()),
202 };
203
204 var.node = Some(Arc::new(node));
205
206 if let Ok(mut tape_guard) = tape.lock() {
208 tape_guard.add_node(var.node.as_ref().unwrap().clone());
209 }
210
211 var
212 }
213
214 pub fn set_gradient(&mut self, gradient: f64) {
216 self.gradient = gradient;
217 }
218
219 pub fn add_gradient(&mut self, gradient: f64) {
221 self.gradient += gradient;
222 }
223
224 pub fn zero_gradient(&mut self) {
226 self.gradient = 0.0;
227 }
228}
229
230pub type GradientFunction = Box<dyn Fn(&[f64], f64) -> Vec<f64> + Send + Sync>;
232
233pub struct ComputationNode {
235 pub operation: Operation,
237 pub inputs: Vec<VariableId>,
239 pub output_id: VariableId,
241 pub gradient_fn: GradientFunction,
243}
244
245impl std::fmt::Debug for ComputationNode {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.debug_struct("ComputationNode")
248 .field("operation", &self.operation)
249 .field("inputs", &self.inputs)
250 .field("output_id", &self.output_id)
251 .field("gradient_fn", &"<function>")
252 .finish()
253 }
254}
255
256#[derive(Debug, Clone, PartialEq)]
258pub enum Operation {
259 Input,
261 Add,
263 Sub,
265 Mul,
267 Div,
269 Pow,
271 Exp,
273 Ln,
275 Sin,
277 Cos,
279 Tanh,
281 Sigmoid,
283 ReLU,
285 Custom(String),
287}
288
289#[derive(Debug)]
291pub struct ComputationTape {
292 pub nodes: Vec<Arc<ComputationNode>>,
294 pub variables: HashMap<VariableId, Variable>,
296 pub execution_order: Vec<VariableId>,
298}
299
300impl ComputationTape {
301 pub fn new() -> Self {
303 Self {
304 nodes: Vec::new(),
305 variables: HashMap::new(),
306 execution_order: Vec::new(),
307 }
308 }
309
310 pub fn add_node(&mut self, node: Arc<ComputationNode>) {
312 self.execution_order.push(node.output_id);
313 self.nodes.push(node);
314 }
315
316 pub fn register_variable(&mut self, var: Variable) {
318 self.variables.insert(var.id, var);
319 }
320
321 pub fn backward(&mut self, root_gradient: f64) -> Result<()> {
323 for var in self.variables.values_mut() {
325 var.zero_gradient();
326 }
327
328 if let Some(root_id) = self.execution_order.last() {
330 if let Some(root_var) = self.variables.get_mut(root_id) {
331 root_var.set_gradient(root_gradient);
332 }
333 }
334
335 for &node_id in self.execution_order.iter().rev() {
337 if let Some(node) = self.nodes.iter().find(|n| n.output_id == node_id) {
338 let output_gradient = self
339 .variables
340 .get(&node_id)
341 .map(|v| v.gradient)
342 .unwrap_or(0.0);
343
344 let input_values: Vec<f64> = node
346 .inputs
347 .iter()
348 .filter_map(|&id| self.variables.get(&id).map(|v| v.value))
349 .collect();
350
351 let input_gradients = (node.gradient_fn)(&input_values, output_gradient);
353
354 for (&input_id, &gradient) in node.inputs.iter().zip(input_gradients.iter()) {
356 if let Some(input_var) = self.variables.get_mut(&input_id) {
357 input_var.add_gradient(gradient);
358 }
359 }
360 }
361 }
362
363 Ok(())
364 }
365
366 pub fn get_gradient(&self, id: VariableId) -> Option<f64> {
368 self.variables.get(&id).map(|v| v.gradient)
369 }
370
371 pub fn clear(&mut self) {
373 self.nodes.clear();
374 self.variables.clear();
375 self.execution_order.clear();
376 }
377}
378
379impl Default for ComputationTape {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct AutodiffConfig {
392 pub mode: ADMode,
394 pub max_order: u32,
396 pub simd: bool,
398 pub gpu: bool,
400 pub symbolic: bool,
402 pub optimizations: Vec<String>,
404}
405
406#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
408pub enum ADMode {
409 Forward,
411 Reverse,
413 Mixed,
415 Symbolic,
417}
418
419impl Default for AutodiffConfig {
420 fn default() -> Self {
421 Self {
422 mode: ADMode::Forward,
423 max_order: 1,
424 simd: false,
425 gpu: false,
426 symbolic: false,
427 optimizations: Vec::new(),
428 }
429 }
430}
431
432pub fn parse_autodiff_attributes(attrs: &[Attribute]) -> Result<AutodiffConfig> {
434 let mut config = AutodiffConfig::default();
435
436 for attr in attrs {
437 if attr.path().is_ident("autodiff") {
438 config.mode = ADMode::Forward; }
442 }
443
444 Ok(config)
445}
446
447pub fn generate_autodiff_impl(func: &ItemFn, config: &AutodiffConfig) -> Result<TokenStream> {
449 let original_name = &func.sig.ident;
450 let autodiff_name = syn::Ident::new(&format!("{}_autodiff", original_name), Span::call_site());
451
452 match config.mode {
453 ADMode::Forward => generate_forward_mode(func, &autodiff_name, config),
454 ADMode::Reverse => generate_reverse_mode(func, &autodiff_name, config),
455 ADMode::Mixed => generate_mixed_mode(func, &autodiff_name, config),
456 ADMode::Symbolic => generate_symbolic_mode(func, &autodiff_name, config),
457 }
458}
459
460fn generate_forward_mode(
462 func: &ItemFn,
463 autodiff_name: &syn::Ident,
464 _config: &AutodiffConfig,
465) -> Result<TokenStream> {
466 let original_name = &func.sig.ident;
467 let inputs = &func.sig.inputs;
468 let output = &func.sig.output;
469
470 let dual_inputs = transform_inputs_to_dual(inputs)?;
472 let dual_output = transform_output_to_dual(output)?;
473
474 let dual_body = transform_body_to_dual(&func.block)?;
476
477 let generated = quote! {
478 pub fn #autodiff_name(#dual_inputs) -> #dual_output {
480 #dual_body
481 }
482
483 pub fn #original_name _derivative(x: f64) -> (f64, f64) {
485 let dual_x = Dual::variable(x);
486 let result = #autodiff_name(dual_x);
487 (result.value(), result.derivative())
488 }
489 };
490
491 Ok(generated)
492}
493
494fn generate_reverse_mode(
496 func: &ItemFn,
497 autodiff_name: &syn::Ident,
498 _config: &AutodiffConfig,
499) -> Result<TokenStream> {
500 let original_name = &func.sig.ident;
501 let inputs = &func.sig.inputs;
502
503 let var_inputs = transform_inputs_to_variables(inputs)?;
505 let tape_body = transform_body_to_tape(&func.block)?;
506
507 let generated = quote! {
508 pub fn #autodiff_name(#var_inputs, tape: Arc<Mutex<ComputationTape>>) -> Variable {
510 #tape_body
511 }
512
513 pub fn #original_name _gradients(inputs: &[f64]) -> Vec<f64> {
515 let tape = Arc::new(Mutex::new(ComputationTape::new()));
516
517 let vars: Vec<Variable> = inputs.iter()
519 .map(|&x| Variable::with_graph(x, tape.clone()))
520 .collect();
521
522 let output = #autodiff_name(vars, tape.clone());
524
525 if let Ok(mut tape_guard) = tape.lock() {
527 let _ = tape_guard.backward(1.0);
528
529 vars.iter()
531 .map(|v| tape_guard.get_gradient(v.id).unwrap_or(0.0))
532 .collect()
533 } else {
534 vec![0.0; inputs.len()]
535 }
536 }
537 };
538
539 Ok(generated)
540}
541
542fn generate_mixed_mode(
544 func: &ItemFn,
545 autodiff_name: &syn::Ident,
546 config: &AutodiffConfig,
547) -> Result<TokenStream> {
548 let forward_impl = generate_forward_mode(func, autodiff_name, config)?;
550
551 let reverse_name = syn::Ident::new(&format!("{}_reverse", autodiff_name), Span::call_site());
552 let reverse_impl = generate_reverse_mode(func, &reverse_name, config)?;
553
554 let generated = quote! {
555 #forward_impl
556 #reverse_impl
557
558 pub fn #autodiff_name _mixed(inputs: &[f64], forward_vars: &[usize]) -> (f64, Vec<f64>) {
560 let gradients = vec![0.0; inputs.len()];
563 (0.0, gradients)
564 }
565 };
566
567 Ok(generated)
568}
569
570fn generate_symbolic_mode(
572 func: &ItemFn,
573 autodiff_name: &syn::Ident,
574 _config: &AutodiffConfig,
575) -> Result<TokenStream> {
576 let original_name = &func.sig.ident;
577
578 let generated = quote! {
579 pub fn #autodiff_name() -> SymbolicExpression {
581 SymbolicExpression::new("derivative")
584 }
585
586 pub fn #original_name _latex() -> String {
588 let expr = #autodiff_name();
589 expr.to_latex()
590 }
591 };
592
593 Ok(generated)
594}
595
596fn transform_inputs_to_dual(
602 inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
603) -> Result<TokenStream> {
604 let mut dual_inputs = Vec::new();
605
606 for input in inputs {
607 match input {
608 FnArg::Typed(pat_type) => {
609 let pat = &pat_type.pat;
610 match &*pat_type.ty {
612 Type::Path(type_path) if type_path.path.is_ident("f64") => {
613 dual_inputs.push(quote! { #pat: Dual });
614 }
615 ty => {
616 dual_inputs.push(quote! { #pat: #ty });
617 }
618 }
619 }
620 _ => {
621 return Err(SklearsError::InvalidOperation(
622 "Unsupported function parameter type".to_string(),
623 ));
624 }
625 }
626 }
627
628 Ok(quote! { #(#dual_inputs),* })
629}
630
631fn transform_output_to_dual(output: &ReturnType) -> Result<TokenStream> {
633 match output {
634 ReturnType::Type(_, ty) => match &**ty {
635 Type::Path(type_path) if type_path.path.is_ident("f64") => Ok(quote! { Dual }),
636 ty => Ok(quote! { #ty }),
637 },
638 ReturnType::Default => Ok(quote! { () }),
639 }
640}
641
642fn transform_body_to_dual(block: &syn::Block) -> Result<TokenStream> {
644 let mut transformed_stmts = Vec::new();
645
646 for stmt in &block.stmts {
647 let transformed = transform_statement_to_dual(stmt)?;
648 transformed_stmts.push(transformed);
649 }
650
651 Ok(quote! { { #(#transformed_stmts)* } })
652}
653
654fn transform_statement_to_dual(stmt: &Stmt) -> Result<TokenStream> {
656 match stmt {
657 Stmt::Expr(expr, _) => {
658 let transformed_expr = transform_expression_to_dual(expr)?;
659 Ok(quote! { #transformed_expr })
660 }
661 Stmt::Local(local) => {
662 let pat = &local.pat;
664 if let Some(local_init) = &local.init {
665 let init = &local_init.expr;
666 let transformed_init = transform_expression_to_dual(init)?;
667 Ok(quote! { let #pat = #transformed_init; })
668 } else {
669 Ok(quote! { #stmt })
670 }
671 }
672 _ => Ok(quote! { #stmt }),
673 }
674}
675
676fn transform_expression_to_dual(expr: &Expr) -> Result<TokenStream> {
678 match expr {
679 Expr::Binary(binary_expr) => {
680 let left = transform_expression_to_dual(&binary_expr.left)?;
681 let right = transform_expression_to_dual(&binary_expr.right)?;
682 let op = &binary_expr.op;
683
684 Ok(quote! { (#left) #op (#right) })
686 }
687 Expr::Call(call_expr) => {
688 let func = &call_expr.func;
689 let args: Vec<TokenStream> = call_expr
690 .args
691 .iter()
692 .map(transform_expression_to_dual)
693 .collect::<Result<Vec<_>>>()?;
694
695 match &**func {
697 Expr::Path(path) if path.path.is_ident("exp") => {
698 Ok(quote! { dual_exp(#(#args),*) })
699 }
700 Expr::Path(path) if path.path.is_ident("ln") => Ok(quote! { dual_ln(#(#args),*) }),
701 Expr::Path(path) if path.path.is_ident("sin") => {
702 Ok(quote! { dual_sin(#(#args),*) })
703 }
704 Expr::Path(path) if path.path.is_ident("cos") => {
705 Ok(quote! { dual_cos(#(#args),*) })
706 }
707 _ => Ok(quote! { #func(#(#args),*) }),
708 }
709 }
710 Expr::Lit(lit_expr) => {
711 match &lit_expr.lit {
713 syn::Lit::Float(float_lit) => {
714 let value = &float_lit.base10_digits();
715 let parsed_value: f64 = value.parse().map_err(|_| {
716 SklearsError::InvalidOperation("Invalid float literal".to_string())
717 })?;
718 Ok(quote! { Dual::constant(#parsed_value) })
719 }
720 syn::Lit::Int(int_lit) => {
721 let value = &int_lit.base10_digits();
722 let parsed_value: i64 = value.parse().map_err(|_| {
723 SklearsError::InvalidOperation("Invalid int literal".to_string())
724 })?;
725 Ok(quote! { Dual::constant(#parsed_value as f64) })
726 }
727 _ => Ok(quote! { #expr }),
728 }
729 }
730 _ => Ok(quote! { #expr }),
731 }
732}
733
734fn transform_inputs_to_variables(
736 inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
737) -> Result<TokenStream> {
738 let mut var_inputs = Vec::new();
739
740 for input in inputs {
741 match input {
742 FnArg::Typed(pat_type) => {
743 let pat = &pat_type.pat;
744 match &*pat_type.ty {
746 Type::Path(type_path) if type_path.path.is_ident("f64") => {
747 var_inputs.push(quote! { #pat: Variable });
748 }
749 ty => {
750 var_inputs.push(quote! { #pat: #ty });
751 }
752 }
753 }
754 _ => {
755 return Err(SklearsError::InvalidOperation(
756 "Unsupported function parameter type".to_string(),
757 ));
758 }
759 }
760 }
761
762 Ok(quote! { #(#var_inputs),* })
763}
764
765fn transform_body_to_tape(_block: &syn::Block) -> Result<TokenStream> {
767 Ok(quote! {
770 {
771 Variable::with_graph(0.0, tape)
773 }
774 })
775}
776
777pub fn dual_exp(x: Dual) -> Dual {
783 let exp_x = x.real.exp();
784 Dual::new(exp_x, x.dual * exp_x)
785}
786
787pub fn dual_ln(x: Dual) -> Dual {
789 Dual::new(x.real.ln(), x.dual / x.real)
790}
791
792pub fn dual_sin(x: Dual) -> Dual {
794 Dual::new(x.real.sin(), x.dual * x.real.cos())
795}
796
797pub fn dual_cos(x: Dual) -> Dual {
799 Dual::new(x.real.cos(), -x.dual * x.real.sin())
800}
801
802pub fn dual_tanh(x: Dual) -> Dual {
804 let tanh_x = x.real.tanh();
805 Dual::new(tanh_x, x.dual * (1.0 - tanh_x * tanh_x))
806}
807
808pub fn dual_sigmoid(x: Dual) -> Dual {
810 let sigmoid_x = 1.0 / (1.0 + (-x.real).exp());
811 Dual::new(sigmoid_x, x.dual * sigmoid_x * (1.0 - sigmoid_x))
812}
813
814pub fn dual_pow(base: Dual, exponent: f64) -> Dual {
816 let pow_result = base.real.powf(exponent);
817 Dual::new(
818 pow_result,
819 base.dual * exponent * base.real.powf(exponent - 1.0),
820 )
821}
822
823#[derive(Debug, Clone, PartialEq)]
829pub enum SymbolicExpression {
830 Variable(String),
832 Constant(f64),
834 Add(Box<SymbolicExpression>, Box<SymbolicExpression>),
836 Sub(Box<SymbolicExpression>, Box<SymbolicExpression>),
838 Mul(Box<SymbolicExpression>, Box<SymbolicExpression>),
840 Div(Box<SymbolicExpression>, Box<SymbolicExpression>),
842 Pow(Box<SymbolicExpression>, Box<SymbolicExpression>),
844 Function(String, Vec<SymbolicExpression>),
846}
847
848impl SymbolicExpression {
849 pub fn new(name: &str) -> Self {
851 Self::Variable(name.to_string())
852 }
853
854 pub fn differentiate(&self, var: &str) -> Self {
856 match self {
857 SymbolicExpression::Variable(v) if v == var => SymbolicExpression::Constant(1.0),
858 SymbolicExpression::Variable(_) => SymbolicExpression::Constant(0.0),
859 SymbolicExpression::Constant(_) => SymbolicExpression::Constant(0.0),
860 SymbolicExpression::Add(left, right) => SymbolicExpression::Add(
861 Box::new(left.differentiate(var)),
862 Box::new(right.differentiate(var)),
863 ),
864 SymbolicExpression::Sub(left, right) => SymbolicExpression::Sub(
865 Box::new(left.differentiate(var)),
866 Box::new(right.differentiate(var)),
867 ),
868 SymbolicExpression::Mul(left, right) => {
869 SymbolicExpression::Add(
871 Box::new(SymbolicExpression::Mul(
872 Box::new(left.differentiate(var)),
873 right.clone(),
874 )),
875 Box::new(SymbolicExpression::Mul(
876 left.clone(),
877 Box::new(right.differentiate(var)),
878 )),
879 )
880 }
881 SymbolicExpression::Div(left, right) => {
882 SymbolicExpression::Div(
884 Box::new(SymbolicExpression::Sub(
885 Box::new(SymbolicExpression::Mul(
886 Box::new(left.differentiate(var)),
887 right.clone(),
888 )),
889 Box::new(SymbolicExpression::Mul(
890 left.clone(),
891 Box::new(right.differentiate(var)),
892 )),
893 )),
894 Box::new(SymbolicExpression::Pow(
895 right.clone(),
896 Box::new(SymbolicExpression::Constant(2.0)),
897 )),
898 )
899 }
900 SymbolicExpression::Pow(base, exp) => {
901 match (&**base, &**exp) {
903 (_, SymbolicExpression::Constant(n)) => {
904 SymbolicExpression::Mul(
906 Box::new(SymbolicExpression::Mul(
907 Box::new(SymbolicExpression::Constant(*n)),
908 Box::new(SymbolicExpression::Pow(
909 base.clone(),
910 Box::new(SymbolicExpression::Constant(n - 1.0)),
911 )),
912 )),
913 Box::new(base.differentiate(var)),
914 )
915 }
916 _ => {
917 SymbolicExpression::Mul(
919 Box::new(self.clone()),
920 Box::new(SymbolicExpression::Add(
921 Box::new(SymbolicExpression::Mul(
922 Box::new(exp.differentiate(var)),
923 Box::new(SymbolicExpression::Function(
924 "ln".to_string(),
925 vec![*base.clone()],
926 )),
927 )),
928 Box::new(SymbolicExpression::Mul(
929 exp.clone(),
930 Box::new(SymbolicExpression::Div(
931 Box::new(base.differentiate(var)),
932 base.clone(),
933 )),
934 )),
935 )),
936 )
937 }
938 }
939 }
940 SymbolicExpression::Function(name, args) => {
941 self.differentiate_function(name, args, var)
942 }
943 }
944 }
945
946 fn differentiate_function(&self, name: &str, args: &[SymbolicExpression], var: &str) -> Self {
948 match name {
949 "sin" if args.len() == 1 => {
950 SymbolicExpression::Mul(
952 Box::new(SymbolicExpression::Function(
953 "cos".to_string(),
954 args.to_vec(),
955 )),
956 Box::new(args[0].differentiate(var)),
957 )
958 }
959 "cos" if args.len() == 1 => {
960 SymbolicExpression::Mul(
962 Box::new(SymbolicExpression::Constant(-1.0)),
963 Box::new(SymbolicExpression::Mul(
964 Box::new(SymbolicExpression::Function(
965 "sin".to_string(),
966 args.to_vec(),
967 )),
968 Box::new(args[0].differentiate(var)),
969 )),
970 )
971 }
972 "exp" if args.len() == 1 => {
973 SymbolicExpression::Mul(
975 Box::new(self.clone()),
976 Box::new(args[0].differentiate(var)),
977 )
978 }
979 "ln" if args.len() == 1 => {
980 SymbolicExpression::Div(
982 Box::new(args[0].differentiate(var)),
983 Box::new(args[0].clone()),
984 )
985 }
986 _ => {
987 SymbolicExpression::Function(format!("d{}_d{}", name, var), args.to_vec())
989 }
990 }
991 }
992
993 pub fn to_latex(&self) -> String {
995 match self {
996 SymbolicExpression::Variable(v) => v.clone(),
997 SymbolicExpression::Constant(c) => {
998 if c.fract() == 0.0 {
999 format!("{}", *c as i64)
1000 } else {
1001 format!("{:.3}", c)
1002 }
1003 }
1004 SymbolicExpression::Add(left, right) => {
1005 format!("({} + {})", left.to_latex(), right.to_latex())
1006 }
1007 SymbolicExpression::Sub(left, right) => {
1008 format!("({} - {})", left.to_latex(), right.to_latex())
1009 }
1010 SymbolicExpression::Mul(left, right) => {
1011 format!("({} \\cdot {})", left.to_latex(), right.to_latex())
1012 }
1013 SymbolicExpression::Div(left, right) => {
1014 format!("\\frac{{{}}}{{{}}}", left.to_latex(), right.to_latex())
1015 }
1016 SymbolicExpression::Pow(base, exp) => {
1017 format!("{}^{{{}}}", base.to_latex(), exp.to_latex())
1018 }
1019 SymbolicExpression::Function(name, args) => {
1020 if args.is_empty() {
1021 format!("\\{}", name)
1022 } else if args.len() == 1 {
1023 format!("\\{}({})", name, args[0].to_latex())
1024 } else {
1025 let arg_strs: Vec<String> = args.iter().map(|a| a.to_latex()).collect();
1026 format!("\\{}({})", name, arg_strs.join(", "))
1027 }
1028 }
1029 }
1030 }
1031
1032 pub fn simplify(&self) -> Self {
1034 match self {
1035 SymbolicExpression::Add(left, right) => {
1036 let left_simp = left.simplify();
1037 let right_simp = right.simplify();
1038
1039 match (&left_simp, &right_simp) {
1040 (SymbolicExpression::Constant(0.0), _) => right_simp,
1041 (_, SymbolicExpression::Constant(0.0)) => left_simp,
1042 (SymbolicExpression::Constant(a), SymbolicExpression::Constant(b)) => {
1043 SymbolicExpression::Constant(a + b)
1044 }
1045 _ => SymbolicExpression::Add(Box::new(left_simp), Box::new(right_simp)),
1046 }
1047 }
1048 SymbolicExpression::Mul(left, right) => {
1049 let left_simp = left.simplify();
1050 let right_simp = right.simplify();
1051
1052 match (&left_simp, &right_simp) {
1053 (SymbolicExpression::Constant(0.0), _)
1054 | (_, SymbolicExpression::Constant(0.0)) => SymbolicExpression::Constant(0.0),
1055 (SymbolicExpression::Constant(1.0), _) => right_simp,
1056 (_, SymbolicExpression::Constant(1.0)) => left_simp,
1057 (SymbolicExpression::Constant(a), SymbolicExpression::Constant(b)) => {
1058 SymbolicExpression::Constant(a * b)
1059 }
1060 _ => SymbolicExpression::Mul(Box::new(left_simp), Box::new(right_simp)),
1061 }
1062 }
1063 SymbolicExpression::Pow(base, exponent) => {
1064 let base_simp = base.simplify();
1065 let exp_simp = exponent.simplify();
1066
1067 match (&base_simp, &exp_simp) {
1068 (_, SymbolicExpression::Constant(1.0)) => base_simp,
1070 (_, SymbolicExpression::Constant(0.0)) => SymbolicExpression::Constant(1.0),
1072 (SymbolicExpression::Constant(1.0), _) => SymbolicExpression::Constant(1.0),
1074 (SymbolicExpression::Constant(0.0), SymbolicExpression::Constant(n))
1076 if *n > 0.0 =>
1077 {
1078 SymbolicExpression::Constant(0.0)
1079 }
1080 (SymbolicExpression::Constant(a), SymbolicExpression::Constant(b)) => {
1082 SymbolicExpression::Constant(a.powf(*b))
1083 }
1084 _ => SymbolicExpression::Pow(Box::new(base_simp), Box::new(exp_simp)),
1085 }
1086 }
1087 _ => self.clone(),
1088 }
1089 }
1090}
1091
1092pub fn second_derivative<F>(_f: F, x: f64) -> f64
1098where
1099 F: Fn(Dual) -> Dual,
1100{
1101 let _dual_x = Dual::new(x, 1.0);
1103
1104 0.0
1106}
1107
1108pub fn hessian<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
1110where
1111 F: Fn(&[f64]) -> f64,
1112{
1113 let n = x.len();
1114 let mut hessian = vec![vec![0.0; n]; n];
1115
1116 let h = 1e-8; for i in 0..n {
1120 for j in 0..n {
1121 if i == j {
1122 let mut x_plus = x.to_vec();
1124 let mut x_minus = x.to_vec();
1125 x_plus[i] += h;
1126 x_minus[i] -= h;
1127
1128 let f_plus = f(&x_plus);
1129 let f_center = f(x);
1130 let f_minus = f(&x_minus);
1131
1132 hessian[i][j] = (f_plus - 2.0 * f_center + f_minus) / (h * h);
1133 } else {
1134 let mut x_pp = x.to_vec();
1136 let mut x_pm = x.to_vec();
1137 let mut x_mp = x.to_vec();
1138 let mut x_mm = x.to_vec();
1139
1140 x_pp[i] += h;
1141 x_pp[j] += h;
1142 x_pm[i] += h;
1143 x_pm[j] -= h;
1144 x_mp[i] -= h;
1145 x_mp[j] += h;
1146 x_mm[i] -= h;
1147 x_mm[j] -= h;
1148
1149 let f_pp = f(&x_pp);
1150 let f_pm = f(&x_pm);
1151 let f_mp = f(&x_mp);
1152 let f_mm = f(&x_mm);
1153
1154 hessian[i][j] = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h * h);
1155 }
1156 }
1157 }
1158
1159 hessian
1160}
1161
1162pub fn forward_diff<F>(f: F, x: f64) -> (f64, f64)
1168where
1169 F: Fn(Dual) -> Dual,
1170{
1171 let dual_x = Dual::variable(x);
1172 let result = f(dual_x);
1173 (result.value(), result.derivative())
1174}
1175
1176pub fn gradient<F>(f: F, x: &[f64]) -> Vec<f64>
1178where
1179 F: Fn(&[f64]) -> f64,
1180{
1181 let mut grad = vec![0.0; x.len()];
1182 let h = 1e-8;
1183
1184 for i in 0..x.len() {
1185 let mut x_plus = x.to_vec();
1186 let mut x_minus = x.to_vec();
1187 x_plus[i] += h;
1188 x_minus[i] -= h;
1189
1190 grad[i] = (f(&x_plus) - f(&x_minus)) / (2.0 * h);
1191 }
1192
1193 grad
1194}
1195
1196#[allow(non_snake_case)]
1197#[cfg(test)]
1198mod tests {
1199 use super::*;
1200
1201 #[test]
1202 fn test_dual_arithmetic() {
1203 let x = Dual::new(2.0, 1.0);
1204 let y = Dual::new(3.0, 0.0);
1205
1206 let sum = x + y;
1207 assert_eq!(sum.real, 5.0);
1208 assert_eq!(sum.dual, 1.0);
1209
1210 let product = x * y;
1211 assert_eq!(product.real, 6.0);
1212 assert_eq!(product.dual, 3.0);
1213 }
1214
1215 #[test]
1216 fn test_dual_math_functions() {
1217 let x = Dual::variable(1.0);
1218
1219 let exp_result = dual_exp(x);
1220 assert!((exp_result.real - std::f64::consts::E).abs() < 1e-10);
1221 assert!((exp_result.dual - std::f64::consts::E).abs() < 1e-10);
1222
1223 let ln_result = dual_ln(x);
1224 assert!((ln_result.real - 0.0).abs() < 1e-10);
1225 assert!((ln_result.dual - 1.0).abs() < 1e-10);
1226 }
1227
1228 #[test]
1229 fn test_forward_diff() {
1230 let f = |x: Dual| x * x;
1232 let (value, derivative) = forward_diff(f, 3.0);
1233
1234 assert_eq!(value, 9.0);
1235 assert_eq!(derivative, 6.0);
1236 }
1237
1238 #[test]
1239 fn test_symbolic_differentiation() {
1240 let x = SymbolicExpression::Variable("x".to_string());
1241 let x_squared = SymbolicExpression::Pow(
1242 Box::new(x.clone()),
1243 Box::new(SymbolicExpression::Constant(2.0)),
1244 );
1245
1246 let derivative = x_squared.differentiate("x");
1247 let simplified = derivative.simplify();
1248
1249 match simplified {
1251 SymbolicExpression::Mul(left, right) => {
1252 assert_eq!(*left, SymbolicExpression::Constant(2.0));
1253 assert_eq!(*right, SymbolicExpression::Variable("x".to_string()));
1254 }
1255 _ => panic!("Expected multiplication"),
1256 }
1257 }
1258
1259 #[test]
1260 fn test_gradient_computation() {
1261 let f = |vars: &[f64]| vars[0] * vars[0] + vars[1] * vars[1];
1263 let grad = gradient(f, &[2.0, 3.0]);
1264
1265 assert!((grad[0] - 4.0).abs() < 1e-6);
1266 assert!((grad[1] - 6.0).abs() < 1e-6);
1267 }
1268
1269 #[test]
1270 fn test_computation_tape() {
1271 let mut tape = ComputationTape::new();
1272
1273 let x = Variable::new(2.0);
1275 let y = Variable::new(3.0);
1276
1277 tape.register_variable(x.clone());
1278 tape.register_variable(y.clone());
1279
1280 assert_eq!(tape.variables.len(), 2);
1282 assert!(tape.get_gradient(x.id).is_some());
1283 }
1284
1285 #[test]
1286 fn test_variable_creation() {
1287 let var1 = Variable::new(1.0);
1288 let var2 = Variable::new(2.0);
1289
1290 assert_ne!(var1.id, var2.id);
1291 assert_eq!(var1.value, 1.0);
1292 assert_eq!(var2.value, 2.0);
1293 assert_eq!(var1.gradient, 0.0);
1294 assert_eq!(var2.gradient, 0.0);
1295 }
1296
1297 #[test]
1298 fn test_autodiff_config() {
1299 let config = AutodiffConfig::default();
1300 assert_eq!(config.mode, ADMode::Forward);
1301 assert_eq!(config.max_order, 1);
1302 assert!(!config.simd);
1303 assert!(!config.gpu);
1304 }
1305
1306 #[test]
1307 fn test_symbolic_latex_output() {
1308 let expr = SymbolicExpression::Div(
1309 Box::new(SymbolicExpression::Variable("x".to_string())),
1310 Box::new(SymbolicExpression::Constant(2.0)),
1311 );
1312
1313 let latex = expr.to_latex();
1314 assert_eq!(latex, "\\frac{x}{2}");
1315 }
1316}