1pub(crate) mod average;
2pub(crate) mod factor_lie_group;
3pub(crate) mod group_mul;
4pub(crate) mod lie_group_manifold;
5pub(crate) mod real_lie_group;
6
7use core::fmt::Debug;
8
9use approx::assert_relative_eq;
10use sophus_autodiff::{
11 manifold::IsTangent,
12 params::{
13 HasParams,
14 IsParamsImpl,
15 },
16};
17
18use crate::{
19 IsLieGroupImpl,
20 prelude::*,
21};
22
23extern crate alloc;
24
25#[derive(Debug, Copy, Clone, Default)]
69pub struct LieGroup<
70 S: IsScalar<BATCH, DM, DN>,
71 const DOF: usize,
72 const PARAMS: usize,
73 const POINT: usize,
74 const AMBIENT: usize,
75 const BATCH: usize,
76 const DM: usize,
77 const DN: usize,
78 G: IsLieGroupImpl<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
79> {
80 pub(crate) params: S::Vector<PARAMS>,
81 phantom: core::marker::PhantomData<G>,
82}
83
84impl<
85 S: IsScalar<BATCH, DM, DN>,
86 const DOF: usize,
87 const PARAMS: usize,
88 const POINT: usize,
89 const AMBIENT: usize,
90 const BATCH: usize,
91 const DM: usize,
92 const DN: usize,
93 G: IsLieGroupImpl<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
94> LieGroup<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN, G>
95{
96 pub fn adj(&self) -> S::Matrix<DOF, DOF> {
98 G::adj(&self.params)
99 }
100
101 pub fn ad(tangent: S::Vector<DOF>) -> S::Matrix<DOF, DOF> {
103 G::ad(tangent)
104 }
105
106 pub fn exp(omega: S::Vector<DOF>) -> Self {
108 Self::from_params(G::exp(omega))
109 }
110
111 pub fn interpolate(&self, other: &Self, w: S) -> Self {
115 self * Self::exp((self.inverse() * other).log().scaled(w))
116 }
117
118 pub fn log(&self) -> S::Vector<DOF> {
120 G::log(&self.params)
121 }
122
123 pub fn hat(omega: S::Vector<DOF>) -> S::Matrix<AMBIENT, AMBIENT> {
125 G::hat(omega)
126 }
127
128 pub fn vee(xi: S::Matrix<AMBIENT, AMBIENT>) -> S::Vector<DOF> {
130 G::vee(xi)
131 }
132
133 pub fn identity() -> Self {
135 Self::from_params(G::identity_params())
136 }
137
138 pub fn group_mul(&self, other: Self) -> Self {
140 Self::from_params(G::group_mul(&self.params, other.params))
141 }
142
143 pub fn inverse(&self) -> Self {
145 Self::from_params(G::inverse(&self.params))
146 }
147
148 pub fn transform(&self, point: S::Vector<POINT>) -> S::Vector<POINT> {
150 G::transform(&self.params, point)
151 }
152
153 pub fn to_ambient(point: S::Vector<POINT>) -> S::Vector<AMBIENT> {
155 G::to_ambient(point)
156 }
157
158 pub fn compact(&self) -> S::Matrix<POINT, AMBIENT> {
160 G::compact(&self.params)
161 }
162
163 pub fn matrix(&self) -> S::Matrix<AMBIENT, AMBIENT> {
165 G::matrix(&self.params)
166 }
167
168 pub fn element_examples() -> alloc::vec::Vec<Self> {
170 let mut elements = alloc::vec![];
171 for params in Self::params_examples() {
172 elements.push(Self::from_params(params));
173 }
174 elements
175 }
176
177 fn presentability_tests() {
178 if G::IS_ORIGIN_PRESERVING {
179 for g in &Self::element_examples() {
180 let o = S::Vector::<POINT>::zeros();
181
182 approx::assert_abs_diff_eq!(
183 g.transform(o).real_vector(),
184 o.real_vector(),
185 epsilon = 0.0001
186 );
187 }
188 } else {
189 let mut num_preserves = 0;
190 let mut num = 0;
191 for g in &Self::element_examples() {
192 let o = S::Vector::<POINT>::zeros();
193 let o_transformed = g.transform(o);
194 let mask = (o_transformed.real_vector())
195 .norm()
196 .less_equal(&S::RealScalar::from_f64(0.0001));
197
198 num_preserves += mask.count();
199 num += S::Mask::all_true().count();
200 }
201 let percentage = num_preserves as f64 / num as f64;
202 assert!(percentage <= 0.75, "{percentage} <= 0.75");
203 }
204 }
205
206 fn adjoint_tests() {
207 let group_examples = Self::element_examples();
208 let basis: alloc::vec::Vec<S::Vector<DOF>> = (0..DOF)
209 .map(|i| {
210 let mut e = S::Vector::<DOF>::zeros();
211 *e.elem_mut(i) = S::ones();
212 e
213 })
214 .collect();
215
216 for g in &group_examples {
217 let mat_g = g.matrix();
218 let inv_mat_g = g.inverse().matrix();
219
220 let mut ad_ref = S::Matrix::<DOF, DOF>::zeros();
221
222 for i in 0..DOF {
223 let col_i = Self::vee(mat_g.mat_mul(Self::hat(basis[i]).mat_mul(inv_mat_g)));
224 ad_ref.set_col_vec(i, col_i);
225 }
226
227 let ad_impl = g.adj();
228 assert_relative_eq!(
229 ad_impl.real_matrix(),
230 ad_ref.real_matrix(),
231 epsilon = 0.0001
232 );
233 }
234 let tangent_examples: alloc::vec::Vec<S::Vector<DOF>> = G::tangent_examples();
235 for a in tangent_examples.clone() {
236 for b in tangent_examples.clone() {
237 let ad_a = Self::ad(a);
238 let ad_a_b = ad_a * b;
239 let hat_ab = Self::hat(a).mat_mul(Self::hat(b));
240 let hat_ba = Self::hat(b).mat_mul(Self::hat(a));
241
242 let lie_bracket_a_b = Self::vee(hat_ab - hat_ba);
243 assert_relative_eq!(
244 ad_a_b.real_vector(),
245 lie_bracket_a_b.real_vector(),
246 epsilon = 0.0001
247 );
248 }
249 }
250 }
251
252 fn exp_tests() {
253 let group_examples: alloc::vec::Vec<_> = Self::element_examples();
254 let tangent_examples: alloc::vec::Vec<S::Vector<DOF>> = G::tangent_examples();
255
256 for g in &group_examples {
257 let matrix_before = g.compact().real_matrix();
258 let matrix_after = Self::exp(g.log()).compact().real_matrix();
259
260 assert_relative_eq!(matrix_before, matrix_after, epsilon = 0.0001);
261
262 let t = g.inverse().log().real_vector();
263 let t2 = -(g.log().real_vector());
264 assert_relative_eq!(t, t2, epsilon = 0.0001);
265 }
266 for omega in tangent_examples {
267 let exp_inverse = Self::exp(omega).inverse();
268 let neg_omega = -omega;
269
270 let exp_neg_omega = Self::exp(neg_omega);
271 assert_relative_eq!(
272 exp_inverse.compact(),
273 exp_neg_omega.compact(),
274 epsilon = 0.0001
275 );
276 }
277 }
278
279 fn hat_tests() {
280 let tangent_examples: alloc::vec::Vec<S::Vector<DOF>> = G::tangent_examples();
281
282 for omega in tangent_examples {
283 assert_relative_eq!(
284 omega.real_vector(),
285 Self::vee(Self::hat(omega)).real_vector(),
286 epsilon = 0.0001
287 );
288 }
289 }
290
291 fn group_operation_tests() {
292 let group_examples: alloc::vec::Vec<_> = Self::element_examples();
293
294 for g1 in &group_examples {
295 for g2 in &group_examples {
296 for g3 in &group_examples {
297 let left_hugging = (g1 * g2) * g3;
298 let right_hugging = g1 * (g2 * g3);
299 assert_relative_eq!(
300 left_hugging.compact(),
301 right_hugging.compact(),
302 epsilon = 0.0001
303 );
304 }
305 }
306 }
307 for g1 in &group_examples {
308 for g2 in &group_examples {
309 let daz_from_foo_transform_1 = g2.inverse() * g1.inverse();
310 let daz_from_foo_transform_2 = (g1 * g2).inverse();
311 assert_relative_eq!(
312 daz_from_foo_transform_1.compact(),
313 daz_from_foo_transform_2.compact(),
314 epsilon = 0.0001
315 );
316 }
317 }
318 }
319
320 pub fn test_suite() {
322 let group_examples: alloc::vec::Vec<_> = Self::element_examples();
325 assert!(group_examples.len() >= 3);
326 let tangent_examples: alloc::vec::Vec<S::Vector<DOF>> = G::tangent_examples();
327 assert!(tangent_examples.len() >= 3);
328
329 Self::presentability_tests();
330 Self::group_operation_tests();
331 Self::hat_tests();
332 Self::exp_tests();
333 Self::adjoint_tests();
334 }
335}
336
337impl<
338 S: IsScalar<BATCH, DM, DN>,
339 const DOF: usize,
340 const PARAMS: usize,
341 const POINT: usize,
342 const AMBIENT: usize,
343 const BATCH: usize,
344 const DM: usize,
345 const DN: usize,
346 G: IsLieGroupImpl<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
347> IsParamsImpl<S, PARAMS, BATCH, DM, DN>
348 for LieGroup<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN, G>
349{
350 fn are_params_valid(params: S::Vector<PARAMS>) -> S::Mask {
351 G::are_params_valid(params)
352 }
353
354 fn params_examples() -> alloc::vec::Vec<S::Vector<PARAMS>> {
355 G::params_examples()
356 }
357
358 fn invalid_params_examples() -> alloc::vec::Vec<S::Vector<PARAMS>> {
359 G::invalid_params_examples()
360 }
361}
362
363impl<
364 S: IsScalar<BATCH, DM, DN>,
365 const DOF: usize,
366 const PARAMS: usize,
367 const POINT: usize,
368 const AMBIENT: usize,
369 const BATCH: usize,
370 const DM: usize,
371 const DN: usize,
372 G: IsLieGroupImpl<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
373> HasParams<S, PARAMS, BATCH, DM, DN>
374 for LieGroup<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN, G>
375{
376 fn from_params(params: S::Vector<PARAMS>) -> Self {
377 assert!(
378 G::are_params_valid(params).all(),
379 "Invalid parameters for {:?}",
380 params.real_vector()
381 );
382 Self {
383 params: G::disambiguate(params),
384 phantom: core::marker::PhantomData,
385 }
386 }
387
388 fn set_params(&mut self, params: S::Vector<PARAMS>) {
389 self.params = G::disambiguate(params);
390 }
391
392 fn params(&self) -> &S::Vector<PARAMS> {
393 &self.params
394 }
395}
396
397impl<
398 S: IsScalar<BATCH, DM, DN>,
399 const DOF: usize,
400 const PARAMS: usize,
401 const POINT: usize,
402 const AMBIENT: usize,
403 const BATCH: usize,
404 const DM: usize,
405 const DN: usize,
406 G: IsLieGroupImpl<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
407> IsTangent<S, DOF, BATCH, DM, DN> for LieGroup<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN, G>
408{
409 fn tangent_examples() -> alloc::vec::Vec<<S as IsScalar<BATCH, DM, DN>>::Vector<DOF>> {
410 G::tangent_examples()
411 }
412}
413
414impl<
415 S: IsScalar<BATCH, DM, DN>,
416 const DOF: usize,
417 const PARAMS: usize,
418 const POINT: usize,
419 const AMBIENT: usize,
420 const BATCH: usize,
421 const DM: usize,
422 const DN: usize,
423 G: IsLieGroupImpl<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
424> IsLieGroup<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>
425 for LieGroup<S, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN, G>
426{
427 type G = G;
428 type GenG<S2: IsScalar<BATCH, DM, DN>> = G::GenG<S2>;
429 type RealG = G::RealG;
430 type DualG<const M: usize, const N: usize> = G::DualG<M, N>;
431
432 type GenGroup<
433 S2: IsScalar<BATCH, DM, DN>,
434 G2: IsLieGroupImpl<S2, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN>,
435 > = LieGroup<S2, DOF, PARAMS, POINT, AMBIENT, BATCH, DM, DN, G2>;
436
437 const DOF: usize = DOF;
438
439 const PARAMS: usize = PARAMS;
440
441 const POINT: usize = POINT;
442
443 const AMBIENT: usize = AMBIENT;
444}