Commit 3129dc23 authored by Tyler Thatcher's avatar Tyler Thatcher Committed by jlaura
Browse files

Moved remaining pyhat files (#120)

parent 2f896671
Loading
Loading
Loading
Loading

plio/io/hcube.py

deleted100644 → 0
+0 −208
Original line number Diff line number Diff line
import numpy as np
import gdal

from ..utils.indexing import _LocIndexer, _iLocIndexer
from libpyhat.transform.continuum import continuum_correction
from libpyhat.transform.continuum import polynomial, linear, regression


class HCube(object):
    """
    A Mixin class for use with the io_gdal.GeoDataset class
    to optionally add support for spectral labels, label
    based indexing, and lazy loading for reads.
    """
    def __init__(self, data = [], wavelengths = []):
        if len(data) != 0:
            self._data = data
        if len(wavelengths) != 0:
            self._wavelengths = wavelengths

    @property
    def wavelengths(self):
        if not hasattr(self, '_wavelengths'):
            try:
                info = gdal.Info(self.file_name, format='json')
                wavelengths = [float(j) for i, j in sorted(info['metadata'][''].items(),
                                 key=lambda x: float(x[0].split('_')[-1]))]
                self._original_wavelengths = wavelengths
                self._wavelengths = np.round(wavelengths, self.tolerance)
            except:
                self._wavelengths = []
        return self._wavelengths

    @property
    def data(self):
        if not hasattr(self, '_data'):
            try:
                key = (slice(None, None, None),
                       slice(None, None, None),
                       slice(None, None, None))
                data = self._read(key)
            except Exception as e:
                print(e)
                data = []
            self._data = data

        return self._data

    @property
    def tolerance(self):
        return getattr(self, '_tolerance', 2)

    @tolerance.setter
    def tolerance(self, val):
        if isinstance(val, int):
            self._tolerance = val
            self._reindex()
        else:
            raise TypeError

    def _reindex(self):
        if self._original_wavelengths is not None:
            self._wavelengths = np.round(self._original_wavelengths, decimals=self.tolerance)

    def __getitem__(self, key):
        i = _iLocIndexer(self)
        return i[key]

    @property
    def loc(self):
        return _LocIndexer(self)

    @property
    def iloc(self):
        return _iLocIndexer(self)

    def reduce(self, how = np.mean, axis = (1, 2)):
        """
        Parameters
        ----------
        how : function
              Function to apply across along axises of the hcube

        axis : tuple
               List of axis to apply a given function along

        Returns
        -------
        new_hcube : Object
                    A new hcube object with the reduced data set
        """
        res = how(self.data, axis = axis)

        new_hcube = HCube(res, self.wavelengths)
        return new_hcube

    def continuum_correct(self, nodes, correction_nodes = np.array([]), correction = linear,
                          axis=0, adaptive=False, window=3, **kwargs):
        """
        Parameters
        ----------

        nodes : list
                A list of wavelengths for the continuum to be corrected along

        correction_nodes : list
                           A list of nodes to limit the correction between

        correction : function
                     Function specifying the type of correction to perform
                     along the continuum

        axis : int
               Axis to apply the continuum correction on

        adaptive : boolean
                   ?

        window : int
                 ?

        Returns
        -------

        new_hcube : Object
                    A new hcube object with the corrected dataset
        """

        continuum_data = continuum_correction(self.data, self.wavelengths, nodes = nodes,
                                              correction_nodes = correction_nodes, correction = correction,
                                              axis = axis, adaptive = adaptive,
                                              window = window, **kwargs)

        new_hcube = HCube(continuum_data[0], self.wavelengths)
        return new_hcube


    def clip_roi(self, x, y, band, tolerance=2):
        """
        Parameters
        ----------

        x : tuple
            Lower and upper bound along the x axis for clipping

        y : tuple
            Lower and upper bound along the y axis for clipping

        band : tuple
               Lower and upper band along the z axis for clipping

        tolerance : int
                    Tolerance given for trying to find wavelengths
                    between the upper and lower bound

        Returns
        -------

        new_hcube : Object
                    A new hcube object with the clipped dataset
        """
        wavelength_clip = []
        for wavelength in self.wavelengths:
            wavelength_upper = wavelength + tolerance
            wavelength_lower = wavelength - tolerance
            if wavelength_upper > band[0] and wavelength_lower < band[1]:
                wavelength_clip.append(wavelength)

        key = (wavelength_clip, slice(*x), slice(*y))
        data_clip = _LocIndexer(self)[key]

        new_hcube = HCube(np.copy(data_clip), np.array(wavelength_clip))
        return new_hcube

    def _read(self, key):
        ifnone = lambda a, b: b if a is None else a

        y = key[1]
        x = key[2]
        if isinstance(x, slice):
            xstart = ifnone(x.start,0)
            xstop = ifnone(x.stop,self.raster_size[0])
            xstep = xstop - xstart
        else:
            raise TypeError("Loc style access elements must be slices, e.g., [:] or [10:100]")
        if isinstance(y, slice):
            ystart = ifnone(y.start, 0)
            ystop = ifnone(y.stop, self.raster_size[1])
            ystep = ystop - ystart
        else:
            raise TypeError("Loc style access elements must be slices, e.g., [:] or [10:100]")

        pixels = (xstart, ystart, xstep, ystep)
        if isinstance(key[0], (int, np.integer)):
            return self.read_array(band=int(key[0]+1), pixels=pixels)

        elif isinstance(key[0], slice):
            # Given some slice iterate over the bands and get the bands and pixel space requested
            arrs = []
            for band in list(list(range(1, self.nbands + 1))[key[0]]):
                arrs.append(self.read_array(band, pixels = pixels))
            return np.stack(arrs)

        else:
            arrs = []
            for b in key[0]:
                arrs.append(self.read_array(band=int(b+1), pixels=pixels))
        return np.stack(arrs)

plio/io/io_crism.py

deleted100644 → 0
+0 −61
Original line number Diff line number Diff line
import os
import numpy as np
from .io_gdal import GeoDataset
from .hcube import HCube

try:
    from libpyhat.derived import crism
    from libpyhat.derived.utils import get_derived_funcs
    libpyhat_enabled = True
except:
    print('No libpyhat module. Unable to attach derived product functions')
    libpyhat_enabled = False

import gdal


class Crism(GeoDataset, HCube):
    """
    An Crism specific reader with the spectral mixin.
    """
    def __init__(self, file_name):

        GeoDataset.__init__(self, file_name)
        HCube.__init__(self)

        self.derived_funcs = {}

        if libpyhat_enabled:
            self.derived_funcs = get_derived_funcs(crism)

    def __getattr__(self, name):
        try:
            func = self.derived_funcs[name]

            setattr(self, name, func.__get__(self))
            return getattr(self, name)

        except KeyError as keyerr:
            raise AttributeError("'Crism' object has no attribute '{}'".format(name)) from None

    @property
    def wavelengths(self):
        if not hasattr(self, '_wavelengths'):
            try:
                info = gdal.Info(self.file_name, format='json')
                wv = dict((k,v) for (k,v) in info['metadata'][''].items() if 'Band' in k) # Only get the 'Band_###' keys
                wavelengths = [float(j.split(" ")[0]) for i, j in sorted(wv.items(),
                                key=lambda x: int(x[0].split('_')[-1]))]
                self._original_wavelengths = wavelengths
                self._wavelengths = np.round(wavelengths, self.tolerance)
            except:
                self._wavelengths = []
        return self._wavelengths

def open(input_data):
    if os.path.splitext(input_data)[-1] == 'hdr':
        # GDAL wants the img, but many users aim at the .hdr
        input_data = os.path.splitext(input_data)[:-1] + '.img'
    ds = Crism(input_data)

    return ds
+0 −67
Original line number Diff line number Diff line
import os
import numpy as np
from .io_gdal import GeoDataset
from .hcube import HCube

try:
    from libpyhat.derived import m3
    from libpyhat.derived.utils import get_derived_funcs
    libpyhat_enabled = True
except:
    print('No libpyhat module. Unable to attach derived product functions')
    libpyhat_enabled = False

import gdal


class M3(GeoDataset, HCube):
    """
    An M3 specific reader with the spectral mixin.
    """
    def __init__(self, file_name):

        GeoDataset.__init__(self, file_name)
        HCube.__init__(self)

        self.derived_funcs = {}

        if libpyhat_enabled:
            self.derived_funcs = get_derived_funcs(m3)

    def __getattr__(self, name):
        try:
            func = self.derived_funcs[name]

            setattr(self, name, func.__get__(self))
            return getattr(self, name)

        except KeyError as keyerr:
            raise AttributeError("'M3' object has no attribute '{}'".format(name)) from None

    @property
    def wavelengths(self):
        if not hasattr(self, '_wavelengths'):
            try:
                info = gdal.Info(self.file_name, format='json')
                if 'Resize' in info['metadata']['']['Band_1']:
                    wavelengths = [float(j.split(' ')[-1].replace('(','').replace(')', '')) for\
                                  i,j in sorted(info['metadata'][''].items(),
                                  key=lambda x: float(x[0].split('_')[-1]))]
                    # This is a geotiff translated from the PDS IMG
                else:
                    # This is a PDS IMG
                    wavelengths = [float(j) for i, j in sorted(info['metadata'][''].items(),
                                    key=lambda x: float(x[0].split('_')[-1]))]
                self._original_wavelengths = wavelengths
                self._wavelengths = np.round(wavelengths, self.tolerance)
            except:
                self._wavelengths = []
        return self._wavelengths

def open(input_data):
    if os.path.splitext(input_data)[-1] == 'hdr':
        # GDAL wants the img, but many users aim at the .hdr
        input_data = os.path.splitext(input_data)[:-1] + '.img'
    ds = M3(input_data)

    return ds