Source code for intake_xarray.image


from intake.source.base import PatternMixin
from intake.source.utils import reverse_formats
from .base import DataSourceMixin, Schema


def _coerce_shape(array, shape):
    """ Trim or pad array to match desired shape"""
    import numpy as np

    if len(shape) != 2:
        raise ValueError('coerce_shape must be an iterable of len 2')

    target_shape = tuple(shape)
    actual_shape = array.shape
    ndims = len(actual_shape)

    if actual_shape[:2] == target_shape:
        # no trimming or padding needed
        return array

    # do any necessary trimming first
    for i, (a, t) in enumerate(zip(actual_shape[:2], target_shape)):
        if a > t:
            if i == 0:
                if ndims == 2:
                    array = array[:t, :]
                else:
                    array = array[:t, :, :]
            else:
                if ndims == 2:
                    array = array[:, :t]
                else:
                    array = array[:, :t, :]

    if array.shape[:2] == target_shape:
        # only needed trimming
        return array

    # create array of zeros and fill with trimmed value array
    if ndims == 2:
        new_array = np.zeros(target_shape, dtype=array.dtype)
        new_array[:array.shape[0], :array.shape[1]] = array
    else:
        new_array = np.zeros((target_shape[0],
                              target_shape[1],
                              actual_shape[2]), dtype=array.dtype)
        new_array[:array.shape[0], :array.shape[1], :] = array

    return new_array


def _dask_imread(files, imread=None, preprocess=None, coerce_shape=None):
    """ Read a stack of images into a dask array """
    from dask.array import Array
    from dask.base import tokenize
    from functools import partial

    if not imread:
        from skimage.io import imread

    def _imread(open_file):
        with open_file as f:
            return imread(f)

    def add_leading_dimension(x):
        return x[None, ...]

    filenames = [f.path for f in files]

    name = 'imread-%s' % tokenize(filenames)

    if coerce_shape is not None:
        reshape = partial(_coerce_shape, shape=coerce_shape)

    with files[0] as f:
        sample = imread(f)
    if coerce_shape is not None:
        sample = reshape(sample)
    if preprocess:
        sample = preprocess(sample)

    keys = [(name, i) + (0,) * len(sample.shape)
            for i in range(len(files))]

    if coerce_shape is not None:
        if preprocess:
            values = [(add_leading_dimension,
                       (preprocess,
                        (reshape,
                         (_imread, f))))
                      for f in files]
        else:
            values = [(add_leading_dimension,
                       (reshape,
                        (_imread, f)))
                      for f in files]
    elif preprocess:
        values = [(add_leading_dimension,
                   (preprocess,
                    (_imread, f)))
                  for f in files]
    else:
        values = [(add_leading_dimension,
                   (_imread, f))
                  for f in files]
    dsk = dict(zip(keys, values))

    chunks = ((1, ) * len(files), ) + tuple((d, ) for d in sample.shape)

    return Array(dsk, name, chunks, sample.dtype)


def reader(file, chunks, imread=None, preprocess=None, coerce_shape=None):
    """Read a file object and output an dask xarray object

    NOTE: inspired by dask.array.image.imread but altering the input to accept
    a just one file object.

    Parameters
    ----------
    file : OpenFile
        File object
    chunks : int or dict
        Chunks is used to load the new dataset into dask
        arrays. ``chunks={}`` loads the dataset with dask using a single
        chunk for all arrays.
    imread : function (optional)
        Optionally provide custom imread function.
        Function should expect a file object and produce a numpy array.
        Defaults to ``skimage.io.imread``.
    preprocess : function (optional)
        Optionally provide custom function to preprocess the image.
        Function should expect a numpy array for a single image.
    coerce_shape : tuple len 2 (optional)
        Optionally coerce the shape of the height and width of the image
        by setting `coerce_shape` to desired shape.

    Returns
    -------
    Dask xarray.DataArray of the image. Treated as one chunk unless
    chunks kwarg is specified.
    """
    import numpy as np
    from xarray import DataArray

    if not imread:
        from skimage.io import imread

    with file as f:
        array = imread(f)
    if coerce_shape is not None:
        array = _coerce_shape(sample, shape=coerce_shape)
    if preprocess:
        array = preprocess(array)

    ny, nx = array.shape[:2]
    coords = {'y': np.arange(ny),
              'x': np.arange(nx)}
    dims = ('y', 'x')

    if len(array.shape) == 3:
        nchannel = array.shape[2]
        coords['channel'] = np.arange(nchannel)
        dims += ('channel',)

    return DataArray(array, coords=coords, dims=dims).chunk(chunks=chunks)


def multireader(files, chunks, concat_dim, **kwargs):
    """Read a stack of images into a dask xarray object

    NOTE: copied from dask.array.image.imread but altering the input to accept
    a list of file objects.

    Parameters
    ----------
    files : iter
        List of file objects where each file contains data with the same
        shape. If this is not the case, use preprocess to coerce data into
        a shape
    chunks : int or dict
        Chunks is used to load the new dataset into dask
        arrays. ``chunks={}`` loads the dataset with dask using a single
        chunk for all arrays.
    concat_dim : str or iterable
        Dimension over which to concatenate. If iterable, all fields must be
        part of the the pattern.
    imread : function (optional)
        Optionally provide custom imread function.
        Function should expect a file object and produce a numpy array.
        Defaults to ``skimage.io.imread``.
    preprocess : function (optional)
        Optionally provide custom function to preprocess the image.
        Function should expect a numpy array for a single image.
    coerce_shape : iterable of len 2 (optional)
        Optionally coerce the shape of the height and width of the image
        by setting `coerce_shape` to desired shape.

    Returns
    -------
    Dask xarray.DataArray of all images stacked along the first dimension.
    All images will be treated as individual chunks unless
    chunks kwarg is specified.
    """
    import numpy as np
    from xarray import DataArray

    dask_array = _dask_imread(files, **kwargs)

    ny, nx = dask_array.shape[1:3]
    coords = {'y': np.arange(ny),
              'x': np.arange(nx)}
    if isinstance(concat_dim, list):
        dims = ('dim_0', 'y', 'x')
    else:
        dims = (concat_dim, 'y', 'x')

    if len(dask_array.shape) == 4:
        nchannel = dask_array.shape[3]
        coords['channel'] = np.arange(nchannel)
        dims += ('channel',)

    return DataArray(dask_array, coords=coords, dims=dims).chunk(chunks=chunks)


