Commit b3fa8451 authored by Jay's avatar Jay Committed by jay
Browse files

Adds testing for the save and load functionality for features when called via the graph object

parent f1ef39dc
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -48,7 +48,6 @@ class CandidateGraph(nx.Graph):
        node_labels = {}
        self.node_name_map = {}
        self.graph_masks = pd.DataFrame()

        # the node_name is the relative path for the image
        for node_name, node in self.nodes_iter(data=True):
            image_name = os.path.basename(node_name)
@@ -56,13 +55,11 @@ class CandidateGraph(nx.Graph):

            # Replace the default node dict with an object
            self.node[node_name] = Node(image_name, image_path)

            # fill the dictionary used for relabelling nodes with relative path keys
            node_labels[node_name] = self.node_counter
            # fill the dictionary used for mapping base name to node index
            self.node_name_map[self.node[node_name].image_name] = self.node_counter
            self.node_counter += 1

        nx.relabel_nodes(self, node_labels, copy=False)

        # Add the Edge class as a edge data structure
+9 −3
Original line number Diff line number Diff line
@@ -54,9 +54,9 @@ class Node(dict, MutableMapping):
        Image PATH: {}
        Number Keypoints: {}
        Available Masks : {}
        Type: {}
        """.format(None, self.image_name, self.image_path,
                   self.nkeypoints, self.masks)

                   self.nkeypoints, self.masks, self.__class__)
    @property
    def handle(self):
        if not getattr(self, '_handle', None):
@@ -77,8 +77,14 @@ class Node(dict, MutableMapping):
    @property
    def masks(self):
        mask_lookup = {'suppression': 'suppression'}

        if not hasattr(self, '_keypoints'):
            warnings.warn('Keypoints have note been extracted')
            return

        if not hasattr(self, '_masks'):
            self._masks = pd.DataFrame()
            self._masks = pd.DataFrame(index=self._keypoints.index)

        # If the mask is coming form another object that tracks
        # state, dynamically draw the mask from the object.
        for c in self._masks.columns:
+28 −1
Original line number Diff line number Diff line
@@ -66,7 +66,7 @@ class TestCandidateGraph(unittest.TestCase):
        subgraph_list = self.graph.connected_subgraphs()
        self.assertEqual(len(subgraph_list), 1)

    def test_save_load(self):
    def test_save_load_graph(self):
        self.graph.save('test_save.cg')
        loaded = self.graph.from_graph('test_save.cg')
        self.assertEqual(self.graph.node[0].nkeypoints, loaded.node[0].nkeypoints)
@@ -78,6 +78,31 @@ class TestCandidateGraph(unittest.TestCase):

        os.remove('test_save.cg')

    def test_save_load_features(self):
        for i in ['all_out.hdf', 'one_out.hdf']:
            try:
                os.remove(i)
            except: pass

        graph = self.graph.copy()
        graph.extract_features(extractor_parameters={'nfeatures': 10})
        graph.save_features('all_out.hdf')
        graph.save_features('one_out.hdf', nodes=[1])
        graph_no_features = self.graph.copy()
        graph_no_features.load_features('one_out.hdf', nodes=[1])
        self.assertEqual(graph.node[1].get_keypoints().all().all(),
                         graph_no_features.node[1].get_keypoints().all().all())

        graph_no_features.load_features('all_out.hdf')
        for n in graph.nodes():
            self.assertEqual(graph.node[n].get_keypoints().all().all(),
                             graph_no_features.node[n].get_keypoints().all().all())
        for i in ['all_out.hdf', 'one_out.hdf']:
            try:
                os.remove(i)
            except: pass


    def test_fromlist(self):
        mock_list = ['AS15-M-0295_SML.png', 'AS15-M-0296_SML.png', 'AS15-M-0297_SML.png',
                     'AS15-M-0298_SML.png', 'AS15-M-0299_SML.png', 'AS15-M-0300_SML.png']
@@ -87,6 +112,8 @@ class TestCandidateGraph(unittest.TestCase):
        n = network.CandidateGraph.from_filelist(get_path('adjacency.lis'), get_path('Apollo15'))
        self.assertEqual(len(n.nodes()), 6)



    def tearDown(self):
        pass

+22 −1
Original line number Diff line number Diff line
@@ -2,8 +2,11 @@ import os
import sys
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import unittest
import warnings

import numpy as np
import pandas as pd

from autocnet.examples import get_path
from autocnet.fileio.io_gdal import GeoDataset
@@ -35,6 +38,24 @@ class TestNode(unittest.TestCase):
        self.assertIsInstance(self.node.descriptors[0], np.ndarray)
        self.assertEqual(10, self.node.nkeypoints)

        # Test the setter
        self.node.nkeypoints = 11
        self.assertEqual(11, self.node.nkeypoints)

    def test_masks(self):
        # Assert a warning raise here
        with warnings.catch_warnings(record=True) as w:
            masks = self.node.masks
            self.assertEqual(len(w), 1)
            self.assertEqual(w[0].category, UserWarning)

        image = self.node.get_array()
        self.node.extract_features(image, extractor_parameters={'nfeatures':5})
        self.assertIsInstance(self.node.masks, pd.DataFrame)
        # Create an artificial mask
        self.node.masks = ('foo', np.array([0,0,1,1,1], dtype=np.bool))
        self.assertEqual(self.node.masks['foo'].sum(), 3)

    def test_convex_hull_ratio_fail(self):
        # Convex hull computation is checked lower in the hull computation
        self.assertRaises(AttributeError, self.node.coverage_ratio)