Commit ebc55f33 authored by Jay's avatar Jay Committed by Jason R Laura
Browse files

Adds optional dependency / functionality to utilize VLFeat SIFT implementation...

Adds optional dependency / functionality to utilize VLFeat SIFT implementation for keypoint and descriptor extraction.
parent cf726099
Loading
Loading
Loading
Loading
+24 −14
Original line number Diff line number Diff line
@@ -217,7 +217,12 @@ class Node(dict, MutableMapping):
                 kwargs passed to autocnet.feature_extractor.extract_features

        """
        keypoint_objs, descriptors = fe.extract_features(array, **kwargs)
        keypoint_objs, self.descriptors = fe.extract_features(array, **kwargs)
        if self.descriptors.dtype != np.float32:
            self.descriptors = self.descriptors.astype(np.float32)

        # OpenCV returned keypoint objects
        if isinstance(keypoint_objs, list):
            keypoints = np.empty((len(keypoint_objs), 7), dtype=np.float32)
            for i, kpt in enumerate(keypoint_objs):
                octave = kpt.octave & 8
@@ -230,7 +235,12 @@ class Node(dict, MutableMapping):
            self._keypoints = pd.DataFrame(keypoints, columns=['x', 'y', 'response', 'size',
                                                               'angle', 'octave', 'layer'])
            self.nkeypoints = len(self._keypoints)
        self.descriptors = descriptors.astype(np.float32)

        # VLFeat returned keypoint objects
        elif isinstance(keypoint_objs, np.ndarray):
            # Swap columns for value style access, vl_feat returns y, x
            keypoint_objs[:, 0], keypoint_objs[:, 1] = keypoint_objs[:, 1], keypoint_objs[:, 0].copy()
            self._keypoints = pd.DataFrame(keypoint_objs, columns=['x', 'y', 'size', 'angle'])

    def load_features(self, in_path):
        """
+15 −3
Original line number Diff line number Diff line
import cv2

try:
    import cyvlfeat as vl
    vlfeat = True
except:
    vlfeat = False
    pass


def extract_features(array, method='orb', extractor_parameters=None):
    """
@@ -28,6 +35,11 @@ def extract_features(array, method='orb', extractor_parameters=None):
                 'sift': cv2.xfeatures2d.SIFT_create,
                 'surf': cv2.xfeatures2d.SURF_create,
                 'orb': cv2.ORB_create}
    if vlfeat:
        detectors['vl_sift'] = vl.sift.sift

    if 'vl_' in method:
        return detectors[method](array, compute_descriptor=True, float_descriptors=True, **extractor_parameters)
    else:
        detector = detectors[method](**extractor_parameters)
        return detector.detectAndCompute(array, None)
+7 −0
Original line number Diff line number Diff line
@@ -32,3 +32,10 @@ class TestFeatureExtractor(unittest.TestCase):
        self.assertIn(len(features[0]), range(8, 12))
        self.assertIsInstance(features[0][0], type(cv2.KeyPoint()))
        self.assertIsInstance(features[1][0], np.ndarray)

    def test_extract_vlfeat(self):
        kps, descriptors = feature_extractor.extract_features(self.data_array,
                                                              method='vl_sift',
                                                              extractor_parameters={})
        self.assertIsInstance(kps, np.ndarray)
        self.assertEqual(descriptors.dtype, np.float32)