Commit e99e3d25 authored by Jesse Mapel's avatar Jesse Mapel Committed by Summer Stapleton
Browse files

Adds Transformation object (#162)

* First pass at transformation

* More transformation tests

* Updated to set _parent instead of parent
parent 17db1c82
Loading
Loading
Loading
Loading

ale/transformation.py

0 → 100644
+136 −0
Original line number Diff line number Diff line
import numpy as np

from ale.rotation import ConstantRotation

class FrameNode():
    """
    A single frame in a frame tree. This class is largely adapted from the Node
    class in the vispy scene graph implementation.

    Attributes
    __________
    id : int
         The NAIF ID code for the frame
    parent : FrameNode
             The parent node in the frame tree
    children : List of FrameNode
               The children nodes in the frame tree
    rotation : ConstantRotation or TimeDependentRotation
               The rotation from this frame to the frame of the parent node
    """

    def __init__(self, id, parent=None, rotation=None):
        """
        Construct a frame node with or without a parent. If a parent is specified
        then a rotation from this frame to the parent node's frame must be
        specified and visa-versa.

        Parameters
        ----------
        id : int
             The NAIF ID code for the frame
        parent : FrameNode
                 The parent node in the frame chain
        rotation : ConstantRotation or TimeDependentRotation
                   The rotation from this frame to the frame of the parent node
        """

        if (parent is None) != (rotation is None):
            raise TypeError("parent and rotation must both be entered or both be None.")

        self.children = []
        self.id = id
        if parent is not None:
            self._parent = parent
        if rotation is not None:
            self.rotation = rotation

    def __del__(self):
        """
        Custom deletor for a FrameNode. The child node is always responsible
        for updating the parent nodes. So this removes this node from the
        children of its parent node.
        """
        if self.parent is not None:
            self.parent.children.remove(self)

    @property
    def parent(self):
        """
        The parent node of this node. Returns None if this is a root node.
        """
        if hasattr(self, '_parent'):
            return self._parent
        else:
            return None

    @parent.setter
    def parent(self, new_parent):
        """
        Sets a new parent node. The child node is always responsible for
        updating the parent node. So, this removes this node from the children
        of the old parent and adds it to the children of the new parent.
        """
        if self.parent is not None:
            self.parent.children.remove(self)

        new_parent.children.append(self)
        self._parent = new_parent

    def parent_nodes(self):
        """
        Returns the ordered list of parents starting with this node going to
        the root node.
        """
        chain = [self]
        current_parent = self.parent
        while current_parent is not None:
            chain.append(current_parent)
            current_parent = current_parent.parent
        return chain

    def path_to(self, other):
        """
        Returns the path to another node as two lists. The first list
        starts with this node and ends with the common parent. The second
        list contains the remainder of the path.

        Parameters
        ----------
        other : FrameNode
                The other node to find the path to.
        """
        parents_1 = self.parent_nodes()
        parents_2 = other.parent_nodes()
        common_parent = None
        for node in parents_1:
            if node in parents_2:
                common_parent = node
                break
        if common_parent is None:
            raise RuntimeError('No common parent between nodes')

        first_path = parents_1[:parents_1.index(common_parent)+1]
        second_path = parents_2[:parents_2.index(common_parent)][::-1]

        return first_path, second_path

    def rotation_to(self, other):
        """
        Returns the rotation to another node. Returns the identity rotation
        if the other node is this node.

        Parameters
        ----------
        other : FrameNode
                The other node to find the rotation to.
        """
        if other == self:
            return ConstantRotation(np.array([0, 0, 0, 1]), self.id, other.id)
        forward_path, reverse_path = self.path_to(other)
        rotations = [node.rotation for node in forward_path[:-1]]
        rotations.extend([node.rotation.inverse() for node in reverse_path])
        rotation = rotations[0]
        for next_rotation in rotations[1:]:
            rotation = next_rotation * rotation
        return rotation
+119 −0
Original line number Diff line number Diff line
import pytest

import numpy as np
from ale.rotation import ConstantRotation
from ale.transformation import FrameNode

@pytest.fixture
def frame_tree():
    """
    Test frame tree structure:

          1
         / \
        /   \
       2     4
      /
     /
    3
    """
    rotations = [
        ConstantRotation(np.array([1, 0, 0, 0]), 2, 1),
        ConstantRotation(np.array([1.0/np.sqrt(2), 0, 0, 1.0/np.sqrt(2)]), 3, 2),
        ConstantRotation(np.array([1.0/np.sqrt(2), 0, 0, 1.0/np.sqrt(2)]), 4, 1)
    ]
    root_node = FrameNode(1)
    child_node_1 = FrameNode(2, parent = root_node, rotation = rotations[0])
    child_node_2 = FrameNode(3, parent = child_node_1, rotation = rotations[1])
    child_node_3 = FrameNode(4, parent = root_node, rotation = rotations[2])
    nodes = [
        root_node,
        child_node_1,
        child_node_2,
        child_node_3
    ]
    return (nodes, rotations)

def test_parent_nodes(frame_tree):
    nodes, _ = frame_tree
    root_parents = nodes[0].parent_nodes()
    child_1_parents = nodes[1].parent_nodes()
    child_2_parents = nodes[2].parent_nodes()
    child_3_parents = nodes[3].parent_nodes()

    assert root_parents == [nodes[0]]
    assert child_1_parents == [nodes[1], nodes[0]]
    assert child_2_parents == [nodes[2], nodes[1], nodes[0]]
    assert child_3_parents == [nodes[3], nodes[0]]

def test_path_to_parent(frame_tree):
    nodes, _ = frame_tree
    forward_path, reverse_path = nodes[2].path_to(nodes[0])
    assert forward_path == [nodes[2], nodes[1], nodes[0]]
    assert reverse_path == []

def test_path_to_common_parent(frame_tree):
    nodes, _ = frame_tree
    forward_path, reverse_path = nodes[2].path_to(nodes[3])
    assert forward_path == [nodes[2], nodes[1], nodes[0]]
    assert reverse_path == [nodes[3]]

def test_path_to_child(frame_tree):
    nodes, _ = frame_tree
    forward_path, reverse_path = nodes[0].path_to(nodes[3])
    assert forward_path == [nodes[0]]
    assert reverse_path == [nodes[3]]

def test_path_to_self():
    node = FrameNode(1)
    forward_path, reverse_path = node.path_to(node)
    assert forward_path == [node]
    assert reverse_path == []


def test_parent_rotation(frame_tree):
    nodes, rotations = frame_tree
    child_to_root = nodes[1].rotation_to(nodes[0])
    root_to_child = nodes[0].rotation_to(nodes[1])

    assert child_to_root.source == 2
    assert child_to_root.dest == 1
    np.testing.assert_equal(child_to_root.quat, rotations[0].quat)
    assert root_to_child.source == 1
    assert root_to_child.dest == 2
    np.testing.assert_equal(root_to_child.quat, rotations[0].inverse().quat)

def test_grand_parent_rotation(frame_tree):
    nodes, rotations = frame_tree
    child_2_to_root = nodes[2].rotation_to(nodes[0])
    root_to_child_2 = nodes[0].rotation_to(nodes[2])

    assert child_2_to_root.source == 3
    assert child_2_to_root.dest == 1
    expected_rotation_1 = rotations[0] * rotations[1]
    np.testing.assert_equal(child_2_to_root.quat, expected_rotation_1.quat)
    assert root_to_child_2.source == 1
    assert root_to_child_2.dest == 3
    expected_rotation_2 = rotations[1].inverse() * rotations[0].inverse()
    np.testing.assert_equal(root_to_child_2.quat, expected_rotation_2.quat)

def test_common_parent_rotation(frame_tree):
    nodes, rotations = frame_tree
    child_2_to_child_3 = nodes[2].rotation_to(nodes[3])
    child_3_to_child_2 = nodes[3].rotation_to(nodes[2])

    assert child_2_to_child_3.source == 3
    assert child_2_to_child_3.dest == 4
    expected_rotation_1 = rotations[2].inverse() * rotations[0] * rotations[1]
    np.testing.assert_equal(child_2_to_child_3.quat, expected_rotation_1.quat)
    assert child_3_to_child_2.source == 4
    assert child_3_to_child_2.dest == 3
    expected_rotation_2 = rotations[1].inverse() * rotations[0].inverse() * rotations[2]
    np.testing.assert_equal(child_3_to_child_2.quat, expected_rotation_2.quat)

def test_self_rotation():
    node = FrameNode(1)
    rotation = node.rotation_to(node)
    assert rotation.source == 1
    assert rotation.dest == 1
    np.testing.assert_equal(rotation.quat, np.array([0, 0, 0, 1]))