Commit 760fb8df authored by jlaura's avatar jlaura Committed by GitHub
Browse files

Fixes 144 (#146)

* Fixes 144

* Supports read/write, warn on pointLog

* Adds tests and fixes bugs that the tests identified

* Updates for comments

* updates tests
parent bc105554
Loading
Loading
Loading
Loading
+81 −2
Original line number Diff line number Diff line
from enum import IntEnum
from time import gmtime, strftime
import warnings

import pandas as pd
import numpy as np
@@ -13,6 +15,7 @@ from plio.utils.utils import xstr, find_in_dict
HEADERSTARTBYTE = 65536
DEFAULTUSERNAME = 'None'


def write_filelist(lst, path="fromlist.lis"):
    """
    Writes a filelist to a file so it can be used in ISIS3.
@@ -29,6 +32,73 @@ def write_filelist(lst, path="fromlist.lis"):
        handle.write('\n')
    return


class MeasureMessageType(IntEnum):
    """
    An enum to mirror the ISIS3 MeasureLogData enum.
    """
    GoodnessOfFit = 2
    MinimumPixelZScore = 3
    MaximumPixelZScore = 4
    PixelShift = 5
    WholePixelCorrelation = 6
    SubPixelCorrelation = 7 

class MeasureLog():
    
    def __init__(self, messagetype, value):
        """
        A protobuf compliant measure log object.
        
        Parameters
        ----------
        messagetype : int or str
                      Either the integer or string representation from the MeasureMessageType enum
                      
        value : int or float
                The value to be stored in the message log
        """
        if isinstance(messagetype, int):
            # by value
            self.messagetype = MeasureMessageType(messagetype)
        else:
            # by name
            self.messagetype = MeasureMessageType[messagetype]
        
        if not isinstance(value, (float, int)):
            raise TypeError(f'{value} is not a numeric type')
        self.value = value
        
    def __repr__(self):
        return f'{self.messagetype.name}: {self.value}'
        
    def to_protobuf(self, version=2):
        """
        Return protobuf compliant measure log object representation
        of this class.
        
        Returns
        -------
        log_message : obj
                      MeasureLogData object suitable to append to a MeasureLog
                      repeated field.
        """
        # I do not see a better way to get to the inner MeasureLogData obj than this
        # imports were not working because it looks like these need to instantiate off
        # an object
        if version == 2:
            log_message = cnf.ControlPointFileEntryV0002().Measure().MeasureLogData()
        elif version == 5:
            log_message = cnp5.ControlPointFileEntryV0005().Measure().MeasureLogData()
        log_message.doubleDataValue = self.value
        log_message.doubleDataType = self.messagetype
        return log_message

    @classmethod
    def from_protobuf(cls, protobuf):
        return cls(protobuf.doubleDataType, protobuf.doubleDataValue)


class IsisControlNetwork(pd.DataFrame):

    # normal properties
@@ -171,7 +241,6 @@ class IsisStore(object):
            for s in pbuf_header.pointMessageSizes:
                cp.ParseFromString(self._handle.read(s))
                pt = [getattr(cp, i) for i in self.point_attrs if i != 'measures']

                for measure in cp.measures:
                    meas = pt + [getattr(measure, j) for j in self.measure_attrs]
                    pts.append(meas)
@@ -211,6 +280,10 @@ class IsisStore(object):
        if 'aprioriline' in df.columns:
            df['aprioriline'] -= 0.5
            df['apriorisample'] -= 0.5

        # Munge the MeasureLogData into Python objs
        df['measureLog'] = df['measureLog'].apply(lambda x: [MeasureLog.from_protobuf(i) for i in x])
        
        df.header = pvl_header
        return df

@@ -266,6 +339,10 @@ class IsisStore(object):
                # Un-mangle common attribute names between points and measures
                df_attr = self.point_field_map.get(attr, attr)
                if df_attr in g.columns:
                    if df_attr == 'pointLog':
                        # Currently pointLog is not supported.
                        warnings.warn('The pointLog field is currently unsupported. Any pointLog data will not be saved.')
                        continue
                    # As per protobuf docs for assigning to a repeated field.
                    if df_attr == 'aprioriCovar' or df_attr == 'adjustedCovar':
                        arr = g.iloc[0][df_attr]
@@ -290,8 +367,10 @@ class IsisStore(object):
                    # Un-mangle common attribute names between points and measures
                    df_attr = self.measure_field_map.get(attr, attr)
                    if df_attr in g.columns:
                        if df_attr == 'measureLog':
                            [getattr(measure_spec, attr).extend([i.to_protobuf()]) for i in m[df_attr]]
                        # If field is repeated you must extend instead of assign
                        if cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3:
                        elif cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3:
                            getattr(measure_spec, attr).extend(m[df_attr])
                        else:
                            setattr(measure_spec, attr, attrtype(m[df_attr]))
+106 −94
Original line number Diff line number Diff line
@@ -31,97 +31,109 @@ def test_cnet_read(cnet_file):
        assert proto_field not in df.columns
        assert mangled_field in df.columns

class TestWriteIsisControlNetwork(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.npts = 5
        serial_times = {295: '1971-07-31T01:24:11.754',
                        296: '1971-07-31T01:24:36.970'}
        cls.serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
        columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog']

        data = []
        for i in range(cls.npts):
            data.append((i, 2, cls.serials[0], 2, 0, 0, 0, [], []))
            data.append((i, 2, cls.serials[1], 2, 0, 0, 1, [], []))

        df = pd.DataFrame(data, columns=columns)

        cls.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
        cls.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
        io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon')

        cls.header_message_size = 78
        cls.point_start_byte = 65614 # 66949

    def test_create_buffer_header(self):
@pytest.mark.parametrize('messagetype, value', [
                         (2, 0.5),
                         (3, 0.5),
                         (4, -0.25),
                         (5, 1e6),
                         (6, 1),
                         (7, -1e10),
                         ('GoodnessOfFit', 0.5),
                         ('MinimumPixelZScore', 0.25)
])
def test_MeasureLog(messagetype, value):
    l = io_controlnetwork.MeasureLog(messagetype, value)
    if isinstance(messagetype, int):
        assert l.messagetype == io_controlnetwork.MeasureMessageType(messagetype)
    elif isinstance(messagetype, str):
        assert l.messagetype == io_controlnetwork.MeasureMessageType[messagetype]
        
    assert l.value == value
    assert isinstance(l.to_protobuf, object)

def test_log_error():
    with pytest.raises(TypeError) as err:
        io_controlnetwork.MeasureLog(2, 'foo')

def test_to_protobuf():
    value = 1.25
    int_dtype = 2
    log = io_controlnetwork.MeasureLog(int_dtype, value)
    proto = log.to_protobuf()
    assert proto.doubleDataType == int_dtype
    assert proto.doubleDataValue == value

@pytest.fixture
def cnet_dataframe(tmpdir):
    npts = 5
    serial_times = {295: '1971-07-31T01:24:11.754',
                    296: '1971-07-31T01:24:36.970'}
    serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
        columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index']
    columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog']

    data = []
        for i in range(self.npts):
            data.append((i, 2, serials[0], 2, 0, 0, 0))
            data.append((i, 2, serials[1], 2, 0, 0, 1))
    for i in range(npts):
        data.append((i, 2, serials[0], 2, 0, 0, 0, [], []))
        data.append((i, 2, serials[1], 2, 0, 0, 1, [], [io_controlnetwork.MeasureLog(2, 0.5)]))

    df = pd.DataFrame(data, columns=columns)

        self.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
        self.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
        io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon')
    df.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
    df.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
    io_controlnetwork.to_isis(df, tmpdir.join('test.net'), mode='wb', targetname='Moon')

        self.header_message_size = 78
        self.point_start_byte = 65614 # 66949
    df.header_message_size = 78
    df.point_start_byte = 65614 # 66949
    df.npts = npts
    df.measure_size = 149  # Size of each measure in bytes
    df.serials = serials
    return df

def test_create_buffer_header(cnet_dataframe, tmpdir):
    with open(tmpdir.join('test.net'), 'rb') as f:
        
        with open('test.net', 'rb') as f:
        f.seek(io_controlnetwork.HEADERSTARTBYTE)
            raw_header_message = f.read(self.header_message_size)
        raw_header_message = f.read(cnet_dataframe.header_message_size)
        header_protocol = cnf.ControlNetFileHeaderV0002()
        header_protocol.ParseFromString(raw_header_message)
        #Non-repeating
        #self.assertEqual('None', header_protocol.networkId)
            self.assertEqual('Moon', header_protocol.targetName)
            self.assertEqual(io_controlnetwork.DEFAULTUSERNAME,
                             header_protocol.userName)
            self.assertEqual(self.creation_date,
                             header_protocol.created)
            self.assertEqual('None', header_protocol.description)
            self.assertEqual(self.modified_date, header_protocol.lastModified)
        assert 'Moon' == header_protocol.targetName
        assert io_controlnetwork.DEFAULTUSERNAME == header_protocol.userName
        assert cnet_dataframe.creation_date == header_protocol.created
        assert 'None' == header_protocol.description
        assert cnet_dataframe.modified_date == header_protocol.lastModified
        #Repeating
            self.assertEqual([135] * self.npts, header_protocol.pointMessageSizes)

    def test_create_point(self):
        assert [cnet_dataframe.measure_size] * cnet_dataframe.npts == header_protocol.pointMessageSizes

        with open('test.net', 'rb') as f:
            f.seek(self.point_start_byte)
            for i, length in enumerate([135] * self.npts):
def test_create_point(cnet_dataframe, tmpdir):
    with open(tmpdir.join('test.net'), 'rb') as f:
        f.seek(cnet_dataframe.point_start_byte)
        for i, length in enumerate([cnet_dataframe.measure_size] * cnet_dataframe.npts):
            point_protocol = cnf.ControlPointFileEntryV0002()
            raw_point = f.read(length)
            point_protocol.ParseFromString(raw_point)
                self.assertEqual(str(i), point_protocol.id)
                self.assertEqual(2, point_protocol.type)
                for m in point_protocol.measures:
                    self.assertTrue(m.serialnumber in self.serials.values())
                    self.assertEqual(2, m.type)

    def test_create_pvl_header(self):
        pvl_header = pvl.load('test.net')
            assert str(i) == point_protocol.id
            assert 2 == point_protocol.type
            print(len(point_protocol.measures))
            for i, m in enumerate(point_protocol.measures):
                assert m.serialnumber in cnet_dataframe.serials.values()
                assert 2 == m.type
                assert len(m.log) == i

def test_create_pvl_header(cnet_dataframe, tmpdir):
    with open(tmpdir.join('test.net'), 'rb') as f:
        pvl_header = pvl.load(f)

    npoints = find_in_dict(pvl_header, 'NumberOfPoints')
        self.assertEqual(5, npoints)
    assert 5 == npoints

    mpoints = find_in_dict(pvl_header, 'NumberOfMeasures')
        self.assertEqual(10, mpoints)
    assert 10 == mpoints

    points_bytes = find_in_dict(pvl_header, 'PointsBytes')
        self.assertEqual(675, points_bytes)
    assert 745 == points_bytes

    points_start_byte = find_in_dict(pvl_header, 'PointsStartByte')
        self.assertEqual(self.point_start_byte, points_start_byte)
    assert cnet_dataframe.point_start_byte == points_start_byte
    @classmethod
    def tearDownClass(cls):
        os.remove('test.net')