Commit a80c4259 authored by jlaura's avatar jlaura
Browse files

Merge pull request #89 from Kelvinrr/subgraph_filter

simple node/edge filter
parents ecb2d8e8 af03753b
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -716,3 +716,38 @@ class CandidateGraph(nx.Graph):
                   not edge.matches.empty]

        return self.create_edge_subgraph(matches)

    def filter_nodes(self, func, *args, **kwargs):
        """
        Filters graph and returns a sub-graph from matches. Mimics
        python's filter() function

        Parameters
        ----------
        func : function which returns bool used to filter out nodes

        Returns
        -------
        : Object
          A networkX graph object

        """
        nodes = [n for n, d in self.nodes_iter(data=True) if func(d, *args, **kwargs)]
        return self.create_node_subgraph(nodes)

    def filter_edges(self, func, *args, **kwargs):
        """
        Filters graph and returns a sub-graph from matches. Mimics
        python's filter() function

        Parameters
        ----------
        func : function which returns bool used to filter out edges

        Returns
        -------
        : Object
          A networkX graph object
        """
        edges = [(u, v) for u, v, edge in self.edges_iter(data=True) if func(edge, *args, **kwargs)]
        return self.create_edge_subgraph(edges)
+15 −0
Original line number Diff line number Diff line
@@ -114,6 +114,21 @@ class TestCandidateGraph(unittest.TestCase):
        node_sub = g.create_node_subgraph([0,1])
        self.assertEqual(len(node_sub), 2)

    def test_filter(self):
        def edge_func(edge):
            return hasattr(edge, 'matches') and not edge.matches.empty

        graph = self.graph.copy()
        test_sub_graph = graph.create_node_subgraph([0, 1])
        test_sub_graph.extract_features(extractor_parameters={'nfeatures': 500})
        test_sub_graph.match_features(k=2)

        filtered_nodes = graph.filter_nodes(lambda node: hasattr(node, 'descriptors'))
        filtered_edges = graph.filter_edges(edge_func)

        self.assertEqual(filtered_nodes.number_of_nodes(), test_sub_graph.number_of_nodes())
        self.assertEqual(filtered_edges.number_of_edges(), test_sub_graph.number_of_edges())

    def test_subgraph_from_matches(self):
        test_sub_graph = self.graph.create_node_subgraph([0, 1])
        test_sub_graph.extract_features(extractor_parameters={'nfeatures': 500})