Commit f76e0b3e authored by jlaura's avatar jlaura
Browse files

Merge pull request #58 from Kelvinrr/master

Adding Minimum Spanning Tree Method to CandidateGraph
parents cfbcb66b 74bd8602
Loading
Loading
Loading
Loading
+40 −0
Original line number Diff line number Diff line
@@ -45,6 +45,7 @@ class CandidateGraph(nx.Graph):
        self.node_counter = 0
        node_labels = {}
        self.node_name_map = {}
        self.graph_masks = {}

        # the node_name is the relative path for the image
        for node_name, node in self.nodes_iter(data=True):
@@ -295,6 +296,45 @@ class CandidateGraph(nx.Graph):
        """
        _, self.clusters = func(self, *args, **kwargs)

    def apply_func_to_edges(self, func, *args, graph_mask_keys=[], **kwargs):
        """
        Iterates over edges using an optional mask and and applies the given function.
        If func is not an attribute of Edge, raises AttributeError
        Parameters
        ----------
        func : string
               function to be called on every edge
        graph_mask_keys : list
                          of keys in graph_masks
        """

        merged_graph_mask = self.graph_masks[graph_mask_keys].all(axis=1)
        edges_to_iter = merged_graph_mask[merged_graph_mask].index

        for s, d in edges_to_iter:
            curr_edge = self.get_edge_data(s, d)
            try:
                function = getattr(curr_edge, func)
            except:
                raise AttributeError('The passed function is not an attribute of Edge')
            else:
                function(*args, **kwargs)

    def minimum_spanning_tree(self):
        """
        Calculates the minimum spanning tree of the graph

        Returns
        -------

         : DataFrame
           boolean mask for edges in the minimum spanning tree
        """

        self.graph_masks = pd.DataFrame(False, index=self.edges(), columns=['mst'])
        mst = nx.minimum_spanning_tree(self)
        self.graph_masks['mst'][mst.edges()] = True

    def symmetry_checks(self):
        """
        Perform a symmetry check on all edges in the graph
+41 −0
Original line number Diff line number Diff line
@@ -40,6 +40,23 @@ class TestCandidateGraph(unittest.TestCase):
    def test_island_nodes(self):
        self.assertEqual(len(self.disconnected_graph.island_nodes()), 1)

    def test_apply_func_to_edges(self):
        graph = self.graph.copy()
        graph.minimum_spanning_tree()

        try:
            graph.apply_func_to_edges('incorrect_func')
        except AttributeError:
            pass

        graph.extract_features(extractor_parameters={'nfeatures': 500})
        graph.match_features()
        graph.apply_func_to_edges("symmetry_check", graph_mask_keys=['mst'])

        self.assertFalse(graph[0][2].masks['symmetry'].all())
        self.assertFalse(graph[0][1].masks['symmetry'].all())
        self.assertTrue(graph[1][2].masks['symmetry'].all())

    def test_connected_subgraphs(self):
        subgraph_list = self.disconnected_graph.connected_subgraphs()
        self.assertEqual(len(subgraph_list), 2)
@@ -99,3 +116,27 @@ class TestEdge(unittest.TestCase):
    def setUpClass(cls):
        cls.graph = network.CandidateGraph.from_adjacency(get_path('adjacency.json'))


class TestGraphMasks(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.test_dict = {"0": ["4", "2", "1", "3"],
                         "1": ["0", "3", "2", "6", "5"],
                         "2": ["1", "0", "3", "4", "7"],
                         "3": ["2", "0", "1", "5"],
                         "4": ["2", "0"],
                         "5": ["1", "3"],
                         "6": ["1"],
                         "7": ["2"]}

        cls.graph = network.CandidateGraph.from_adjacency(cls.test_dict)
        cls.graph.minimum_spanning_tree()
        cls.mst_graph = cls.graph.copy()

        for s, d, edge in cls.graph.edges_iter(data=True):
            if not cls.graph.graph_masks['mst'][(s, d)]:
                cls.mst_graph.remove_edge(s, d)

    def test_mst_output(self):
        self.assertEqual(self.mst_graph.nodes(), self.graph.nodes())
        self.assertEqual(self.mst_graph.number_of_edges(), self.graph.number_of_edges()-5)
 No newline at end of file