Commit 03582b86 authored by Kelvinrr's avatar Kelvinrr
Browse files

added mst functionality to network.py

parent 5b423c31
Loading
Loading
Loading
Loading
+40 −0
Original line number Diff line number Diff line
@@ -45,6 +45,7 @@ class CandidateGraph(nx.Graph):
        self.node_counter = 0
        node_labels = {}
        self.node_name_map = {}
        self.graph_masks = {}

        # the node_name is the relative path for the image
        for node_name, node in self.nodes_iter(data=True):
@@ -291,6 +292,45 @@ class CandidateGraph(nx.Graph):
        """
        _, self.clusters = func(self, *args, **kwargs)

    def apply_func_to_edges(self, func, *args, graph_mask_keys=[], **kwargs):
        """
        Iterates over edges using an optional mask and and applies the given function.
        If func is not an attribute of Edge, raises AttributeError
        Parameters
        ----------
        func : string
               function to be called on every edge
        graph_mask_keys : list
                          of keys in graph_masks
        """

        merged_graph_mask = self.graph_masks[graph_mask_keys].all(axis=1)
        edges_to_iter = merged_graph_mask[merged_graph_mask].index

        for s, d in edges_to_iter:
            curr_edge = self.get_edge_data(s, d)
            try:
                function = getattr(curr_edge, func)
            except:
                raise AttributeError('The passed function is not an attribute of Edge')
            else:
                function(*args, **kwargs)

    def minimum_spanning_tree(self):
        """
        Calculates the minimum spanning tree of the graph

        Returns
        -------

         : DataFrame
           boolean mask for edges in the minimum spanning tree
        """

        self.graph_masks = pd.DataFrame(False, index=self.edges(), columns=['mst'])
        mst = nx.minimum_spanning_tree(self)
        self.graph_masks['mst'][mst.edges()] = True

    def symmetry_checks(self):
        """
        Perform a symmetry check on all edges in the graph
+24 −0
Original line number Diff line number Diff line
@@ -99,3 +99,27 @@ class TestEdge(unittest.TestCase):
    def setUpClass(cls):
        cls.graph = network.CandidateGraph.from_adjacency(get_path('adjacency.json'))


class TestMSTGraph(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.test_dict = {"0": ["4", "2", "1", "3"],
                         "1": ["0", "3", "2", "6", "5"],
                         "2": ["1", "0", "3", "4", "7"],
                         "3": ["2", "0", "1", "5"],
                         "4": ["2", "0"],
                         "5": ["1", "3"],
                         "6": ["1"],
                         "7": ["2"]}

        cls.graph = network.CandidateGraph.from_adjacency(cls.test_dict)
        cls.graph.minimum_spanning_tree()
        cls.mst_graph = cls.graph.copy()

        for s, d, edge in cls.graph.edges_iter(data=True):
            if not cls.graph.graph_masks['mst'][(s, d)]:
                cls.mst_graph.remove_edge(s, d)

    def test_mst_output(self):
        self.assertEqual(self.mst_graph.nodes(), self.graph.nodes())
        self.assertEqual(self.mst_graph.number_of_edges(), self.graph.number_of_edges()-5)
 No newline at end of file