Source code for nxtomomill.utils.frameappender

# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/

__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "03/02/2022"


import os
from typing import Union

import h5py
import h5py._hl.selections as selection
import numpy
from h5py import h5s as h5py_h5s
from silx.io.url import DataUrl
from silx.io.utils import get_data, h5py_read_dataset
from tomoscan.esrf.scan.utils import cwd_context
from tomoscan.io import HDF5File
from silx.io.utils import open as open_hdf5

from nxtomo.io import to_target_rel_path
from nxtomomill.utils.h5pyutils import from_data_url_to_virtual_source
from nxtomomill.utils.hdf5 import DatasetReader


[docs]class FrameAppender: """ Class to insert 2D frame(s) to an existing dataset """
[docs] def __init__( self, data: Union[numpy.ndarray, DataUrl], file_path, data_path, where, logger=None, ): if where not in ("start", "end"): raise ValueError("`where` should be `start` or `end`") if not isinstance( data, (DataUrl, numpy.ndarray, list, tuple, h5py.VirtualSource) ): raise TypeError( f"data should be an instance of DataUrl or a numpy array not {type(data)}" ) self.data = data self.file_path = os.path.abspath(file_path) self.data_path = data_path self.where = where self.logger = logger
[docs] def process(self) -> None: """ main function. Will start the insertion of frame(s) """ with HDF5File(self.file_path, mode="a") as h5s: if self.data_path in h5s: self._add_to_existing_dataset(h5s) else: self._create_new_dataset(h5s) if self.logger: self.logger.info(f"data added to {self.data_path}@{self.file_path}")
def _add_to_existing_virtual_dataset(self, h5s): if ( h5py.version.hdf5_version_tuple[0] <= 1 and h5py.version.hdf5_version_tuple[1] < 12 ): if self.logger: self.logger.warning( "You are working on virtual dataset" "with a hdf5 version < 12. Frame " "you want to change might be " "modified depending on the working " "directory without notifying." "See https://github.com/silx-kit/silx/issues/3277" ) if isinstance(self.data, h5py.VirtualSource): self.__insert_virtual_source_in_vds(h5s=h5s, new_virtual_source=self.data) elif isinstance(self.data, DataUrl): if self.logger is not None: self.logger.debug( f"Update virtual dataset: {self.data_path}@{self.file_path}" ) # store DataUrl in the current virtual dataset url = self.data def check_dataset(dataset_frm_url): data_need_reshape = False """check if the dataset is valid or might need a reshape""" if dataset_frm_url.ndim not in (2, 3): raise ValueError(f"{url.path()} should point to 2D or 3D dataset ") if dataset_frm_url.ndim == 2: new_shape = 1, dataset_frm_url.shape[0], dataset_frm_url.shape[1] if self.logger is not None: self.logger.info( f"reshape provided data to 3D (from {dataset_frm_url.shape} to {new_shape})" ) data_need_reshape = True return data_need_reshape loaded_dataset = None if url.data_slice() is None: # case we can avoid to load the data in memory with DatasetReader(url) as data_frm_url: data_need_reshape = check_dataset(data_frm_url) # FIXME: avoid keeping some file open. not clear why this is needed data_frm_url = None else: data_frm_url = get_data(url) data_need_reshape = check_dataset(data_frm_url) loaded_dataset = data_frm_url if url.data_slice() is None and not data_need_reshape: # case we can avoid to load the data in memory with DatasetReader(self.data) as data_frm_url: self.__insert_url_in_vds(h5s, url, data_frm_url) # FIXME: avoid keeping some file open. not clear why this is needed data_frm_url = None else: if loaded_dataset is None: data_frm_url = get_data(url) else: data_frm_url = loaded_dataset self.__insert_url_in_vds(h5s, url, data_frm_url) else: raise TypeError( "Provided data is a numpy array when given" "dataset path is a virtual dataset. " "You must store the data somewhere else " "and provide a DataUrl" ) def __insert_url_in_vds(self, h5s, url, data_frm_url): if data_frm_url.ndim == 2: dim_2, dim_1 = data_frm_url.shape data_frm_url = data_frm_url.reshape(1, dim_2, dim_1) elif data_frm_url.ndim == 3: _, dim_2, dim_1 = data_frm_url.shape else: raise ValueError("data to had is expected to be 2 or 3 d") new_virtual_source = h5py.VirtualSource( path_or_dataset=url.file_path(), name=url.data_path(), shape=data_frm_url.shape, ) if url.data_slice() is not None: # in the case we have to process to a FancySelection with open_hdf5(os.path.abspath(url.file_path())) as h5sd: dst = h5sd[url.data_path()] sel = selection.select( h5sd[url.data_path()].shape, url.data_slice(), dst ) new_virtual_source.sel = sel self.__insert_virtual_source_in_vds( h5s=h5s, new_virtual_source=new_virtual_source, relative_path=True ) def __insert_virtual_source_in_vds( self, h5s, new_virtual_source: h5py.VirtualSource, relative_path=True ): if not isinstance(new_virtual_source, h5py.VirtualSource): raise TypeError( f"{new_virtual_source} is expected to be an instance of h5py.VirtualSource and not {type(new_virtual_source)}" ) if not len(new_virtual_source.shape) == 3: raise ValueError( f"virtual source shape is expected to be 3D and not {len(new_virtual_source.shape)}D." ) # preprocess virtualSource to insure having a relative path if relative_path: vds_file_path = to_target_rel_path(new_virtual_source.path, self.file_path) new_virtual_source_sel = new_virtual_source.sel new_virtual_source = h5py.VirtualSource( path_or_dataset=vds_file_path, name=new_virtual_source.name, shape=new_virtual_source.shape, dtype=new_virtual_source.dtype, ) new_virtual_source.sel = new_virtual_source_sel virtual_sources_len = [] virtual_sources = [] # we need to recreate the VirtualSource they are not # store or available from the API for vs_info in h5s[self.data_path].virtual_sources(): length, vs = self._recreate_vs(vs_info=vs_info, vds_file=self.file_path) virtual_sources.append(vs) virtual_sources_len.append(length) n_frames = h5s[self.data_path].shape[0] + new_virtual_source.shape[0] data_type = h5s[self.data_path].dtype if self.where == "start": virtual_sources.insert(0, new_virtual_source) virtual_sources_len.insert(0, new_virtual_source.shape[0]) else: virtual_sources.append(new_virtual_source) virtual_sources_len.append(new_virtual_source.shape[0]) # create the new virtual dataset layout = h5py.VirtualLayout( shape=( n_frames, new_virtual_source.shape[-2], new_virtual_source.shape[-1], ), dtype=data_type, ) last = 0 for v_source, vs_len in zip(virtual_sources, virtual_sources_len): layout[last : vs_len + last] = v_source last += vs_len if self.data_path in h5s: del h5s[self.data_path] h5s.create_virtual_dataset(self.data_path, layout) def _add_to_existing_none_virtual_dataset(self, h5s): """ for now when we want to add data *to a none virtual dataset* we always duplicate data if provided from a DataUrl. We could create a virtual dataset as well but seems to complicated for a use case that we don't really have at the moment. :param h5s: """ if self.logger is not None: self.logger.debug("Update dataset: {entry}@{file_path}") if isinstance(self.data, (numpy.ndarray, list, tuple)): new_data = self.data else: url = self.data new_data = get_data(url) if isinstance(new_data, numpy.ndarray): if not new_data.shape[1:] == h5s[self.data_path].shape[1:]: raise ValueError( f"Data shapes are incoherent: {new_data.shape} vs {h5s[self.data_path].shape}" ) new_shape = ( new_data.shape[0] + h5s[self.data_path].shape[0], new_data.shape[1], new_data.shape[2], ) data_to_store = numpy.empty(new_shape) if self.where == "start": data_to_store[: new_data.shape[0]] = new_data data_to_store[new_data.shape[0] :] = h5py_read_dataset( h5s[self.data_path] ) else: data_to_store[: h5s[self.data_path].shape[0]] = h5py_read_dataset( h5s[self.data_path] ) data_to_store[h5s[self.data_path].shape[0] :] = new_data else: assert isinstance( self.data, (list, tuple) ), f"Unmanaged data type {type(self.data)}" o_data = h5s[self.data_path] o_data = list(h5py_read_dataset(o_data)) if self.where == "start": new_data.extend(o_data) data_to_store = numpy.asarray(new_data) else: o_data.extend(new_data) data_to_store = numpy.asarray(o_data) del h5s[self.data_path] h5s[self.data_path] = data_to_store def _add_to_existing_dataset(self, h5s): """Add the frame to an existing dataset""" if h5s[self.data_path].is_virtual: self._add_to_existing_virtual_dataset(h5s=h5s) else: self._add_to_existing_none_virtual_dataset(h5s=h5s) def _create_new_dataset(self, h5s): """ needs to create a new dataset. In this case the policy is: - if a DataUrl is provided then we create a virtual dataset - if a numpy array is provided then we create a 'standard' dataset """ if isinstance(self.data, DataUrl): url = self.data url_file_path = to_target_rel_path(url.file_path(), self.file_path) url = DataUrl( file_path=url_file_path, data_path=url.data_path(), scheme=url.scheme(), data_slice=url.data_slice(), ) with cwd_context(os.path.dirname(self.file_path)): vs, vs_shape, data_type = from_data_url_to_virtual_source(url) layout = h5py.VirtualLayout(shape=vs_shape, dtype=data_type) layout[:] = vs h5s.create_virtual_dataset(self.data_path, layout) elif isinstance(self.data, h5py.VirtualSource): virtual_source = self.data layout = h5py.VirtualLayout( shape=virtual_source.shape, dtype=virtual_source.dtype, ) vds_file_path = to_target_rel_path(virtual_source.path, self.file_path) virtual_source_rel_path = h5py.VirtualSource( path_or_dataset=vds_file_path, name=virtual_source.name, shape=virtual_source.shape, dtype=virtual_source.dtype, ) virtual_source_rel_path.sel = virtual_source.sel layout[:] = virtual_source_rel_path # convert path to relative h5s.create_virtual_dataset(self.data_path, layout) elif not isinstance(self.data, numpy.ndarray): raise TypeError( f"self.data should be an instance of DataUrl, a numpy array or a VirtualSource. Not {type(self.data)}" ) else: h5s[self.data_path] = self.data @staticmethod def _recreate_vs(vs_info, vds_file): """Simple util to retrieve a h5py.VirtualSource from virtual source information. to understand clearly this function you might first have a look at the use case exposed in issue: https://gitlab.esrf.fr/tomotools/nxtomomill/-/issues/40 """ with cwd_context(os.path.dirname(vds_file)): dataset_file_path = vs_info.file_name # in case the virtual source is in the same file if dataset_file_path == ".": dataset_file_path = vds_file with open_hdf5(dataset_file_path) as vs_node: dataset = vs_node[vs_info.dset_name] select_bounds = vs_info.vspace.get_select_bounds() left_bound = select_bounds[0] right_bound = select_bounds[1] length = right_bound[0] - left_bound[0] + 1 # warning: for now step is not managed with virtual # dataset virtual_source = h5py.VirtualSource( vs_info.file_name, vs_info.dset_name, shape=dataset.shape, ) # here we could provide dataset but we won't to # insure file path will be relative. type_code = vs_info.src_space.get_select_type() # check for unlimited selections in case where selection is regular # hyperslab, which is the only allowed case for h5s.UNLIMITED to be # in the selection if ( type_code == h5py_h5s.SEL_HYPERSLABS and vs_info.src_space.is_regular_hyperslab() ): ( source_start, stride, count, block, ) = vs_info.src_space.get_regular_hyperslab() source_end = source_start[0] + length sel = selection.select( dataset.shape, slice(source_start[0], source_end), dataset=dataset, ) virtual_source.sel = sel return ( length, virtual_source, )