Commit 8b3ed1ad authored by Jay's avatar Jay Committed by jay
Browse files

Test updates.

parent 84326679
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -54,3 +54,24 @@ class TestC(unittest.TestCase):

    def test_creation_date(self):
        self.assertEqual(self.C.creationdate, strftime("%Y-%m-%d %H:%M:%S", gmtime()))

    def test_point_subpixel(self):
        for k, v in self.C.point_to_correspondence.items():
            self.assertFalse(k.subpixel)
            k.subpixel = True
            self.assertTrue(k.subpixel)
            break

    def test_equalities(self):
        points = []
        correspondences = []
        for k, v in self.C.point_to_correspondence.items():
            points.append(k)
            correspondences.extend(v)
        self.assertEqual(points[0], points[0])
        self.assertNotEqual(points[-1], points[1])
        self.assertEqual(correspondences[0][0], correspondences[0][0])

    def test_to_dataframe(self):
        self.C.to_dataframe()
+0 −136
Original line number Diff line number Diff line
@@ -530,142 +530,6 @@ class CandidateGraph(nx.Graph):
                matches = edge.matches
            self.cn.add_correspondences(edge, matches)

    def to_cnet(self, clean_keys=[], isis_serials=False):
        """
        Generate a control network (C) object from a graph

        Parameters
        ----------

        clean_keys : list
             of strings identifying the masking arrays to use, e.g. ratio, symmetry

        isis_serials : bool
                       Replace the node ID (nid) values with an ISIS
                       serial number. Default False

        Returns
        -------

        merged_cnet : C
                      A control network object
        """

        def _validate_cnet(cnet):
            """
            Once the control network is aggregated from graph edges,
            ensure that a given correspondence in a given image does
            not match multiple correspondences in a different image.

            Parameters
            ----------
            cnet : C
                   control network object

            Returns
            -------
             : C
               the cleaned control network
            """
            mask = np.zeros(len(cnet), dtype=bool)
            counter = 0
            for i, group in cnet.groupby('pid'):
                group_size = len(group)
                if len(group) != len(group['nid'].unique()):
                    mask[counter: counter + group_size] = False
                else:
                    mask[counter: counter + group_size] = True
                counter += group_size

            return cnet[mask]

        merged_cnet = None

        for source, destination, edge in self.edges_iter(data=True):
            matches = edge.matches

            # Merge all of the masks
            if clean_keys:
                matches, mask = edge._clean(clean_keys)

            subpixel = False
            point_type = 2
            if 'subpixel' in clean_keys:
                subpixel = True
                point_type = 3

            kp1 = self.node[source].get_keypoints()
            kp2 = self.node[destination].get_keypoints()
            pt_idx = 0
            values = []
            for i, (idx, row) in enumerate(matches.iterrows()):
                # Composite matching key (node_id, point_id)
                m1_pid = int(row['source_idx'])
                m2_pid = int(row['destination_idx'])
                m1 = (source, int(row['source_idx']))
                m2 = (destination, int(row['destination_idx']))


                values.append([kp1.loc[m1_pid]['x'],
                               kp1.loc[m1_pid]['y'],
                               m1,
                               pt_idx,
                               source,
                               idx,
                               point_type])

                if subpixel:
                    kp2x = kp2.loc[m2_pid]['x'] + row['x_offset']
                    kp2y = kp2.loc[m2_pid]['y'] + row['y_offset']
                else:
                    kp2x = kp2.loc[m2_pid]['x']
                    kp2y = kp2.loc[m2_pid]['y']

                values.append([kp2x,
                               kp2y,
                               m2,
                               pt_idx,
                               destination,
                               idx,
                               point_type])
                pt_idx += 1

            columns = ['x', 'y', 'idx', 'pid', 'nid', 'mid', 'point_type']

            cnet = CorrespondenceNetwork(values, columns=columns)

            if merged_cnet is None:
                merged_cnet = cnet.copy(deep=True)
            else:
                pid_offset = merged_cnet['pid'].max() + 1  # Get the current max point index
                cnet[['pid']] += pid_offset

                # Inner merge on the dataframe identifies common points
                common = pd.merge(merged_cnet, cnet, how='inner', on='idx', left_index=True, suffixes=['_r',
                                                                                                      '_l'])

                # Iterate over the points to be merged and merge them in.
                for i, r in common.iterrows():
                    new_pid = r['pid_r']
                    update_pid = r['pid_l']
                    cnet.loc[cnet['pid'] == update_pid, ['pid']] = new_pid  # Update the point ids

                # Perform the concat
                merged_cnet = pd.concat([merged_cnet, cnet])
                merged_cnet.drop_duplicates(['idx', 'pid'], keep='first', inplace=True)

        # Final validation to remove any correspondence with multiple correspondences in the same image
        merged_cnet = _validate_cnet(merged_cnet)

        # If the user wants ISIS serial numbers, replace the nid with the serial.
        if isis_serials is True:
            nid_to_serial = {}
            for i, node in self.nodes_iter(data=True):
                nid_to_serial[i] = node.isis_serial
            merged_cnet.replace({'nid': nid_to_serial}, inplace=True)

        return merged_cnet

    def to_json_file(self, outputfile):
        """
        Write the edge structure to a JSON adjacency list
+10 −6
Original line number Diff line number Diff line
@@ -53,17 +53,20 @@ class TestCandidateGraph(unittest.TestCase):
        except AttributeError:
            pass

        mst_graph.extract_features(extractor_parameters={'nfeatures': 500})

        mst_graph.extract_features(extractor_parameters={'nfeatures': 50})
        mst_graph.match_features()
        mst_graph.apply_func_to_edges("symmetry_check")

        # Test passing the func by signature
        mst_graph.apply_func_to_edges(graph[0][1].symmetry_check)

        self.assertFalse(graph[0][2].masks['symmetry'].all())
        self.assertFalse(graph[0][1].masks['symmetry'].all())

        try:
            self.assertTrue(graph[1][2].masks['symmetry'].all())
        except:
            pass
        except: pass

    def test_connected_subgraphs(self):
        subgraph_list = self.disconnected_graph.connected_subgraphs()
@@ -178,6 +181,7 @@ class TestCandidateGraph(unittest.TestCase):
        self.assertEqual(sorted(mst_graph.nodes()), sorted(graph.nodes()))
        self.assertEqual(len(mst_graph.edges()), len(graph.edges())-5)


    def tearDown(self):
        pass
    def test_triangular_cycles(self):
        cycles = self.graph.compute_triangular_cycles()
        # Node order is variable, length is not
        self.assertEqual(len(cycles), 1)
+2 −4
Original line number Diff line number Diff line
@@ -160,10 +160,8 @@ class SpatialSuppression(Observable):

    @property
    def nvalid(self):
        try:
        return self.mask.sum()
        except:
            return None


    @property
    def error_k(self):
+8 −0
Original line number Diff line number Diff line
@@ -94,3 +94,11 @@ class TestSpatialSuppression(unittest.TestCase):
        self.suppression_obj.suppress()
        self.assertIn(self.suppression_obj.mask.sum(), list(range(27, 34)))

        with warnings.catch_warnings(record=True) as w:
            self.suppression_obj.k = 101
            self.suppression_obj.suppress()
            self.assertEqual(len(w), 1)
            self.assertTrue(issubclass(w[0].category, UserWarning))


Loading