from __future__ import annotations
from abc import ABC
from numbers import Number
from typing import overload
from torch import nn, Tensor
from .basics import vec2skew
import torch, warnings, importlib
from contextlib import contextmanager
from .operation import broadcast_inputs
from collections.abc import Iterable, Sequence
from torch.utils._pytree import tree_map, tree_flatten
from .operation import SO3_Log, SE3_Log, RxSO3_Log, Sim3_Log
from .operation import so3_Exp, se3_Exp, rxso3_Exp, sim3_Exp
from .operation import SO3_Act, SE3_Act, RxSO3_Act, Sim3_Act
from .operation import SO3_Mul, SE3_Mul, RxSO3_Mul, Sim3_Mul
from .operation import SO3_Inv, SE3_Inv, RxSO3_Inv, Sim3_Inv
from .operation import SO3_Act4, SE3_Act4, RxSO3_Act4, Sim3_Act4
from .operation import SO3_AdjXa, SE3_AdjXa, RxSO3_AdjXa, Sim3_AdjXa
from .operation import SO3_AdjTXa, SE3_AdjTXa, RxSO3_AdjTXa, Sim3_AdjTXa
from .operation import so3_Jl_inv, se3_Jl_inv, rxso3_Jl_inv, sim3_Jl_inv
from ..basics import pm, cumops_, cummul_, cumprod_, cumops, cummul, cumprod
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple, _ntuple
HANDLED_FUNCTIONS = ['__getitem__', '__setitem__', 'cpu', 'cuda', 'float', 'double',
'to', 'detach', 'view', 'view_as', 'squeeze', 'unsqueeze', 'cat',
'stack', 'split', 'hsplit', 'dsplit', 'vsplit', 'tensor_split',
'chunk', 'concat', 'column_stack', 'dstack', 'vstack', 'hstack',
'index_select', 'masked_select', 'movedim', 'moveaxis', 'narrow',
'permute', 'reshape', 'row_stack', 'scatter', 'scatter_add', 'clone',
'swapaxes', 'swapdims', 'take', 'take_along_dim', 'tile', 'copy',
'transpose', 'unbind', 'gather', 'repeat', 'expand', 'expand_as',
'index_select', 'masked_select', 'index_copy', 'index_copy_',
'select', 'select_scatter', 'index_put','index_put_', 'copy_']
class LieType(ABC):
'''LieTensor Type Base Class'''
def __init__(self, dimension, embedding, manifold):
self._dimension = torch.Size([dimension]) # Data dimension
self._embedding = torch.Size([embedding]) # Embedding dimension
self._manifold = torch.Size([manifold]) # Manifold dimension
@property
def dimension(self) -> torch.Size:
return self._dimension
@property
def embedding(self) -> torch.Size:
return self._embedding
@property
def manifold(self) -> torch.Size:
return self._manifold
@property
def on_manifold(self) -> bool:
return self.dimension == self.manifold
def add_(self, input: LieTensor, other: Tensor) -> LieTensor:
if self.on_manifold:
other1 = Tensor.as_subclass(input, Tensor)
other2 = Tensor.as_subclass(other, Tensor)
return input.copy_(other1 + other2[..., :self.manifold[0]])
raise NotImplementedError("Instance has no add_ attribute.")
def Log(self, X: LieTensor) -> LieTensor:
if self.on_manifold:
raise AttributeError("Lie Algebra has no Log attribute")
raise NotImplementedError("Instance has no Log attribute.")
def Exp(self, x: LieTensor) -> LieTensor:
if not self.on_manifold:
raise AttributeError("Lie Group has no Exp attribute")
raise NotImplementedError("Instance has no Exp attribute.")
def Inv(self, X: LieTensor) -> LieTensor:
if self.on_manifold:
return LieTensor(-X, ltype=self)
raise NotImplementedError("Instance has no Inv attribute.")
def Act(self, X: LieTensor, p: Tensor) -> Tensor:
""" action on a points tensor(*, 3[4]) (homogeneous)"""
if not self.on_manifold:
raise AttributeError("Lie Group has no Act attribute")
raise NotImplementedError("Instance has no Act attribute.")
@overload
def Mul(self, X: LieTensor, Y: Number | LieTensor) -> LieTensor: ...
@overload
def Mul(self, X: LieTensor, Y: Tensor) -> Tensor | LieTensor: ...
def Mul(self, X: LieTensor, Y: Number | Tensor | LieTensor) -> Tensor | LieTensor:
if not self.on_manifold:
raise AttributeError("Lie Group has no Mul attribute")
raise NotImplementedError("Instance has no Mul attribute.")
def Retr(self, X: LieTensor, a: LieTensor) -> LieTensor:
if self.on_manifold:
raise AttributeError("Has no Retr attribute")
return a.Exp() * X
def Adj(self, X: LieTensor, a: LieTensor) -> LieTensor:
''' X * Exp(a) = Exp(Adj) * X '''
if not self.on_manifold:
raise AttributeError("Lie Group has no Adj attribute")
raise NotImplementedError("Instance has no Adj attribute.")
def AdjT(self, X: LieTensor, a: LieTensor) -> LieTensor:
''' Exp(a) * X = X * Exp(AdjT) '''
if not self.on_manifold:
raise AttributeError("Lie Group has no AdjT attribute")
raise NotImplementedError("Instance has no AdjT attribute.")
def Jinvp(self, X: LieTensor, p: LieTensor) -> LieTensor:
if not self.on_manifold:
raise AttributeError("Lie Group has no Jinvp attribute")
raise NotImplementedError("Instance has no Jinvp attribute.")
def Jr(self, X: LieTensor) -> Tensor:
raise NotImplementedError("Instance has no Jr attribute")
def matrix(self, input: LieTensor) -> Tensor:
""" To 4x4 matrix """
X = input.Exp() if self.on_manifold else input
I = torch.eye(4, dtype=X.dtype, device=X.device)
I = I.view([1] * (X.dim() - 1) + [4, 4])
return X.unsqueeze(-2).Act(I).transpose(-1,-2)
def rotation(self, input: LieTensor) -> LieTensor:
raise NotImplementedError("Rotation is not implemented for the instance.")
def translation(self, input: LieTensor) -> Tensor:
warnings.warn("Instance has no translation. Zero vector(s) is returned.")
return torch.zeros(input.lshape + (3,), dtype=input.dtype, device=input.device,
requires_grad=input.requires_grad)
def scale(self, input: LieTensor) -> Tensor:
warnings.warn("Instance has no scale. Scalar one(s) is returned.")
return torch.ones(input.lshape + (1,), dtype=input.dtype, device=input.device,
requires_grad=input.requires_grad)
@classmethod
def to_tuple(cls, input):
out = tuple()
for i in input:
if not isinstance(i, Iterable):
out += (i,)
else:
out += tuple(i)
return out
@classmethod
def identity(cls, *args, **kwargs) -> LieTensor:
raise NotImplementedError("Instance has no identity.")
@classmethod
def identity_like(cls, *args, **kwargs):
return cls.identity(*args, **kwargs)
@classmethod
def identity_(cls, X: LieTensor) -> LieTensor:
raise NotImplementedError("Instance has no identity_ method")
def randn_like(self, *args, sigma=1.0, **kwargs):
return self.randn(*args, sigma=sigma, **kwargs)
def randn(self, *args, **kwargs) -> LieTensor:
raise NotImplementedError("Instance has no randn method")
@classmethod
def cumops(cls, X: LieTensor, dim: int, ops):
return cumops(X, dim, ops)
@classmethod
def cummul(cls, X: LieTensor, dim: int, left = True) -> LieTensor:
return cummul(X, dim, left)
@classmethod
def cumprod(cls, X: LieTensor, dim: int, left = True) -> LieTensor:
return cumprod(X, dim, left)
@classmethod
def cumops_(cls, X: LieTensor, dim: int, ops):
return cumops_(X, dim, ops)
@classmethod
def cummul_(cls, X: LieTensor, dim: int, left = True) -> LieTensor:
return cummul_(X, dim, left)
@classmethod
def cumprod_(cls, X: LieTensor, dim: int, left = True) -> LieTensor:
return cumprod_(X, dim, left)
class SO3Type(LieType):
def __init__(self):
super().__init__(4, 4, 3)
def Log(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
x = SO3_Log.apply(X)
return LieTensor(x, ltype=so3_type)
def Act(self, X, p):
assert not self.on_manifold and isinstance(p, Tensor)
assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
X = X.tensor() if isinstance(X, LieTensor) else X
input, out_shape = broadcast_inputs(X, p)
if p.shape[-1]==3:
out = SO3_Act.apply(*input)
else:
out = SO3_Act4.apply(*input)
dim = -1 if out.nelement() != 0 else p.shape[-1]
return out.view(out_shape + (dim,))
def Mul(self, X, Y):
# Transform on transform
X = X.tensor() if isinstance(X, LieTensor) else X
if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
input, out_shape = broadcast_inputs(X, Y.tensor())
out = SO3_Mul.apply(*input)
dim = -1 if out.nelement() != 0 else X.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=SO3_type)
# Transform on points
if not self.on_manifold and isinstance(Y, Tensor):
return self.Act(X, Y)
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=SO3_type)
raise NotImplementedError('Invalid __mul__ operation')
def Inv(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
out = SO3_Inv.apply(X)
return LieTensor(out, ltype=SO3_type)
def Adj(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = SO3_AdjXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=so3_type)
def AdjT(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = SO3_AdjTXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=so3_type)
def Jinvp(self, X, p):
X = X.tensor() if isinstance(X, LieTensor) else X
p = p.tensor() if isinstance(p, LieTensor) else p
(X, p), out_shape = broadcast_inputs(X, p)
out = (so3_Jl_inv(SO3_Log.apply(X)) @ p.unsqueeze(-1)).squeeze(-1)
dim = -1 if out.nelement() != 0 else p.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=so3_type)
@classmethod
def identity(cls, *size, **kwargs):
data = torch.tensor([0., 0., 0., 1.], **kwargs)
return LieTensor(data.repeat(size+(1,)), ltype=SO3_type)
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
data = so3_type.Exp(so3_type.randn(*size, sigma=sigma, **kwargs)).detach()
X = LieTensor(data, ltype=SO3_type)
X.requires_grad_(requires_grad)
return X
@classmethod
def add_(cls, input, other):
return input.copy_(LieTensor(other[..., :3], ltype=so3_type).Exp() * input)
def matrix(self, input):
""" To 3x3 matrix """
I = torch.eye(3, dtype=input.dtype, device=input.device)
I = I.view([1] * (input.dim() - 1) + [3, 3])
return input.unsqueeze(-2).Act(I).transpose(-1,-2)
def rotation(self, input):
return input
@classmethod
def identity_(cls, X):
X.fill_(0)
X.index_fill_(dim=-1, index=torch.tensor([-1], device=X.device), value=1)
return X
def Jr(self, X):
"""
Right jacobian of SO(3)
"""
return X.Log().Jr()
class so3Type(LieType):
def __init__(self):
super().__init__(3, 4, 3)
def Exp(self, x):
x = x.tensor() if isinstance(x, LieTensor) else x
X = so3_Exp.apply(x)
return LieTensor(X, ltype=SO3_type)
def Mul(self, X, Y):
X = X.tensor() if isinstance(X, LieTensor) else X
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=so3_type)
raise NotImplementedError('Invalid __mul__ operation')
@classmethod
def identity(cls, *size, **kwargs):
return SO3_type.Log(SO3_type.identity(*size, **kwargs))
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
assert isinstance(sigma, Number), 'Only accepts sigma as a single number'
size = self.to_tuple(size)
data = torch.randn(*(size + torch.Size([3])), **kwargs)
dist = data.norm(dim=-1, keepdim=True)
theta = sigma * torch.randn(*(size + torch.Size([1])), **kwargs)
x = LieTensor(data / dist * theta, ltype=so3_type)
x.requires_grad_(requires_grad)
return x
def matrix(self, input):
""" To 3x3 matrix """
X = input.Exp()
I = torch.eye(3, dtype=X.dtype, device=X.device)
I = I.view([1] * (X.dim() - 1) + [3, 3])
return X.unsqueeze(-2).Act(I).transpose(-1,-2)
def rotation(self, input):
return input.Exp().rotation()
def Jr(self, X):
"""
Right jacobian of so(3)
"""
K = vec2skew(X)
theta = torch.linalg.norm(X, dim=-1, keepdim=True).unsqueeze(-1)
I = torch.eye(3, device=X.device, dtype=X.dtype).expand(X.lshape+(3, 3))
Jr = I - (1-theta.cos())/theta**2 * K + (theta - theta.sin())/theta**3 * K@K
return torch.where(theta>torch.finfo(theta.dtype).eps, Jr, I)
class SE3Type(LieType):
def __init__(self):
super().__init__(7, 7, 6)
def Log(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
x = SE3_Log.apply(X)
return LieTensor(x, ltype=se3_type)
def Act(self, X, p):
assert not self.on_manifold and isinstance(p, Tensor)
assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
X = X.tensor() if isinstance(X, LieTensor) else X
input, out_shape = broadcast_inputs(X, p)
if p.shape[-1]==3:
out = SE3_Act.apply(*input)
else:
out = SE3_Act4.apply(*input)
dim = -1 if out.nelement() != 0 else p.shape[-1]
return out.view(out_shape + (dim,))
def Mul(self, X, Y):
# Transform on transform
X = X.tensor() if isinstance(X, LieTensor) else X
if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
Y = Y.tensor() if hasattr(Y, 'ltype') else Y
input, out_shape = broadcast_inputs(X, Y)
out = SE3_Mul.apply(*input)
dim = -1 if out.nelement() != 0 else X.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=SE3_type)
# Transform on points
if not self.on_manifold and isinstance(Y, Tensor):
return self.Act(X, Y)
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=SE3_type)
raise NotImplementedError('Invalid __mul__ operation')
def Inv(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
out = SE3_Inv.apply(X)
return LieTensor(out, ltype=SE3_type)
def rotation(self, input):
return LieTensor(input.tensor()[..., 3:7], ltype=SO3_type)
def translation(self, input):
return input.tensor()[..., 0:3]
def Adj(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = SE3_AdjXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=se3_type)
def AdjT(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = SE3_AdjTXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=se3_type)
def Jinvp(self, X, p):
X = X.tensor() if isinstance(X, LieTensor) else X
p = p.tensor() if isinstance(p, LieTensor) else p
(X, p), out_shape = broadcast_inputs(X, p)
out = (se3_Jl_inv(SE3_Log.apply(X)) @ p.unsqueeze(-1)).squeeze(-1)
dim = -1 if out.nelement() != 0 else p.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=se3_type)
@classmethod
def identity(cls, *size, **kwargs):
data = torch.tensor([0., 0., 0., 0., 0., 0., 1.], **kwargs)
return LieTensor(data.repeat(size+(1,)), ltype=SE3_type)
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
data = se3_type.Exp(se3_type.randn(*size, sigma=sigma, **kwargs)).detach()
X = LieTensor(data, ltype=SE3_type)
X.requires_grad_(requires_grad)
return X
@classmethod
def add_(cls, input, other):
return input.copy_(LieTensor(other[..., :6], ltype=se3_type).Exp() * input)
class se3Type(LieType):
def __init__(self):
super().__init__(6, 7, 6)
def Exp(self, x):
x = x.tensor() if isinstance(x, LieTensor) else x
X = se3_Exp.apply(x)
return LieTensor(X, ltype=SE3_type)
def Mul(self, X, Y):
X = X.tensor() if isinstance(X, LieTensor) else X
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=se3_type)
raise NotImplementedError('Invalid __mul__ operation')
def rotation(self, input):
return input.Exp().rotation()
def translation(self, input):
return input.Exp().translation()
@classmethod
def identity(cls, *size, **kwargs):
return SE3_type.Log(SE3_type.identity(*size, **kwargs))
def randn(self, *size, sigma: float | int | Sequence = 1.0, requires_grad=False, **kwargs):
# convert different types of inputs to SE3 sigma
if not isinstance(sigma, Sequence):
sigma = _quadruple(sigma)
elif len(sigma)==2:
rotation_sigma = _single(sigma[-1])
translation_sigma = _triple(sigma[0])
sigma = translation_sigma + rotation_sigma
else:
assert len(sigma)==4, 'Only accepts a tuple of sigma in size 1, 2, or 4.'
size = self.to_tuple(size)
rotation = so3_type.randn(*size, sigma=sigma[-1], **kwargs).tensor().detach()
t_sigma = torch.tensor([sigma[0], sigma[1], sigma[2]], **kwargs)
translation = t_sigma * torch.randn(*(size + torch.Size([3])), **kwargs)
data = torch.cat([translation, rotation], dim=-1)
x = LieTensor(data, ltype=se3_type)
x.requires_grad_(requires_grad)
return x
class Sim3Type(LieType):
def __init__(self):
super().__init__(8, 8, 7)
def Log(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
x = Sim3_Log.apply(X)
return LieTensor(x, ltype=sim3_type)
def Act(self, X, p):
assert not self.on_manifold and isinstance(p, Tensor)
assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
X = X.tensor() if isinstance(X, LieTensor) else X
input, out_shape = broadcast_inputs(X, p)
if p.shape[-1]==3:
out = Sim3_Act.apply(*input)
else:
out = Sim3_Act4.apply(*input)
dim = -1 if out.nelement() != 0 else p.shape[-1]
return out.view(out_shape + (dim,))
def Mul(self, X, Y):
# Transform on transform
X = X.tensor() if isinstance(X, LieTensor) else X
if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
Y = Y.tensor() if hasattr(Y, 'ltype') else Y
input, out_shape = broadcast_inputs(X, Y)
out = Sim3_Mul.apply(*input)
dim = -1 if out.nelement() != 0 else X.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=Sim3_type)
# Transform on points
if not self.on_manifold and isinstance(Y, Tensor):
return self.Act(X, Y)
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=Sim3_type)
raise NotImplementedError('Invalid __mul__ operation')
def Inv(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
out = Sim3_Inv.apply(X)
return LieTensor(out, ltype=Sim3_type)
def Adj(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = Sim3_AdjXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=sim3_type)
def AdjT(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = Sim3_AdjTXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=sim3_type)
def Jinvp(self, X, p):
X = X.tensor() if isinstance(X, LieTensor) else X
p = p.tensor() if isinstance(p, LieTensor) else p
(X, p), out_shape = broadcast_inputs(X, p)
out = (sim3_Jl_inv(Sim3_Log.apply(X)) @ p.unsqueeze(-1)).squeeze(-1)
dim = -1 if out.nelement() != 0 else p.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=sim3_type)
def rotation(self, input):
return LieTensor(input.tensor()[..., 3:7], ltype=SO3_type)
def translation(self, input):
return input.tensor()[..., 0:3]
def scale(self, input):
return input.tensor()[..., 7:8]
@classmethod
def identity(cls, *size, **kwargs):
data = torch.tensor([0., 0., 0., 0., 0., 0., 1., 1.], **kwargs)
return LieTensor(data.repeat(size+(1,)), ltype=Sim3_type)
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
data = sim3_type.Exp(sim3_type.randn(*size, sigma=sigma, **kwargs)).detach()
X = LieTensor(data, ltype=Sim3_type)
X.requires_grad_(requires_grad)
return X
@classmethod
def add_(cls, input, other):
return input.copy_(LieTensor(other[..., :7], ltype=sim3_type).Exp() * input)
class sim3Type(LieType):
def __init__(self):
super().__init__(7, 8, 7)
def Exp(self, x):
x = x.tensor() if isinstance(x, LieTensor) else x
X = sim3_Exp.apply(x)
return LieTensor(X, ltype=Sim3_type)
def Mul(self, X, Y):
X = X.tensor() if isinstance(X, LieTensor) else X
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=sim3_type)
raise NotImplementedError('Invalid __mul__ operation')
def rotation(self, input):
return input.Exp().rotation()
def translation(self, input):
return input.Exp().translation()
def scale(self, input):
return input.Exp().scale()
@classmethod
def identity(cls, *size, **kwargs):
return Sim3_type.Log(Sim3_type.identity(*size, **kwargs))
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
if not isinstance(sigma, Sequence):
sigma = _ntuple(5, "_penta")(sigma)
elif len(sigma)==3:
rotation_sigma = _single(sigma[-2])
scale_sigma = _single(sigma[-1])
translation_sigma = _triple(sigma[0])
sigma = translation_sigma+rotation_sigma+scale_sigma
else:
assert len(sigma)==5, 'Only accepts a tuple of sigma in size 1, 3, or 5.'
size = self.to_tuple(size)
rotation = so3_type.randn(*size, sigma=sigma[-2], **kwargs).tensor().detach()
scale = sigma[-1] * torch.randn(*(size + torch.Size([1])), **kwargs)
sigma = torch.tensor([sigma[0], sigma[1], sigma[2]], **kwargs)
translation = sigma * torch.randn(*(size + torch.Size([3])), **kwargs)
data = torch.cat([translation, rotation, scale], dim=-1)
return LieTensor(data, ltype=sim3_type).requires_grad_(requires_grad)
class RxSO3Type(LieType):
def __init__(self):
super().__init__(5, 5, 4)
def Log(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
x = RxSO3_Log.apply(X)
return LieTensor(x, ltype=rxso3_type)
def Act(self, X, p):
assert not self.on_manifold and isinstance(p, Tensor)
assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
X = X.tensor() if isinstance(X, LieTensor) else X
input, out_shape = broadcast_inputs(X, p)
if p.shape[-1]==3:
out = RxSO3_Act.apply(*input)
else:
out = RxSO3_Act4.apply(*input)
dim = -1 if out.nelement() != 0 else p.shape[-1]
return out.view(out_shape + (dim,))
def Mul(self, X, Y):
# Transform on transform
X = X.tensor() if isinstance(X, LieTensor) else X
if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
Y = Y.tensor() if hasattr(Y, 'ltype') else Y
input, out_shape = broadcast_inputs(X, Y)
out = RxSO3_Mul.apply(*input)
dim = -1 if out.nelement() != 0 else X.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=RxSO3_type)
# Transform on points
if not self.on_manifold and isinstance(Y, Tensor):
return self.Act(X, Y)
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=RxSO3_type)
raise NotImplementedError('Invalid __mul__ operation')
def Inv(self, X):
X = X.tensor() if isinstance(X, LieTensor) else X
out = RxSO3_Inv.apply(X)
return LieTensor(out, ltype=RxSO3_type)
def Adj(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = RxSO3_AdjXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=rxso3_type)
def AdjT(self, X, a):
X = X.tensor() if isinstance(X, LieTensor) else X
a = a.tensor() if isinstance(a, LieTensor) else a
input, out_shape = broadcast_inputs(X, a)
out = RxSO3_AdjTXa.apply(*input)
dim = -1 if out.nelement() != 0 else a.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=rxso3_type)
def Jinvp(self, X, p):
X = X.tensor() if isinstance(X, LieTensor) else X
p = p.tensor() if isinstance(p, LieTensor) else p
(X, p), out_shape = broadcast_inputs(X, p)
out = (rxso3_Jl_inv(RxSO3_Log.apply(X)) @ p.unsqueeze(-1)).squeeze(-1)
dim = -1 if out.nelement() != 0 else p.shape[-1]
out = out.view(out_shape + (dim,))
return LieTensor(out, ltype=rxso3_type)
def rotation(self, input):
return LieTensor(input.tensor()[..., 0:4], ltype=SO3_type)
def scale(self, input):
return input.tensor()[..., 4:5]
@classmethod
def identity(cls, *size, **kwargs):
data = torch.tensor([0., 0., 0., 1., 1.], **kwargs)
return LieTensor(data.repeat(size+(1,)), ltype=RxSO3_type)
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
data = rxso3_type.Exp(rxso3_type.randn(*size, sigma=sigma, **kwargs)).detach()
X = LieTensor(data, ltype=RxSO3_type)
X.requires_grad_(requires_grad)
return X
@classmethod
def add_(cls, input, other):
return input.copy_(LieTensor(other[..., :4], ltype=rxso3_type).Exp() * input)
class rxso3Type(LieType):
def __init__(self):
super().__init__(4, 5, 4)
def Exp(self, x):
x = x.tensor() if isinstance(x, LieTensor) else x
X = rxso3_Exp.apply(x)
return LieTensor(X, ltype=RxSO3_type)
def Mul(self, X, Y):
X = X.tensor() if isinstance(X, LieTensor) else X
# (scalar or tensor) * manifold
if self.on_manifold:
return LieTensor(torch.mul(X, Y), ltype=rxso3_type)
raise NotImplementedError('Invalid __mul__ operation')
def rotation(self, input):
return input.Exp().rotation()
def scale(self, input):
return input.Exp().scale()
@classmethod
def identity(cls, *size, **kwargs):
return RxSO3_type.Log(RxSO3_type.identity(*size, **kwargs))
def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
if not isinstance(sigma, Sequence):
sigma = _pair(sigma)
else:
assert len(sigma)==2, 'Only accepts a tuple of sigma in size 1 or 2.'
size = self.to_tuple(size)
rotation = so3_type.randn(*size, sigma=sigma[0], **kwargs).tensor()
scale = sigma[1] * torch.randn(*(size + torch.Size([1])), **kwargs)
data = torch.cat([rotation, scale], dim=-1)
x = LieTensor(data, ltype=rxso3_type)
x.requires_grad_(requires_grad)
return x
SO3_type, so3_type = SO3Type(), so3Type()
SE3_type, se3_type = SE3Type(), se3Type()
Sim3_type, sim3_type = Sim3Type(), sim3Type()
RxSO3_type, rxso3_type = RxSO3Type(), rxso3Type()
liegroup = [SO3_type, SE3_type, Sim3_type, RxSO3_type]
liealgebra = [so3_type, se3_type, sim3_type, rxso3_type]
[docs]class LieTensor(Tensor):
r""" A sub-class of :obj:`Tensor` to represent Lie Algebra and Lie Group.
Args:
data (:obj:`Tensor`, or :obj:`list`, or ':obj:`int`...'): A
:obj:`Tensor` object, or constructing a :obj:`Tensor`
object from :obj:`list`, which defines tensor data, or from
':obj:`int`...', which defines tensor shape.
The shape of :obj:`Tensor` object should be compatible
with Lie Type :obj:`ltype`, otherwise error will be raised.
ltype (:obj:`ltype`): Lie Type, either **Lie Group** or **Lie Algebra** is listed below:
Returns:
LieTensor corresponding to Lie Type :obj:`ltype`.
.. list-table:: List of :obj:`ltype` for **Lie Group**
:widths: 25 25 30 30
:header-rows: 1
* - Representation
- :obj:`ltype`
- :obj:`shape`
- Alias Class
* - Rotation
- :obj:`SO3_type`
- :obj:`(*, 4)`
- :meth:`SO3`
* - Translation + Rotation
- :obj:`SE3_type`
- :obj:`(*, 7)`
- :meth:`SE3`
* - Translation + Rotation + Scale
- :obj:`Sim3_type`
- :obj:`(*, 8)`
- :meth:`Sim3`
* - Rotation + Scale
- :obj:`RxSO3_type`
- :obj:`(*, 5)`
- :meth:`RxSO3`
.. list-table:: List of :obj:`ltype` for **Lie Algebra**
:widths: 25 25 30 30
:header-rows: 1
* - Representation
- :obj:`ltype`
- :obj:`shape`
- Alias Class
* - Rotation
- :obj:`so3_type`
- :obj:`(*, 3)`
- :meth:`so3`
* - Translation + Rotation
- :obj:`se3_type`
- :obj:`(*, 6)`
- :meth:`se3`
* - Translation + Rotation + Scale
- :obj:`sim3_type`
- :obj:`(*, 7)`
- :meth:`sim3`
* - Rotation + Scale
- :obj:`rxso3_type`
- :obj:`(*, 4)`
- :meth:`rxso3`
Note:
Two attributes :obj:`shape` and :obj:`lshape` are available for LieTensor.
The only differece is the :obj:`lshape` hides the last dimension of :obj:`shape`,
since :obj:`lshape` takes the data in the last dimension as a single :obj:`ltype` item.
See LieTensor method :meth:`lview` for more details.
Examples:
>>> import torch
>>> import pypose as pp
>>> data = torch.randn(3, 3, requires_grad=True, device='cuda:0')
>>> pp.LieTensor(data, ltype=pp.so3_type)
so3Type LieTensor:
tensor([[ 0.9520, 0.4517, 0.5834],
[-0.8106, 0.8197, 0.7077],
[-0.5743, 0.8182, -1.2104]], device='cuda:0', grad_fn=<AliasBackward0>)
Alias class for specific LieTensor is recommended:
>>> pp.so3(data)
so3Type LieTensor:
tensor([[ 0.9520, 0.4517, 0.5834],
[-0.8106, 0.8197, 0.7077],
[-0.5743, 0.8182, -1.2104]], device='cuda:0', grad_fn=<AliasBackward0>)
See more alias classes at `Table 1 for Lie Group <#id1>`_ and `Table 2 for Lie Algebra <#id2>`_.
Other constructors:
- From list.
>>> pp.so3([0, 0, 0])
so3Type LieTensor:
tensor([0., 0., 0.])
- From ints.
>>> pp.so3(2, 3)
so3Type LieTensor:
tensor([[0., 0., 0.],
[0., 0., 0.]])
Note:
Alias class for LieTensor is recommended.
For example, the following usage is equivalent:
- :obj:`pp.LieTensor(tensor, ltype=pp.so3_type)`
- :obj:`pp.so3(tensor)` (This is preferred).
Note:
All attributes from Tensor are available for LieTensor, e.g., :obj:`dtype`,
:obj:`device`, and :obj:`requires_grad`. See more details at
`tensor attributes <https://pytorch.org/docs/stable/tensor_attributes.html>`_.
Example:
>>> data = torch.randn(1, 3, dtype=torch.float64, device="cuda", requires_grad=True)
>>> pp.so3(data) # All Tensor attributes are available for LieTensor
so3Type LieTensor:
tensor([[-1.5948, 0.3113, -0.9807]], device='cuda:0', dtype=torch.float64,
grad_fn=<AliasBackward0>)
Note:
In most of the cases, Lie Group is expected to be used, therefore we only provide
`converting functions <https://pypose.org/docs/main/convert/>`_ between Lie Groups
and other data structures, e.g., transformation matrix, Euler angle, etc. The users
can convert data between Lie Group and Lie algebra with :obj:`Exp` and :obj:`Log`.
"""
def __init__(self, *data, ltype:LieType):
assert self.shape[-1:] == ltype.dimension, 'The last dimension of a LieTensor has to be ' \
'corresponding to their LieType. More details go to {}. If this error happens in an ' \
'optimization process, where LieType is not a necessary structure, we suggest to ' \
'call .tensor() to convert a LieTensor to Tensor before passing it to an optimizer. ' \
'If this still happens, create an issue on GitHub please.'.format(
'https://pypose.org/docs/main/generated/pypose.LieTensor')
self.ltype = ltype
@staticmethod
def __new__(cls, *data, ltype):
tensor = data[0] if isinstance(data[0], Tensor) else Tensor(*data)
return Tensor.as_subclass(tensor, LieTensor)
def __repr__(self):
if hasattr(self, 'ltype'):
return self.ltype.__class__.__name__ + \
' %s:\n'%(self.__class__.__name__) + super().__repr__()
else:
return super().__repr__()
[docs] def new_empty(self, size, *, dtype=None, layout=None, device=None, pin_memory=None,
requires_grad=None):
tensor = torch.empty(
size,
dtype=self.dtype if dtype is None else dtype,
layout=self.layout if layout is None else layout,
device=self.device if device is None else device,
pin_memory=pin_memory,
requires_grad=self.requires_grad if requires_grad is None else requires_grad,
)
out = Tensor.as_subclass(tensor, type(self))
if hasattr(self, 'ltype'):
out.ltype = self.ltype
return out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs={}):
ltypes = (Tensor if t is LieTensor or Parameter else t for t in types)
data = Tensor.__torch_function__(func, ltypes, args, kwargs)
if data is not None and hasattr(func, '__name__') and func.__name__ in HANDLED_FUNCTIONS:
args, spec = tree_flatten(args)
ltype = [arg.ltype for arg in args if isinstance(arg, LieTensor)][0]
def wrap(t):
if isinstance(t, Tensor) and not isinstance(t, cls):
lt = Tensor.as_subclass(t, LieTensor)
lt.ltype = ltype
if lt.shape[-1:] != lt.ltype.dimension:
link = 'https://pypose.org/docs/main/generated/pypose.LieTensor'
warnings.warn('Tensor Shape Invalid by calling {}, ' \
'go to {}'.format(func, link))
return lt
return t
return tree_map(wrap, data)
return data
@property
def lshape(self) -> torch.Size:
r'''
LieTensor Shape (shape of Tensor by ignoring the last dimension)
Returns:
torch.Size
Note:
- The only difference from :obj:`shape` is the last dimension is hidden,
since :obj:`lshape` takes the last dimension as a single :obj:`ltype` item.
- The last dimension can also be accessed via :obj:`LieTensor.ltype.dimension`.
Examples:
>>> x = pp.randn_SE3(2)
>>> x.lshape
torch.Size([2])
>>> x.shape
torch.Size([2, 7])
>>> x.ltype.dimension
torch.Size([7])
'''
return self.shape[:-1]
[docs] def lview(self, *shape) -> LieTensor:
r'''
Returns a new LieTensor with the same data as the self tensor but of a different
:obj:`lshape`.
Args:
shape (torch.Size or int...): the desired size
Returns:
A new lieGroup tensor sharing with the same data as the self tensor but of a
different shape.
Note:
The only difference from :meth:`view` is the last dimension is hidden.
See `Tensor.view <https://tinyurl.com/mrds8nmd>`_
for its usage.
Examples:
>>> x = pp.randn_so3(2,2)
>>> x.shape
torch.Size([2, 2, 3])
>>> x.lview(-1).lshape
torch.Size([4])
'''
return self.view(*shape+self.ltype.dimension)
[docs] def Exp(self) -> LieTensor:
r'''
See :meth:`pypose.Exp`
'''
return self.ltype.Exp(self)
[docs] def Log(self) -> LieTensor:
r'''
See :meth:`pypose.Log`
'''
return self.ltype.Log(self)
[docs] def Inv(self) -> LieTensor:
r'''
See :meth:`pypose.Inv`
'''
return self.ltype.Inv(self)
[docs] def Act(self, p: Tensor) -> Tensor:
r'''
See :meth:`pypose.Act`
'''
return self.ltype.Act(self, p)
[docs] def add(self, other, alpha=1):
r'''
See :meth:`pypose.add`
'''
return self.clone().add_(other = alpha * other)
[docs] def add_(self, other, alpha=1):
r'''
See :meth:`pypose.add_`
'''
return self.ltype.add_(self, other = alpha * other)
def __add__(self, other):
return self.add(other=other)
def __mul__(self, other):
r'''
See :meth:`pypose.mul`
'''
return self.ltype.Mul(self, other)
[docs] def mul(self, other):
r'''
See :meth:`pypose.mul`
'''
return self.ltype.Mul(self, other)
@overload
def __matmul__(self, other: LieTensor) -> LieTensor: ...
@overload
def __matmul__(self, other: Tensor) -> Tensor: ...
def __matmul__(self, other: Tensor):
r'''
See :meth:`pypose.matmul`
'''
if isinstance(other, LieTensor):
return self.ltype.Mul(self, other)
else: # Same with: self.ltype.matrix(self) @ other
return self.Act(other)
[docs] def Retr(self, a: LieTensor) -> LieTensor:
r'''
See :meth:`pypose.Retr`
'''
return self.ltype.Retr(self, a)
[docs] def Adj(self, a: LieTensor | Tensor) -> LieTensor:
r'''
See :meth:`pypose.Adj`
'''
return self.ltype.Adj(self, a)
[docs] def AdjT(self, a: LieTensor | Tensor) -> LieTensor:
r'''
See :meth:`pypose.AdjT`
'''
return self.ltype.AdjT(self, a)
[docs] def Jinvp(self, p: LieTensor | Tensor) -> LieTensor:
r'''
See :meth:`pypose.Jinvp`
'''
return self.ltype.Jinvp(self, p)
[docs] def Jr(self):
r'''
See :meth:`pypose.Jr`
'''
return self.ltype.Jr(self)
[docs] def tensor(self) -> Tensor:
r'''
See :meth:`pypose.tensor`
'''
return Tensor.as_subclass(self, Tensor)
[docs] def matrix(self) -> Tensor:
r'''
See :meth:`pypose.matrix`
'''
return self.ltype.matrix(self)
[docs] def translation(self) -> Tensor:
r'''
See :meth:`pypose.translation`
'''
return self.ltype.translation(self)
[docs] def rotation(self) -> LieTensor:
r'''
See :meth:`pypose.rotation`
'''
return self.ltype.rotation(self)
[docs] def scale(self) -> Tensor:
r'''
See :meth:`pypose.scale`
'''
return self.ltype.scale(self)
[docs] def euler(self, eps=2e-4) -> Tensor:
r'''
See :meth:`pypose.euler`
'''
data = self.rotation().tensor()
x, y = data[..., 0], data[..., 1]
z, w = data[..., 2], data[..., 3]
xx, yy, zz, ww = x*x, y*y, z*z, w*w
t0 = 2 * (w * x + y * z)
t1 = (ww + zz) - (xx + yy)
t2 = 2 * (w * y - z * x) / (xx + yy + zz + ww)
t3 = 2 * (w * z + x * y)
t4 = (ww + xx) - (yy + zz)
# sigularity when pitch angle ~ +/-pi/2
roll1 = torch.atan2(t0, t1)
roll2 = torch.zeros_like(t0)
flag = t2.abs() < 1. - eps
yaw1 = torch.atan2(t3, t4)
yaw2 = -2 * pm(t2) * torch.atan2(x, w)
roll = torch.where(flag, roll1, roll2)
pitch = torch.asin(t2.clamp(-1, 1))
yaw = torch.where(flag, yaw1, yaw2)
return torch.stack([roll, pitch, yaw], dim=-1)
[docs] def identity_(self):
r'''
Inplace set the LieTensor to identity.
Return:
LieTensor: the :obj:`self` LieTensor
Note:
The translation part, if there is, is set to zeros, while the
rotation part is set to identity quaternion.
Example:
>>> x = pp.randn_SO3(2)
>>> x
SO3Type LieTensor:
tensor([[-0.0724, 0.1970, 0.0022, 0.9777],
[ 0.3492, 0.4998, -0.5310, 0.5885]])
>>> x.identity_()
SO3Type LieTensor:
tensor([[0., 0., 0., 1.],
[0., 0., 0., 1.]])
'''
return self.ltype.identity_(self)
[docs] def cumops(self, dim: int, ops):
r"""
See :func:`pypose.cumops`
"""
return self.ltype.cumops(self, dim, ops)
[docs] def cummul(self, dim: int, left = True) -> LieTensor:
r"""
See :func:`pypose.cummul`
"""
return self.ltype.cummul(self, dim, left)
[docs] def cumprod(self, dim: int, left = True) -> LieTensor:
r"""
See :func:`pypose.cumprod`
"""
return self.ltype.cumprod(self, dim, left)
[docs] def cumops_(self, dim: int, ops):
r"""
Inplace version of :func:`pypose.cumops`
"""
return self.ltype.cumops_(self, dim, ops)
[docs] def cummul_(self, dim: int, left = True) -> LieTensor:
r"""
Inplace version of :func:`pypose.cummul`
"""
return self.ltype.cummul_(self, dim, left)
[docs] def cumprod_(self, dim: int, left = True) -> LieTensor:
r"""
Inplace version of :func:`pypose.cumprod`
"""
return self.ltype.cumprod_(self, dim, left)
[docs]class Parameter(LieTensor, nn.Parameter):
r"""
:class:`Parameter` wraps a :class:`torch.Tensor` or :class:`pypose.LieTensor`.
When setting `sjac=True`, it tracks all operations performed on the tensor,
allowing PyPose's sparse backend to build structured sparse Jacobians.
Args:
data (Tensor or LieTensor): parameter data.
requires_grad (bool, optional): if the parameter requires
gradient. Default: ``True``
sjac (bool, optional): if ``True``, sparse Jacobian tracing is enabled.
Default: ``False``
Use `sjac=True` together with :func:`pypose.autograd.function.parallel_for_sparse_jacobian`
(alias :func:`psjac`) for any :class:`pypose.Parameter` that needs to be optimized with
the sparse backend, when the optimization model is instantiated.
.. note::
Recommend to use alias :func:`pypose.autograd.function.psjac` for brevity.
.. warning::
Wrap the original batched tensor before performing any operation.
If you use a regular tensor or LieTensor instead,
the sparse backend will not recover the Jacobian for the tensor.
.. note::
:class:`Parameter` does not change numerical results. It only adds
tracing information for sparse Jacobian construction when setting `sjac=True`.
.. admonition:: Example
Below, ``self.poses`` and ``self.points_3d`` are the optimization variables
in a bundle-adjustment model.
``pp.Parameter(..., sjac=True)`` records the indexing and reprojection operations
so the sparse backend knows how each entry in the output depends on these variables.
This includes tracing :func:`psjac` functions as well,
so the dependency through ``project`` is preserved.
``observations``, ``camera_indices``, and ``point_indices``
do not need to be wrapped in ``Parameter`` because they are fixed inputs,
not optimization variables.
.. code-block:: python
import pypose as pp
from torch import nn
from pypose.autograd.function import psjac
class Reproj(nn.Module):
def __init__(self, poses, points_3d):
# self.points_3d: tensor (P, 3), self.poses: pp.SE3 (C, 7)
super().__init__()
self.poses = pp.Parameter(poses, sjac=True)
self.points_3d = pp.Parameter(points_3d, sjac=True)
@psjac
def project(points, poses):
# points: tensor (N, 3), poses: pp.SE3 (N, 7)
# returns: (N, 2)
points = poses.Act(points)
return -points[..., :2] / points[..., 2].unsqueeze(-1)
def forward(self, observations, camera_indices, point_indices):
# observations: tensor (N, 2), camera_indices: (N,), point_indices: (N,)
poses = self.poses[camera_indices]
points = self.points_3d[point_indices]
points_proj = Reproj.project(points, poses)
return points_proj - observations
"""
def __init__(self, data=None, requires_grad=True, sjac=False):
if hasattr(data, 'ltype'):
self.ltype = data.ltype
def __new__(cls, data=None, requires_grad=True, sjac=False):
if data is None:
data = torch.tensor([])
if sjac:
from .. import _require_backend_attr
TrackingTensor = _require_backend_attr(
"bae.autograd.function",
"TrackingTensor",
"pypose.Parameter(..., sjac=True)",
)
data = TrackingTensor(data)
return nn.Parameter(data, requires_grad)
if isinstance(data, LieTensor):
param = Tensor._make_subclass(cls, data.tensor(), requires_grad)
param.ltype = data.ltype
param._is_param = True
return param
return nn.Parameter(data, requires_grad)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.clone(memory_format=torch.preserve_format))
memo[id(self)] = result
return result
@contextmanager
def retain_ltype():
# save the original PyTorch functions
TO_BE_WRAPPED = {
torch.autograd.forward_ad.make_dual,
torch._functorch.eager_transforms._wrap_tensor_for_grad,
torch._functorch.vmap._add_batch_dim,
}
torch._functorch.vmap._add_batch_dim.__module__ = 'torch._functorch.vmap'
def wrap_function(func):
def wrapper(*args, **kwargs):
ltype = args[0].ltype if isinstance(args[0], LieTensor) else None
res = func(*args, **kwargs)
if ltype is not None:
res = Tensor.as_subclass(res, LieTensor)
res.ltype = ltype
return res
return wrapper
try:
# swap the original PyTorch functions with the wrapper
for func in TO_BE_WRAPPED:
module, name = func.__module__, func.__name__
module = importlib.import_module(module)
setattr(module, name, wrap_function(func))
yield
finally:
for func in TO_BE_WRAPPED:
module, name = func.__module__, func.__name__
module = importlib.import_module(module)
setattr(module, name, func)