[docs]class ImageSource(DataSourceMixin, PatternMixin): """Open a xarray dataset from image files. This creates an xarray.DataArray or an xarray.Dataset. See http://scikit-image.org/docs/dev/api/skimage.io.html#skimage.io.imread for the file formats supported. NOTE: Although ``skimage.io.imread`` is used by default, any reader function which accepts a file object and outputs a numpy array can be used instead. Parameters ---------- urlpath : str or iterable, location of data May be a local path, or remote path if including a protocol specifier such as ``'s3://'``. May include glob wildcards or format pattern strings. Must be a format supported by ``skimage.io.imread`` or user-supplied ``imread``. Some examples: - ``{{ CATALOG_DIR }}/data/RGB.tif`` - ``s3://data/*.jpeg`` - ``https://example.com/image.png`` - ``s3://data/Images/{{ landuse }}/{{ '%02d' % id }}.tif`` chunks : int or dict Chunks is used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. path_as_pattern : bool or str, optional Whether to treat the path as a pattern (ie. ``data_{field}.tif``) and create new coodinates in the output corresponding to pattern fields. If str, is treated as pattern to match on. Default is True. concat_dim : str or iterable Dimension over which to concatenate. If iterable, all fields must be part of the the pattern. imread : function (optional) Optionally provide custom imread function. Function should expect a file object and produce a numpy array. Defaults to ``skimage.io.imread``. preprocess : function (optional) Optionally provide custom function to preprocess the image. Function should expect a numpy array for a single image and return a numpy array. coerce_shape : iterable of len 2 (optional) Optionally coerce the shape of the height and width of the image by setting `coerce_shape` to desired shape. """ name = 'xarray_image' def __init__(self, urlpath, chunks=None, concat_dim='concat_dim', metadata=None, path_as_pattern=True, storage_options=None, **kwargs): self.path_as_pattern = path_as_pattern self.urlpath = urlpath self.chunks = chunks self.concat_dim = concat_dim self.storage_options = storage_options or {} self._kwargs = kwargs self._ds = None super(ImageSource, self).__init__(metadata=metadata) def _open_files(self, files): """ This function is called when the data source refers to more than one file either as a list or a glob. It sets up the dask graph for opening the files. Parameters ---------- files : iter List of file objects """ import pandas as pd from xarray import DataArray out = multireader(files, self.chunks, self.concat_dim, **self._kwargs) if not self.pattern: return out coords = {} filenames = [f.path for f in files] field_values = reverse_formats(self.pattern, filenames) if isinstance(self.concat_dim, list): if not set(field_values.keys()).issuperset(set(self.concat_dim)): raise KeyError('All concat_dims should be in pattern.') index = pd.MultiIndex.from_tuples( zip(*(field_values[dim] for dim in self.concat_dim)), names=self.concat_dim) coords = { k: DataArray(v, dims=('dim_0')) for k, v in field_values.items() if k not in self.concat_dim } out = (out.assign_coords(dim_0=index, **coords) # use the index .unstack().chunk(self.chunks)) # unstack along new index return out.transpose(*self.concat_dim, # reorder dims *filter(lambda x: x not in self.concat_dim, out.dims)) else: coords = { k: DataArray(v, dims=self.concat_dim) for k, v in field_values.items() } return out.assign_coords(**coords).chunk(self.chunks) def _open_dataset(self): """ Main entry function that finds a set of files and passes them to the reader. """ from dask.bytes import open_files files = open_files(self.urlpath, **self.storage_options) if len(files) == 0: raise Exception("No files found at {}".format(self.urlpath)) if len(files) == 1: self._ds = reader(files[0], self.chunks, **self._kwargs) else: self._ds = self._open_files(files) def _get_schema(self): """Make schema object, which embeds xarray object and some details""" import xarray as xr import msgpack from .xarray_container import serialize_zarr_ds self.urlpath, *_ = self._get_cache(self.urlpath) if self._ds is None: self._open_dataset() # convert to dataset for serialization ds2 = xr.Dataset({'raster': self._ds}) metadata = { 'dims': dict(ds2.dims), 'data_vars': {k: list(ds2[k].coords) for k in ds2.data_vars.keys()}, 'coords': tuple(ds2.coords.keys()), 'array': 'raster' } if getattr(self, 'on_server', False): metadata['internal'] = serialize_zarr_ds(ds2) for k, v in self._ds.attrs.items(): try: # ensure only sending serializable attrs from remote msgpack.packb(v) metadata[k] = v except TypeError: pass self._schema = Schema( datashape=None, dtype=str(self._ds.dtype), shape=self._ds.shape, npartitions=self._ds.data.npartitions, extra_metadata=metadata) return self._schema