Commit 17bb701a authored by Adam Paquette's avatar Adam Paquette Committed by jay
Browse files

Updated sytactic sugar doc strings

parent 6ef0d288
Loading
Loading
Loading
Loading
+85 −69
Original line number Diff line number Diff line
import itertools
import math
import os
import warnings

import dill as pickle
import networkx as nx
import numpy as np
import pandas as pd

from autocnet.control.control import CorrespondenceNetwork
from autocnet.fileio.io_gdal import GeoDataset
from autocnet.fileio import io_hdf
from autocnet.control.control import C
from autocnet.fileio import io_json
from autocnet.fileio import io_utils
from autocnet.fileio.io_gdal import GeoDataset
from autocnet.graph import markov_cluster
from autocnet.matcher.matcher import FlannMatcher
import autocnet.matcher.suppression_funcs as spf
from autocnet.graph.edge import Edge
from autocnet.graph.node import Node
from autocnet.matcher.matcher import FlannMatcher
from autocnet.graph import markov_cluster
from autocnet.vis.graph_view import plot_graph


@@ -46,6 +44,7 @@ class CandidateGraph(nx.Graph):
        self.node_counter = 0
        node_labels = {}
        self.node_name_map = {}
        self.graph_masks = pd.DataFrame()

        for node_name in self.nodes():
            image_name = os.path.basename(node_name)
@@ -67,6 +66,11 @@ class CandidateGraph(nx.Graph):
            e = self.edge[s][d]
            e.source = self.node[s]
            e.destination = self.node[d]
            #del self.adj[d][s]

        # Add the Edge class as a edge data structure
        #for s, d, edge in self.edges_iter(data=True):
            #self.edge[s][d] = Edge(self.node[s], self.node[d])

    @classmethod
    def from_graph(cls, graph):
