Commit 84326679 authored by Jay's avatar Jay Committed by jay
Browse files

Updates correspondence network generation and IO.

parent 16609f4a
Loading
Loading
Loading
Loading
+70 −1
Original line number Diff line number Diff line
import collections
from time import gmtime, strftime


class Point(object):
    """
    An n-image correspondence container class to store
    information common to all identical correspondences across
    an image set.

    Attributes
    ----------
    point_id : int
               A unique identifier for the given point

    subpixel : bool
               Whether or not the point has been subpixel registered

    point_type : an ISIS identifier for the type of the point
                 as defined in the ISIS protobuf spec.

    correspondences : list
                      of image correspondences
    """
    __slots__ = '_subpixel', 'point_id', 'point_type', 'correspondences'

@@ -35,7 +52,29 @@ class Point(object):


class Correspondence(object):
    """
    A single correspondence (image measure).

    Attributes
    ----------

    id : int
         The index of the point in a matches dataframe (stored as an edge attribute)

    x : float
        The x coordinate of the measure in image space

    y : float
        The y coordinate of the measure in image space

    measure_type : int
                   The ISIS measure type as per the protobuf spec

    serial : str
             A unique serial number for the image the measure corresponds to
             In the case of an ISIS cube, this is a valid ISIS serial number,
             else, None.
    """
    __slots__ = 'id', 'x', 'y', 'measure_type', 'serial'

    def __init__(self, id, x, y, measure_type=2, serial=None):
@@ -45,7 +84,6 @@ class Correspondence(object):
        self.measure_type = measure_type
        self.serial = serial


    def __repr__(self):
        return str(self.id)

@@ -57,6 +95,37 @@ class Correspondence(object):


class CorrespondenceNetwork(object):
    """
    A container of points and associated correspondences.  The primary
    data structures are point_to_correspondence and correspondence_to_point.
    These two attributes store the mapping between point and correspondences.

    Attributes
    ----------
    point_to_correspondence : dict
                              with key equal to an instance of the Point class and
                              values equal to a list of Correspondences.

    correspondence_to_point : dict
                              with key equal to a correspondence identifier (not the class) and
                              value equal to a unique point_id (not an instance of the Point class).
                              This attribute serves as a low memory reverse lookup table

    point_id : int
               The current 'new' point id if an additional point were to be adde

    n_points : int
               The number of points in the CorrespondenceNetwork

    n_measures : int
                 The number of Correspondences in the CorrespondenceNetwork

    creationdate : str
                   The date the instance of this class was first instantiated

    modifieddata : str
                   The date this class last had correspondences and/or points added
    """
    def __init__(self):
        self.point_to_correspondence = collections.defaultdict(list)
        self.correspondence_to_point = {}
+31 −10
Original line number Diff line number Diff line
@@ -3,7 +3,13 @@ import sys
from time import gmtime, strftime
import unittest

from unittest.mock import Mock, MagicMock

from autocnet.graph.edge import Edge
from autocnet.graph.node import Node

import numpy as np
import pandas as pd

sys.path.insert(0, os.path.abspath('..'))

@@ -12,24 +18,39 @@ from autocnet.control import control

