Shortcuts

Source code for pypose.module.lqr

import torch
import torch.nn as nn
from .. import bmv, bvmv


[docs]class LQR(nn.Module): r''' Linear Quadratic Regulator (LQR) with Dynamic Programming. Args: system (:obj:`instance`): The system to be soved by LQR. Q (:obj:`Tensor`): The weight matrix of the quadratic term. p (:obj:`Tensor`): The weight vector of the first-order term. T (:obj:`int`): Time steps of system. A discrete-time linear system can be described as: .. math:: \begin{align*} \mathbf{x}_{t+1} &= \mathbf{A}_t\mathbf{x}_t + \mathbf{B}_t\mathbf{u}_t + \mathbf{c}_{1t} \\ \mathbf{y}_t &= \mathbf{C}_t\mathbf{x}_t + \mathbf{D}_t\mathbf{u}_t + \mathbf{c}_{2t} \\ \end{align*} where :math:`\mathbf{x}`, :math:`\mathbf{u}` are the state and input of the linear system; :math:`\mathbf{y}` is the observation of the linear system; :math:`\mathbf{A}`, :math:`\mathbf{B}` are the state matrix and input matrix of the linear system; :math:`\mathbf{C}`, :math:`\mathbf{D}` are the output matrix and observation matrix of the linear system; :math:`\mathbf{c}_{1}`, :math:`\mathbf{c}_{2}` are the constant input and constant output of the linear system. The subscript :math:`\cdot_{t}` denotes the time step. LQR finds the optimal nominal trajectory :math:`\mathbf{\tau}_{1:T}^*` = :math:`\begin{Bmatrix} \mathbf{x}_t, \mathbf{u}_t \end{Bmatrix}_{1:T}` for the linear system of the optimization problem: .. math:: \begin{align*} \mathbf{\tau}_{1:T}^* = \mathop{\arg\min}\limits_{\tau_{1:T}} \sum\limits_t\frac{1}{2} \mathbf{\tau}_t^\top\mathbf{Q}_t\mathbf{\tau}_t + \mathbf{p}_t^\top\mathbf{\tau}_t \\ \mathrm{s.t.} \quad \mathbf{x}_1 = \mathbf{x}_{\text{init}}, \\ \mathbf{x}_{t+1} = \mathbf{F}_t\mathbf{\tau}_t + \mathbf{c}_{1t} \\ \end{align*} where :math:`\mathbf{\tau}_t` = :math:`\begin{bmatrix} \mathbf{x}_t \\ \mathbf{u}_t \end{bmatrix}`, :math:`\mathbf{F}_t` = :math:`\begin{bmatrix} \mathbf{A}_t & \mathbf{B}_t \end{bmatrix}`. The LQR process can be summarised as a backward and a forward recursion. - The backward recursion. For :math:`t` = :math:`T` to 1: .. math:: \begin{align*} \mathbf{Q}_t &= \mathbf{Q}_t + \mathbf{F}_t^\top\mathbf{V}_{t+1}\mathbf{F}_t \\ \mathbf{q}_t &= \mathbf{q}_t + \mathbf{F}_t^\top\mathbf{V}_{t+1} \mathbf{c}_{1t} + \mathbf{F}_t^\top\mathbf{v}_{t+1} \\ \mathbf{K}_t &= -\mathbf{Q}_{\mathbf{u}_t, \mathbf{u}_t}^{-1} \mathbf{Q}_{\mathbf{u}_t, \mathbf{x}_t} \\ \mathbf{k}_t &= -\mathbf{Q}_{\mathbf{u}_t, \mathbf{u}_t}^{-1} \mathbf{q}_{\mathbf{u}_t} \\ \mathbf{V}_t &= \mathbf{Q}_{\mathbf{x}_t, \mathbf{x}_t} + \mathbf{Q}_{\mathbf{x}_t, \mathbf{u}_t}\mathbf{K}_t + \mathbf{K}_t^\top\mathbf{Q}_{\mathbf{u}_t, \mathbf{x}_t} + \mathbf{K}_t^\top\mathbf{Q}_{\mathbf{u}_t, \mathbf{u}_t}\mathbf{K}_t \\ \mathbf{v}_t &= \mathbf{q}_{\mathbf{x}_t} + \mathbf{Q}_{\mathbf{x}_t, \mathbf{u}_t}\mathbf{k}_t + \mathbf{K}_t^\top\mathbf{q}_{\mathbf{u}_t} + \mathbf{K}_t^\top\mathbf{Q}_{\mathbf{u}_t, \mathbf{u}_t}\mathbf{k}_t \\ \end{align*} - The forward recursion. For :math:`t` = 1 to :math:`T`: .. math:: \begin{align*} \mathbf{u}_t &= \mathbf{K}_t\mathbf{x}_t + \mathbf{k}_t \\ \mathbf{x}_{t+1} &= \mathbf{A}_t\mathbf{x}_t + \mathbf{B}_t\mathbf{u}_t + \mathbf{c}_{1t} \\ \end{align*} Then quadratic costs of the system over the time horizon: .. math:: \mathbf{c} \left( \mathbf{\tau}_t \right) = \frac{1}{2} \mathbf{\tau}_t^\top\mathbf{Q}_t\mathbf{\tau}_t + \mathbf{p}_t^\top\mathbf{\tau}_t Note: The discrete-time system to be solved by LQR could be both either linear time-invariant (:meth:`LTI`) system or linear time-varying (:meth:`LTV`) system. From the learning perspective, this can be interpreted as a module with unknown parameters :math:`\begin{Bmatrix} \mathbf{Q}, \mathbf{p}, \mathbf{F}, \mathbf{f} \end{Bmatrix}`, which can be integrated into a larger end-to-end learning system. Note: The implementation is based on page 24-32 of `Optimal Control and Planning <http://rll.berkeley.edu/deeprlcourse/f17docs/lecture_8_model_based_planning.pdf>`_. Example: >>> n_batch, T = 2, 5 >>> n_state, n_ctrl = 4, 3 >>> n_sc = n_state + n_ctrl >>> Q = torch.randn(n_batch, T, n_sc, n_sc) >>> Q = torch.matmul(Q.mT, Q) >>> p = torch.randn(n_batch, T, n_sc) >>> r = 0.2 * torch.randn(n_state, n_state) >>> A = torch.tile(torch.eye(n_state) + r, (n_batch, 1, 1)) >>> B = torch.randn(n_batch, n_state, n_ctrl) >>> C = torch.tile(torch.eye(n_state), (n_batch, 1, 1)) >>> D = torch.tile(torch.zeros(n_state, n_ctrl), (n_batch, 1, 1)) >>> c1 = torch.tile(torch.randn(n_state), (n_batch, 1)) >>> c2 = torch.tile(torch.zeros(n_state), (n_batch, 1)) >>> x_init = torch.randn(n_batch, n_state) >>> >>> lti = pp.module.LTI(A, B, C, D, c1, c2) >>> LQR = pp.module.LQR(lti, Q, p, T) >>> x, u, cost = LQR(x_init) >>> print("u = ", u) x = tensor([[[-0.2633, -0.3466, 2.3803, -0.0423], [ 0.1849, -1.3884, 1.0898, -1.6229], [ 1.2138, -0.7161, 0.2954, -0.6819], [ 1.4840, -1.1249, -1.0302, 0.9805], [-0.3477, -1.7063, 4.6494, 2.6780]], [[-0.9744, 0.4976, 0.0603, -0.5258], [-0.6356, 0.0539, 0.7264, -0.5048], [-0.2275, -0.1649, 0.3872, -0.4614], [ 0.2697, -0.3576, 0.0999, -0.4594], [ 0.3916, -2.0832, 0.0701, -0.5407]]]) u = tensor([[[ 1.0405, 0.1586, -0.1282], [-1.4845, -0.5745, 0.2523], [-0.6322, -0.3281, -0.3620], [-1.6768, 2.4054, -0.1047], [-1.7948, 3.5269, 9.0703]], [[-0.1795, 0.9153, 1.7066], [ 0.0814, 0.4004, 0.7114], [ 0.0435, 0.5782, 1.0127], [-0.3017, -0.2897, 0.7251], [-0.0728, 0.7290, -0.3117]]]) ''' def __init__(self, system, Q, p, T): super().__init__() self.system = system self.Q, self.p, self.T = Q, p, T if self.Q.ndim == 3: self.Q = torch.tile(self.Q.unsqueeze(-3), (1, self.T, 1, 1)) if self.p.ndim == 2: self.p = torch.tile(self.p.unsqueeze(-2), (1, self.T, 1)) assert self.Q.shape[:-1] == self.p.shape, "Shape not compatible." assert self.Q.size(-1) == self.Q.size(-2), "Shape not compatible." assert self.Q.ndim == 4 or self.p.ndim == 3, "Shape not compatible." assert self.Q.device == self.p.device, "device not compatible." assert self.Q.dtype == self.p.dtype, "tensor data type not compatible."
[docs] def forward(self, x_init): r''' Performs LQR for the linear system. Args: x_init (:obj:`Tensor`): The initial state of the system. Returns: List of :obj:`Tensor`: A list of tensors including the solved state sequence :math:`\mathbf{x}`, the solved input sequence :math:`\mathbf{u}`, and the associated quadratic costs :math:`\mathbf{c}` over the time horizon. ''' K, k = self.lqr_backward() x, u, cost = self.lqr_forward(x_init, K, k) return x, u, cost
def lqr_backward(self): # Q: (B*, T, N, N), p: (B*, T, N), where B* can be any batch dimensions, e.g., (2, 3) B = self.p.shape[:-2] ns, nc = self.system.B.size(-2), self.system.B.size(-1) K = torch.zeros(B + (self.T, nc, ns), dtype=self.p.dtype, device=self.p.device) k = torch.zeros(B + (self.T, nc), dtype=self.p.dtype, device=self.p.device) for t in range(self.T-1, -1, -1): if t == self.T - 1: Qt = self.Q[...,t,:,:] qt = self.p[...,t,:] else: self.system.set_refpoint(t=t) F = torch.cat((self.system.A, self.system.B), dim=-1) Qt = self.Q[...,t,:,:] + F.mT @ V @ F qt = self.p[...,t,:] + bmv(F.mT, v) if self.system.c1 is not None: qt = qt + bmv(F.mT @ V, self.system.c1) Qxx, Qxu = Qt[..., :ns, :ns], Qt[..., :ns, ns:] Qux, Quu = Qt[..., ns:, :ns], Qt[..., ns:, ns:] qx, qu = qt[..., :ns], qt[..., ns:] Quu_inv = torch.linalg.pinv(Quu) K[...,t,:,:] = Kt = - Quu_inv @ Qux k[...,t,:] = kt = - bmv(Quu_inv, qu) V = Qxx + Qxu @ Kt + Kt.mT @ Qux + Kt.mT @ Quu @ Kt v = qx + bmv(Qxu, kt) + bmv(Kt.mT, qu) + bmv(Kt.mT @ Quu, kt) return K, k def lqr_forward(self, x_init, K, k): assert x_init.device == K.device == k.device assert x_init.dtype == K.dtype == k.dtype assert x_init.ndim == 2, "Shape not compatible." B = self.p.shape[:-2] ns, nc = self.system.B.size(-2), self.system.B.size(-1) u = torch.zeros(B + (self.T, nc), dtype=self.p.dtype, device=self.p.device) cost = torch.zeros(B, dtype=self.p.dtype, device=self.p.device) x = torch.zeros(B + (self.T+1, ns), dtype=self.p.dtype, device=self.p.device) x[..., 0, :] = x_init xt = x_init self.system.set_refpoint(t=0) for t in range(self.T): Kt, kt = K[...,t,:,:], k[...,t,:] u[..., t, :] = ut = bmv(Kt, xt) + kt xut = torch.cat((xt, ut), dim=-1) x[...,t+1,:] = xt = self.system(xt, ut)[0] cost = cost + 0.5 * bvmv(xut, self.Q[...,t,:,:], xut) + (xut * self.p[...,t,:]).sum(-1) return x[...,0:-1,:], u, cost

Docs

Access documentation for PyPose

View Docs

Tutorials

Get started with tutorials and examples

View Tutorials

Get Started

Find resources and how to start using pypose

View Resources