@@ -103,8 +107,11 @@ class CandidateGraph(nx.Graph):
        : object
          A Network graph object
        """
        if isinstance(filelist, str):
            filelist = io_utils.file_to_list(filelist)
        if not isinstance(filelist, list):
            with open(filelist, 'r') as f:
                filelist = f.readlines()
                filelist = map(str.rstrip, filelist)
                filelist = filter(None, filelist)

        # TODO: Reject unsupported file formats + work with more file formats
        if basepath:
@@ -114,28 +121,23 @@ class CandidateGraph(nx.Graph):

        # This is brute force for now, could swap to an RTree at some point.
        adjacency_dict = {}
        valid_datasets = []

        for i in datasets:
        for i, j in itertools.permutations(datasets,2):
            if not i.file_name in adjacency_dict.keys():
                adjacency_dict[i.file_name] = []

            fp = i.footprint
            if fp and fp.IsValid():
                valid_datasets.append(i)
            else:
                warnings.warn('Missing or invalid geospatial data for {}'.format(i.base_name))
            if not j.file_name in adjacency_dict.keys():
                adjacency_dict[j.file_name] = []

            # Grab the footprints and test for intersection
        for i, j in itertools.permutations(valid_datasets, 2):
            i_fp = i.footprint
            j_fp = j.footprint

            try:
                if i_fp.Intersects(j_fp):
                if j_fp and i_fp and i_fp.Intersects(j_fp):
                    adjacency_dict[i.file_name].append(j.file_name)
                    adjacency_dict[j.file_name].append(i.file_name)
            except:
                warnings.warn('Failed to calculated intersection between {} and {}'.format(i, j))
                warnings.warn('No or incorrect geospatial information for {} and/or {}'.format(i, j))

        return cls(adjacency_dict)

@@ -298,11 +300,6 @@ class CandidateGraph(nx.Graph):
            descriptors = node.descriptors
            # Load the neighbors of the current node into the FLANN matcher
            neighbors = self.neighbors(i)

            # if node has no neighbors, skip
            if not neighbors:
                continue

            for n in neighbors:
                neighbor_descriptors = self.node[n].descriptors
                self._fl.add(neighbor_descriptors, n)
@@ -368,33 +365,7 @@ class CandidateGraph(nx.Graph):
        """
        _, self.clusters = func(self, *args, **kwargs)

    def compute_triangular_cycles(self):
        """
        Find all cycles of length 3.  This is similar
         to cycle_basis (networkX), but returns all cycles.
         As opposed to all basis cycles.

        Returns
        -------
        cycles : list
                 A list of cycles in the form [(a,b,c), (c,d,e)],
                 where letrers indicate node identifiers

        Examples
        --------
        >>> g = CandidateGraph()
        >>> g.add_edges_from([(0,1), (0,2), (1,2), (0,3), (1,3), (2,3)])
        >>> g.compute_triangular_cycles()
        [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]
        """
        cycles = []
        for s, d in self.edges_iter():
            for n in self.nodes():
                if(s,n) in self.edges() and (d,n) in self.edges():
                    cycles.append((s,d,n))
        return cycles

    def apply_func_to_edges(self, function, *args, **kwargs):
    def apply_func_to_edges(self, function, *args, graph_mask_keys=[], **kwargs):
        """
        Iterates over edges using an optional mask and and applies the given function.
        If func is not an attribute of Edge, raises AttributeError
@@ -405,17 +376,71 @@ class CandidateGraph(nx.Graph):
        graph_mask_keys : list
                          of keys in graph_masks
        """

        if graph_mask_keys:
            merged_graph_mask = self.graph_masks[graph_mask_keys].all(axis=1)
            edges_to_iter = merged_graph_mask[merged_graph_mask].index
        else:
            edges_to_iter = self.edges()

        if not isinstance(function, str):
            function = function.__name__

        for s, d, edge in self.edges_iter(data=True):
        for s, d in edges_to_iter:
            curr_edge = self.get_edge_data(s, d)
            try:
                func = getattr(edge, function)
                func = getattr(curr_edge, function)
            except:
                raise AttributeError(function, ' is not an attribute of Edge')
            else:
                func(*args, **kwargs)

    def symmetry_checks(self):
        '''
        Apply a symmetry check to all edges in the graph
        '''
        self.apply_func_to_edges('symmetry_check')

    def ratio_checks(self, *args, **kwargs):
        '''
        Apply a ratio check to all edges in the graph
        See Also
        --------
        matcher.outlier_detector.DistanceRatio.compute
        outlier_detector.DistanceRatio.compute
        DistanceRatio.compute
        '''
        self.apply_func_to_edges('ratio_check', *args, **kwargs)

    def compute_homographies(self, *args, **kwargs):
        '''
        Compute homographies for all edges using identical parameters
        Parameters: method = '', clean_keys = [], pid=None
        '''
        self.apply_func_to_edges('compute_homography', *args, **kwargs)

    def compute_fundamental_matrices(self, *args, **kwargs):
        '''
        Compute fundmental matrices for all edges using identical parameters
        Parameters: clean_keys=[], method = '', reproj_threshold=5.0, confidence=0.99
        '''
        self.apply_func_to_edges('compute_fundamental_matrix', *args, **kwargs)

    def subpixel_register(self, *args, **kwargs):
        '''
        Compute subpixel offsets for all edges using identical parameters
        Parameters: clean_keys=[], threshold=0.8, upsampling=10, template_size=9, search_size=27, tiled=False,
        max_x_shift=1.0, max_y_shift=1.0
        '''
        self.apply_func_to_edges('subpixel_register', *args, **kwargs)

    def suppress(self, *args, **kwargs):
        '''
        Apply a metric of point suppression to the graph
        Parameters: clean_keys=[], min_radius=2, k=50, error_k=0.1
        '''
        self.apply_func_to_edges('suppress', *args, **kwargs)

    def minimum_spanning_tree(self):
        """
        Calculates the minimum spanning tree of the graph
@@ -427,8 +452,11 @@ class CandidateGraph(nx.Graph):
           boolean mask for edges in the minimum spanning tree
        """

        graph_mask = pd.Series(False, index=self.edges())
        self.graph_masks['mst'] = graph_mask

        mst = nx.minimum_spanning_tree(self)
        return self.create_edge_subgraph(mst.edges())
        self.graph_masks['mst'][mst.edges()] = True

    def to_filelist(self):
        """
@@ -445,18 +473,6 @@ class CandidateGraph(nx.Graph):
            filelist.append(node.image_path)
        return filelist

    def get_cnet(self, clean_keys=[]):
        cn = CorrespondenceNetwork()

        for s, d, edge in self.edges_iter(data=True):
            if clean_keys:
                matches, _ = edge._clean(clean_keys)
            else:
                matches = edge.matches
            cn.add_correspondences(edge, matches)

        return cn

    def to_cnet(self, clean_keys=[], isis_serials=False):
        """
        Generate a control network (C) object from a graph