class TestC(unittest.TestCase):

    def setUp(self):
        x = list(range(10))
        y = list(range(10))
        pid = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2]
        nid = [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
    @classmethod
    def setUpClass(cls):
        npts = 10
        coords = pd.DataFrame(np.arange(npts * 2).reshape(-1, 2))
        source = np.zeros(npts)
        destination = np.ones(npts)
        pid = np.arange(npts)

        matches = pd.DataFrame(np.vstack((source, pid, destination, pid)).T, columns=['source_image',
                                                                                      'source_idx',
                                                                                      'destination_image',
                                                                                      'destination_idx'])

        edge = Mock(spec=Edge)
        edge.source = Mock(spec=Node)
        edge.destination = Mock(spec=Node)
        edge.source.isis_serial = None
        edge.destination.isis_serial = None
        edge.source.get_keypoint_coordinates = MagicMock(return_value=coords)
        edge.destination.get_keypoint_coordinates = MagicMock(return_value=coords)

        data = np.array([x, y, pid, nid]).T
        cls.C = control.CorrespondenceNetwork()
        cls.C.add_correspondences(edge, matches)

        self.C = control.C(data, columns=['x', 'y', 'pid', 'nid'])

    def test_n_point(self):
        self.assertEqual(self.C.n, 4)
        self.assertEqual(self.C.n_points, 10)

    def test_n_measures(self):
        self.assertEqual(self.C.m, 10)
        self.assertEqual(self.C.n_measures, 20)

    def test_modified_date(self):
        self.assertEqual(self.C.modifieddate, 'Not modified')
        self.assertIsInstance(self.C.modifieddate, str)

    def test_creation_date(self):
        self.assertEqual(self.C.creationdate, strftime("%Y-%m-%d %H:%M:%S", gmtime()))
+2 −2
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ def write_filelist(lst, path="fromlist.lis"):
        handle.write('\n')
    return


def to_isis(path, C, mode='w', version=VERSION,
            headerstartbyte=HEADERSTARTBYTE,
            networkid='None', targetname='None',
@@ -82,7 +83,6 @@ def to_isis(path, C, mode='w', version=VERSION,
                                                                           point_sizes)
            # Write the buffer header
            store.write(buffer_header, HEADERSTARTBYTE)

            # Then write the points, so we know where to start writing, + 1 to avoid overwrite
            point_start_offset = HEADERSTARTBYTE + buffer_header_size
            for i, point in enumerate(point_messages):
+49 −44
Original line number Diff line number Diff line
import os
from time import gmtime, strftime
import unittest
import sys
sys.path.insert(0, os.path.abspath('..'))

import unittest
from unittest.mock import Mock, MagicMock
import numpy as np
import pandas as pd
import pvl
@@ -12,44 +10,50 @@ from .. import io_controlnetwork
from .. import ControlNetFileV0002_pb2 as cnf

from autocnet.utils.utils import find_in_dict
from autocnet.control.control import C

class TestWriteIsisControlNetwork(unittest.TestCase):

    def setUp(self):
        """
        Not 100% sure how to mock in the DF without creating lots of methods...
        """

        serial_times = {295: '1971-07-31T01:24:11.754',
                   296: '1971-07-31T01:24:36.970',
                   297: '1971-07-31T01:25:02.243',
                   298: '1971-07-31T01:25:27.457',
                   299: '1971-07-31T01:25:52.669',
                   300: '1971-07-31T01:26:17.923'}
        self.serials = ['APOLLO15/METRIC/{}'.format(i) for i in serial_times.values()]


        x = list(range(5))
        y = list(range(5))
        pid = [0,0,1,1,1]
        idx = pid
        serials = [self.serials[0], self.serials[1], self.serials[2],
                   self.serials[2], self.serials[3]]
from autocnet.control.control import CorrespondenceNetwork
from autocnet.graph.edge import Edge
from autocnet.graph.node import Node

sys.path.insert(0, os.path.abspath('..'))

        columns = ['x', 'y', 'idx', 'pid', 'nid', 'point_type']
        self.data_length = 5

        data = [x,y, idx, pid, serials, [2] * self.data_length]
class TestWriteIsisControlNetwork(unittest.TestCase):

        self.creation_time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
        cnet = C(data, index=columns).T
    @classmethod
    def setUpClass(cls):

        serial_times = {295: '1971-07-31T01:24:11.754',
                        296: '1971-07-31T01:24:36.970'}
        cls.serials = ['APOLLO15/METRIC/{}'.format(i) for i in serial_times.values()]

        # Create an edge and a set of matches
        cls.npts = 5
        coords = pd.DataFrame(np.arange(cls.npts * 2).reshape(-1, 2))
        source = np.zeros(cls.npts)
        destination = np.ones(cls.npts)
        pid = np.arange(cls.npts)

        matches = pd.DataFrame(np.vstack((source, pid, destination, pid)).T, columns=['source_image',
                                                                                      'source_idx',
                                                                                      'destination_image',
                                                                                      'destination_idx'])

        edge = Mock(spec=Edge)
        edge.source = Mock(spec=Node)
        edge.destination = Mock(spec=Node)
        edge.source.isis_serial = cls.serials[0]
        edge.destination.isis_serial = cls.serials[1]
        edge.source.get_keypoint_coordinates = MagicMock(return_value=coords)
        edge.destination.get_keypoint_coordinates = MagicMock(return_value=coords)

        cnet = CorrespondenceNetwork()
        cnet.add_correspondences(edge, matches)
        cls.creation_date = cnet.creationdate
        cls.modified_date = cnet.modifieddate
        io_controlnetwork.to_isis('test.net', cnet, mode='wb', targetname='Moon')

        self.header_message_size = 85
        self.point_start_byte = 65621
        cls.header_message_size = 98
        cls.point_start_byte = 65634

    def test_create_buffer_header(self):
        with open('test.net', 'rb') as f:
@@ -63,20 +67,20 @@ class TestWriteIsisControlNetwork(unittest.TestCase):
            self.assertEqual('Moon', header_protocol.targetName)
            self.assertEqual(io_controlnetwork.DEFAULTUSERNAME,
                             header_protocol.userName)
            self.assertEqual(self.creation_time,
            self.assertEqual(self.creation_date,
                             header_protocol.created)
            self.assertEqual('None', header_protocol.description)
            self.assertEqual('Not modified', header_protocol.lastModified)
            self.assertEqual(self.modified_date, header_protocol.lastModified)

            #Repeating
            self.assertEqual([135, 199], header_protocol.pointMessageSizes)
            self.assertEqual([135] * self.npts, header_protocol.pointMessageSizes)

    def test_create_point(self):
        with open('test.net', 'rb') as f:

            with open('test.net', 'rb') as f:
                f.seek(self.point_start_byte)
                for i, length in enumerate([135, 199]):
                for i, length in enumerate([135] * self.npts):
                    point_protocol = cnf.ControlPointFileEntryV0002()
                    raw_point = f.read(length)
                    point_protocol.ParseFromString(raw_point)
@@ -90,16 +94,17 @@ class TestWriteIsisControlNetwork(unittest.TestCase):
        pvl_header = pvl.load('test.net')

        npoints = find_in_dict(pvl_header, 'NumberOfPoints')
        self.assertEqual(2, npoints)
        self.assertEqual(5, npoints)

        mpoints = find_in_dict(pvl_header, 'NumberOfMeasures')
        self.assertEqual(5, mpoints)
        self.assertEqual(10, mpoints)

        points_bytes = find_in_dict(pvl_header, 'PointsBytes')
        self.assertEqual(334, points_bytes)
        self.assertEqual(675, points_bytes)

        points_start_byte = find_in_dict(pvl_header, 'PointsStartByte')
        self.assertEqual(65621, points_start_byte)
        self.assertEqual(65634, points_start_byte)

    def tearDown(self):
    @classmethod
    def tearDownClass(cls):
        os.remove('test.net')
 No newline at end of file
+86 −40
Original line number Diff line number Diff line
import itertools
import math
import os
import warnings

import dill as pickle
import networkx as nx
import numpy as np
import pandas as pd

from autocnet.fileio.io_gdal import GeoDataset
from autocnet.control.control import CorrespondenceNetwork
from autocnet.fileio import io_hdf
from autocnet.control.control import C
from autocnet.fileio import io_json
from autocnet.matcher.matcher import FlannMatcher
import autocnet.matcher.suppression_funcs as spf
from autocnet.fileio import io_utils
from autocnet.fileio.io_gdal import GeoDataset
from autocnet.graph import markov_cluster
from autocnet.graph.edge import Edge
from autocnet.graph.node import Node
from autocnet.graph import markov_cluster
from autocnet.matcher.matcher import FlannMatcher
from autocnet.vis.graph_view import plot_graph


@@ -35,6 +37,9 @@ class CandidateGraph(nx.Graph):
    clusters : dict
               of clusters with key as the cluster id and value as a
               list of node indices

    cn : object
         A control network object instantiated by calling generate_cnet.
    ----------
    """
    edge_attr_dict_factory = Edge
@@ -44,7 +49,6 @@ class CandidateGraph(nx.Graph):
        self.node_counter = 0
        node_labels = {}
        self.node_name_map = {}
        self.graph_masks = pd.DataFrame()

        for node_name in self.nodes():
            image_name = os.path.basename(node_name)
@@ -66,11 +70,6 @@ class CandidateGraph(nx.Graph):
            e = self.edge[s][d]
            e.source = self.node[s]
            e.destination = self.node[d]
            #del self.adj[d][s]

        # Add the Edge class as a edge data structure
        #for s, d, edge in self.edges_iter(data=True):
            #self.edge[s][d] = Edge(self.node[s], self.node[d])

    @classmethod
    def from_graph(cls, graph):
@@ -107,11 +106,8 @@ class CandidateGraph(nx.Graph):
        : object
          A Network graph object
        """
        if not isinstance(filelist, list):
            with open(filelist, 'r') as f:
                filelist = f.readlines()
                filelist = map(str.rstrip, filelist)
                filelist = filter(None, filelist)
        if isinstance(filelist, str):
            filelist = io_utils.file_to_list(filelist)

        # TODO: Reject unsupported file formats + work with more file formats
        if basepath:
@@ -121,23 +117,28 @@ class CandidateGraph(nx.Graph):

        # This is brute force for now, could swap to an RTree at some point.
        adjacency_dict = {}
        valid_datasets = []

        for i, j in itertools.permutations(datasets,2):
            if not i.file_name in adjacency_dict.keys():
        for i in datasets:
            adjacency_dict[i.file_name] = []
            if not j.file_name in adjacency_dict.keys():
                adjacency_dict[j.file_name] = []

            fp = i.footprint
            if fp and fp.IsValid():
                valid_datasets.append(i)
            else:
                warnings.warn('Missing or invalid geospatial data for {}'.format(i.base_name))

        # Grab the footprints and test for intersection
        for i, j in itertools.permutations(valid_datasets, 2):
            i_fp = i.footprint
            j_fp = j.footprint

            try:
                if j_fp and i_fp and i_fp.Intersects(j_fp):
                if i_fp.Intersects(j_fp):
                    adjacency_dict[i.file_name].append(j.file_name)
                    adjacency_dict[j.file_name].append(i.file_name)
            except:
                warnings.warn('No or incorrect geospatial information for {} and/or {}'.format(i, j))
                warnings.warn('Failed to calculated intersection between {} and {}'.format(i, j))

        return cls(adjacency_dict)

@@ -300,6 +301,11 @@ class CandidateGraph(nx.Graph):
            descriptors = node.descriptors
            # Load the neighbors of the current node into the FLANN matcher
            neighbors = self.neighbors(i)

            # if node has no neighbors, skip
            if not neighbors:
                continue

            for n in neighbors:
                neighbor_descriptors = self.node[n].descriptors
                self._fl.add(neighbor_descriptors, n)
@@ -365,31 +371,51 @@ class CandidateGraph(nx.Graph):
        """
        _, self.clusters = func(self, *args, **kwargs)

    def apply_func_to_edges(self, function, *args, graph_mask_keys=[], **kwargs):
    def compute_triangular_cycles(self):
        """
        Find all cycles of length 3.  This is similar
         to cycle_basis (networkX), but returns all cycles.
         As opposed to all basis cycles.

        Returns
        -------
        cycles : list
                 A list of cycles in the form [(a,b,c), (c,d,e)],
                 where letters indicate node identifiers

        Examples
        --------
        >>> g = CandidateGraph()
        >>> g.add_edges_from([(0,1), (0,2), (1,2), (0,3), (1,3), (2,3)])
        >>> g.compute_triangular_cycles()
        [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]
        """
        cycles = []
        for s, d in self.edges_iter():
            for n in self.nodes():
                if(s,n) in self.edges() and (d,n) in self.edges():
                    cycles.append((s,d,n))
        return cycles

    def apply_func_to_edges(self, function, *args, **kwargs):
        """
        Iterates over edges using an optional mask and and applies the given function.
        If func is not an attribute of Edge, raises AttributeError

        Parameters
        ----------
        function : obj
                   function to be called on every edge

        graph_mask_keys : list
                          of keys in graph_masks
        """

        if graph_mask_keys:
            merged_graph_mask = self.graph_masks[graph_mask_keys].all(axis=1)
            edges_to_iter = merged_graph_mask[merged_graph_mask].index
        else:
            edges_to_iter = self.edges()

        if not isinstance(function, str):
            function = function.__name__

        for s, d in edges_to_iter:
            curr_edge = self.get_edge_data(s, d)
        for s, d, edge in self.edges_iter(data=True):
            try:
                func = getattr(curr_edge, function)
                func = getattr(edge, function)
            except:
                raise AttributeError(function, ' is not an attribute of Edge')
            else:
@@ -463,11 +489,8 @@ class CandidateGraph(nx.Graph):
           boolean mask for edges in the minimum spanning tree
        """

        graph_mask = pd.Series(False, index=self.edges())
        self.graph_masks['mst'] = graph_mask

        mst = nx.minimum_spanning_tree(self)
        self.graph_masks['mst'][mst.edges()] = True
        return self.create_edge_subgraph(mst.edges())

    def to_filelist(self):
        """
@@ -484,6 +507,29 @@ class CandidateGraph(nx.Graph):
            filelist.append(node.image_path)
        return filelist

    def generate_cnet(self, clean_keys=[]):
        """
        Compute (or re-compute) a CorrespondenceNetwork attribute

        Parameters
        ----------
        clean_keys : list
                     of string clean keys to mask correspondences

        See Also
        --------
        autocnet.control.control.CorrespondenceNetwork

        """
        self.cn = CorrespondenceNetwork()

        for s, d, edge in self.edges_iter(data=True):
            if clean_keys:
                matches, _ = edge._clean(clean_keys)
            else:
                matches = edge.matches
            self.cn.add_correspondences(edge, matches)

    def to_cnet(self, clean_keys=[], isis_serials=False):
        """
        Generate a control network (C) object from a graph
@@ -586,7 +632,7 @@ class CandidateGraph(nx.Graph):

            columns = ['x', 'y', 'idx', 'pid', 'nid', 'mid', 'point_type']

            cnet = C(values, columns=columns)
            cnet = CorrespondenceNetwork(values, columns=columns)

            if merged_cnet is None:
                merged_cnet = cnet.copy(deep=True)
Loading