Commit f2f72069 authored by jcwbacker's avatar jcwbacker
Browse files

Updated documentation, naming conventions and tests for io_gdal.

parent 5e7ac635
Loading
Loading
Loading
Loading
+331 −101
Original line number Diff line number Diff line
import os

import numpy as np
import numpy
from osgeo import gdal
from osgeo import osr

from autocnet.fileio import extract_metadata as em
from autocnet.fileio import extract_metadata

NP2GDAL_CONVERSION = {
  "uint8": 1,
@@ -26,91 +26,148 @@ for k, v in iter(NP2GDAL_CONVERSION.items()):

GDAL2NP_CONVERSION[1] = 'int8'

class GeoDataSet(object):
class GeoDataset(object):
    """
    Geospatial dataset object
    Geospatial dataset object that represents.

    Parameters
    ----------
    filename : str
               The path to the file
    file_name : str
                The name of the input image, including its full path.

    Attributes
    ----------

    basename : str
               The basename extracted from the full path
    base_name : str
                The base name of the input image, extracted from the full path.

    geotransform : object
                   OGR geotransformation object

    standardparallels : list
                        of the standard parallels

    unittype : str
               Name of the unit, e.g. 'm' or 'ft' used by the raster

    spatialreference : object
                       OSR spatial reference object
                   Represents the geotransform reference OGR object.

    geospatial_coordinate_system : object
                                   OSR geospatial coordinate reference object
                                   Represents the geospatial coordinate system OSR object.

    latlon_extent : list
                    of tuples in the form (llat, llon), (ulat, ulon)
                    of two tuples to describe that latitide/longitude boundaries. This list is in the form [(lowerlat, lowerlon), (upperlat, upperlon)].

    pixel_width : float
                  The width of the image pixels (i.e. length in the x-direction).

    pixel_height : float
                   The height of the image pixels (i.e. length in the y-direction).

    extent : list
             of tuples in the form (minx, miny), (maxx, maxy)
    spatial_reference : object
                        Represents the OSR spatial reference system OSR object.

    xpixelsize : float
                 Size of the x-pixel
    standard_parallels : list
                         of the standard parallels used by the map projection.

    ypixelsize : float
                 Size of the y-pixel
    unit_type : str
                Name of the unit used by the raster, e.g. 'm' or 'ft'.

    xrotation : float
                Rotation of the x-axis
                The geotransform coefficient that represents the rotation about the x-axis.

    xy_extent : list
                of two tuples to describe the x/y boundaries. This list is in the form [(minx, miny), (maxx, maxy)].

    yrotation : float
                Rotation of the y-axis
                The geotransform coefficient that represents the rotation about the y-axis.

    """
    def __init__(self, filename):
        self.filename = filename
        self.ds = gdal.Open(filename)
    def __init__(self, file_name):
        """
        Initialization method to set the file name and open the file using GDAL.

        Parameters
        ----------
        file_name : str
                   The file name to set and open.

        """
        self.file_name = file_name
        self.dataset = gdal.Open(file_name)
    
    def __repr__(self):
        return os.path.basename(self.filename)
        return os.path.basename(self.file_name)

    @property
    def basename(self):
        if not getattr(self, '_basename', None):
            self._basename = os.path.splitext(os.path.basename(self.filename))[0]
        return self._basename
    def base_name(self):
        """
        Gets the base name of the file (without the full directory path).

        Returns
        -------
        _base_name : str
                     The base file name.

        """
        if not getattr(self, '_base_name', None):
            self._base_name = os.path.splitext(os.path.basename(self.file_name))[0]
        return self._base_name

    @property
    def geotransform(self):
        """
        Gets an array of size 6 containing the affine transformation coefficients for transforming from raw sample/line to projected x/y.

        xproj = geotransform[0] + sample * geotransform[1] + line * geotransform[2]
        yproj = geotransform[3] + sample * geotransform[4] + line * geotransform[5]

        Returns
        -------
        _geotransform : array
                        of transformation coefficients.

        """
        if not getattr(self, '_geotransform', None):
            self._geotransform = self.ds.GetGeoTransform()
            self._geotransform = self.dataset.GetGeoTransform()
        return self._geotransform

    @property
    def standardparallels(self):
        if not getattr(self, '_standardparallels', None):
            self._standardparallels = em.get_standard_parallels(self.spatialreference)
        return self._standardparallels
    def standard_parallels(self):
        """
        Gets the list of standard parallels found in the metadata using the spatial reference for this GeoDataset.

        Returns
        -------
        _standard_parallels : list
                              of standard parallels.

        """
        if not getattr(self, '_standard_parallels', None):
            self._standard_parallels = extract_metadata.get_standard_parallels(self.spatial_reference)
        return self._standard_parallels

    @property
    def unittype(self):
        if not getattr(self, '_unittype', None):
            self._unittype = self.ds.GetRasterBand(1).GetUnitType()
        return self._unittype
    def unit_type(self):
        """
        Gets the type of units the raster data is stored in. For example, this might be meters, kilometers, feet, etc.

        Returns
        -------
        _unit_type : str
                     The units for this data set.

        """
        if not getattr(self, '_unit_type', None):
            self._unit_type = self.dataset.GetRasterBand(1).GetUnitType()
        return self._unit_type

    @property
    def spatialreference(self):
    def spatial_reference(self):
        """
        Gets the spatial reference system (SRS) and sets the geospatial coordinate system (GCS).

        Returns
        -------
        _srs : object
               The spatial reference system. 
        
        """
        if not getattr(self, '_srs', None):
            self._srs = osr.SpatialReference()
            self._srs.ImportFromWkt(self.ds.GetProjection())
            self._srs.ImportFromWkt(self.dataset.GetProjection())
            try:
                self._srs.MorphToESRI()
                self._srs.MorphFromESRI()
@@ -122,125 +179,254 @@ class GeoDataSet(object):

    @property
    def geospatial_coordinate_system(self):
        """
        Gets the geospatial coordinate system (GCS).

        Returns
        -------
        _gcs : object
               The geospatial coordinate system. 
        
        """
        if not getattr(self, '_gcs', None):
            self._gcs = self.spatialreference.CloneGeogCS()
            self._gcs = self.spatial_reference.CloneGeogCS()
        return self._gcs

    @property
    def latlon_extent(self):
        if not getattr(self, '_latlonextent', None):
        """
        Gets the size two list of tuples containing the latitide/longitude boundaries. This list is in the form [(lowerlat, lowerlon), (upperlat, upperlon)].

        Returns
        -------
        _latlon_extent : list
                         [(lowerlat, lowerlon), (upperlat, upperlon)]
        
        """
        if not getattr(self, '_latlon_extent', None):
            ext = self.extent
            llat, llon = self.pixel_to_latlon(ext[0][0], ext[0][1])
            ulat, ulon = self.pixel_to_latlon(ext[1][0], ext[1][1])
            self._latlonextent = [(llat, llon), (ulat, ulon)]
        return self._latlonextent
            self._latlon_extent = [(llat, llon), (ulat, ulon)]
        return self._latlon_extent

    @property
    def extent(self):
        if not getattr(self, '_extent', None):
    def xy_extent(self):
        """
        Gets the size two list of tuples containing the sample/line boundaries. This list is in the form [(minx, miny), (maxx, maxy)].

        Returns
        -------
        _xy_extent : list
                         [(minx, miny), (maxx, maxy)]
        
        """
        if not getattr(self, '_xy_extent', None):
            gt = self.geotransform
            minx = gt[0]
            maxy = gt[3]

            maxx = minx + gt[1] * self.ds.RasterXSize
            miny = maxy + gt[5] * self.ds.RasterYSize
            maxx = minx + gt[1] * self.dataset.RasterXSize
            miny = maxy + gt[5] * self.dataset.RasterYSize

            self._extent = [(minx, miny), (maxx, maxy)]
            self._xy_extent = [(minx, miny), (maxx, maxy)]

        return self._extent
        return self._xy_extent

    @property
    def xpixelsize(self):
    def pixel_width(self):
        """
        Get the pixel size of the input data
        Get the width of the pixels in the input image (i.e. the length in the x-direction).
        Note: This is the second value geotransform array.

        Returns
        -------
        _pixel_width : float
                       The width of each pixel.
        
        """
        if not getattr(self, '_xpixelsize', None):
            self._xpixelsize = self.geotransform[1]
        return self._xpixelsize
        if not getattr(self, '_pixel_width', None):
            self._pixel_width = self.geotransform[1]
        return self._pixel_width

    @property
    def ypixelsize(self):
        """
        The y-pixel size of the input data
    def pixel_height(self):
        """
        Get the height of the pixels in the input image (i.e the length in the y-direction).
        Note: This is the sixth (last) value geotransform array.

        if not getattr(self, '_ypixelsize', None):
            self._ypixelsize = self.geotransform[5]
        return self._ypixelsize
        Returns
        -------
        _pixel_height : float
                        The height of each pixel.
        
        """
        if not getattr(self, '_pixel_height', None):
            self._pixel_height = self.geotransform[5]
        return self._pixel_height

    @property
    def xrotation(self):
        """
        Get the geotransform rotation about the x-axis.
        Note: This is the third value geotransform array.

        Returns
        -------
        _xrotation : float
                     The geotransform coefficient representing rotation about the x-axis.
        
        """
        if not getattr(self, '_xrotation', None):
            self._xrotation = self.geotransform[2]
        return self._xrotation

    @property
    def yrotation(self):
        """
        Get the geotransform rotation about the y-axis.
        Note: This is the fifth value geotransform array.

        Returns
        -------
        _yrotation : float
                     The geotransform coefficient representing rotation about the y-axis.
        
        """
        if not getattr(self, '_yrotation', None):
            self._yrotation = self.geotransform[4]
        return self._yrotation

    @property
    def coordinate_transformation(self):
        """
        Gets the coordinate transformation from the spatial reference system to the geospatial coordinate system.

        Returns
        -------
        _ct : object
              The coordinate transformation. 
        
        """
        if not getattr(self, '_ct', None):
            self._ct = osr.CoordinateTransformation(self.spatialreference,
            self._ct = osr.CoordinateTransformation(self.spatial_reference,
                                                    self.geospatial_coordinate_system)
        return self._ct

    @property
    def inverse_coordinate_transformation(self):
        """
        Gets the coordinate transformation from the geospatial coordinate system to the spatial reference system.

        Returns
        -------
        _ict : object
               The inverse coordinate transformation.
        
        """
        if not getattr(self, '_ict', None):
                       self._ict = osr.CoordinateTransformation(self.geospatial_coordinate_system,
                                                                self.spatialreference)
                                                                self.spatial_reference)
        return self._ict

    @property
    def ndv(self, band=1):
        """
        Gets the no data value for the given band. This is used to indicate pixels that are not valid.

        Parameters
        ----------
        band : int
               The one-based index of the band. Default band=1.

        Returns
        -------
        _ndv : float
               Special value used to indicate invalid pixels.
        
        """
        if not getattr(self, '_ndv', None):
            self._ndv = self.ds.GetRasterBand(band).GetNoDataValue()
            self._ndv = self.dataset.GetRasterBand(band).GetNoDataValue()
        return self._ndv

    @property
    def scale(self):
        """
        Gets the name and value of the linear projection units of the spatial reference system. To transform a linear distance to meters, multiply by this value.
        If no units are available ("Meters", 1) will be returned.

        Returns
        -------
        _scale : tuple
                 A string/float tuple of the form (unit name, value)
                 
        """
        if not getattr(self, '_scale', None):
            unitname = self.spatialreference.GetLinearUnitsName()
            value = self.spatialreference.GetLinearUnits()
            unitname = self.spatial_reference.GetLinearUnitsName()
            value = self.spatial_reference.GetLinearUnits()
            self._scale = (unitname, value)
        return self._scale

    @property
    def spheroid(self):
        """
        Gets the spheroid found in the metadata using the spatial reference system. 

        Returns
        -------
        _spheroid : tuple
                    (semi-major, semi-minor, inverse flattening)
        
        """
        if not getattr(self, '_spheroid', None):
            self._spheroid = em.get_spheroid(self.spatialreference)
            self._spheroid = extract_metadata.get_spheroid(self.spatial_reference)
        return self._spheroid

    @property
    def rastersize(self):
        if not getattr(self, '_rastersize', None):
            self._rastersize = (self.ds.RasterXSize, self.ds.RasterYSize)
        return self._rastersize
    def raster_size(self):
        """
        Gets the dimensions of the raster, i.e. (number of samples, number of lines).

        Returns
        -------
        _raster_size : tuple
                       (x size, y size)
        
        """
        if not getattr(self, '_raster_size', None):
            self._raster_size = (self.dataset.RasterXSize, self.dataset.RasterYSize)
        return self._raster_size

    @property
    def central_meridian(self):
        """
        Gets the central meridian from the metadata.

        Returns
        -------
        _central_meridian : float

        """
        if not getattr(self, '_central_meridian', None):
            self._central_meridian = em.get_central_meridian(self.spatialreference)
            self._central_meridian = extract_metadata.get_central_meridian(self.spatial_reference)
        return self._central_meridian

    def pixel_to_latlon(self, x, y):
        """
        Convert from pixel space to lat/lon space
        Convert from pixel space (i.e. sample/line) to lat/lon space.

        Parameters
        ----------
        x : float
            x-coordinate
            x-coordinate to be transformed.
        y : float
            y-coordinates
            y-coordinate to be transformed.

        Returns
        -------
        lat, lon : tuple
                   Latitude, Longitude
                   (Latitude, Longitude) corresponding to the given (x,y).
        
        """
        gt = self.geotransform
        x = gt[0] + (x * gt[1]) + (y * gt[2])
@@ -250,28 +436,51 @@ class GeoDataSet(object):
        return lat, lon

    def latlon_to_pixel(self, lat, lon):
        """
        Convert from lat/lon space to pixel space (i.e. sample/line).

        Parameters
        ----------
        lat: float
             Latitude to be transformed.
        lon : float
              Longitude to be transformed.
        Returns
        -------
        x, y : tuple
               (Sample, line) position corresponding to the given (latitude, longitude).
        
        """
        gt = self.geotransform
        ulat, ulon, _ = self.inverse_coordinate_transformation.TransformPoint(lon, lat)
        x = (ulat - gt[0]) / gt[1]
        y = (ulon - gt[3]) / gt[5]
        return x, y

    def readarray(self, band=1, pixels=None, dtype='float32'):
    def read_array(self, band=1, pixels=None, dtype='float32'):
        """
        Extract the required data as a numpy array

        Parameters
        ----------
        band : int
               The image band number to be extracted as a numpy array. Default band=1.

        pixels : list
		  [start, ystart, xstop, ystop]
                 [start, ystart, xstop, ystop]. Default pixels=None.

        dtype : str
                numpy dtype, e.g. float32
                The numpy dtype for the output array. Default dtype='float32'.

        Returns
        -------
        array : NumPy array
                The dataset for the specified band.

        """
        band = self.ds.GetRasterBand(band)
        band = self.dataset.GetRasterBand(band)

        dtype = getattr(np, dtype)
        dtype = getattr(numpy, dtype)

        if pixels == None:
            array = band.ReadAsArray().astype(dtype)
@@ -284,9 +493,30 @@ class GeoDataSet(object):
                                          xextent, yextent).astype(dtype)
        return array

def array_to_raster(array, filename, projection=None,
def array_to_raster(array, file_name, projection=None,
                    geotransform=None, outformat='GTiff',
                    ndv=None):
    """

    Parameters
    ----------
    array : 

    file_name : str 

    projection : 
                 Default projection=None.

    geotransform : object 
                   Default geotransform=None.

    outformat : const char *
                Default outformat='GTiff'.

    ndv : 
          Default ndv=None.

    """
    driver = gdal.GetDriverByName(outformat)
    try:
        y, x, bands = array.shape
@@ -297,27 +527,27 @@ def array_to_raster(array, filename, projection=None,
        single = True

    #This is a crappy hard code to 32bit.
    ds = driver.Create(filename, x, y, bands, gdal.GDT_Float64)
    dataset = driver.Create(file_name, x, y, bands, gdal.GDT_Float64)

    if geotransform:
        ds.SetGeoTransform(geotransform)
        dataset.SetGeoTransform(geotransform)

    if projection:
        if isinstance(projection, str):
            ds.SetProjection(projection)
            dataset.SetProjection(projection)
        else:
            ds.SetProjection(projection.ExportToWkt())
            dataset.SetProjection(projection.ExportToWkt())

    if single == True:
        bnd = ds.GetRasterBand(1)
        bnd = dataset.GetRasterBand(1)
        if ndv != None:
            bnd.SetNoDataValue(ndv)
        bnd.WriteArray(array)
        ds.FlushCache()
        dataset.FlushCache()
    else:
        for i in range(1, bands + 1):
            bnd = ds.GetRasterBand(i)
            bnd = dataset.GetRasterBand(i)
            if ndv != None:
                bnd.SetNoDataValue(ndv)
            bnd.WriteArray(array[:,:,i - 1])
            ds.FlushCache()
            dataset.FlushCache()
+84 −81

File changed.

Preview size limit exceeded, changes collapsed.