import itertools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional

import numpy as np
import pyarrow as pa

import datasets
from datasets.features.features import (
    Array2D,
    Array3D,
    Array4D,
    Array5D,
    Features,
    LargeList,
    List,
    Value,
    _ArrayXD,
    _arrow_to_datasets_dtype,
)
from datasets.table import cast_table_to_features


if TYPE_CHECKING:
    import h5py

logger = datasets.utils.logging.get_logger(__name__)

EXTENSIONS = [".h5", ".hdf5"]


@dataclass
class HDF5Config(datasets.BuilderConfig):
    """BuilderConfig for HDF5."""

    batch_size: Optional[int] = None
    features: Optional[datasets.Features] = None


class HDF5(datasets.ArrowBasedBuilder):
    """ArrowBasedBuilder that converts HDF5 files to Arrow tables using the HF extension types."""

    BUILDER_CONFIG_CLASS = HDF5Config

    def _info(self):
        return datasets.DatasetInfo(features=self.config.features)

    def _split_generators(self, dl_manager):
        import h5py

        if not self.config.data_files:
            raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
        dl_manager.download_config.extract_on_the_fly = True
        data_files = dl_manager.download_and_extract(self.config.data_files)
        splits = []
        for split_name, files in data_files.items():
            if isinstance(files, str):
                files = [files]

            files = [dl_manager.iter_files(file) for file in files]
            # Infer features from first file
            if self.info.features is None:
                for first_file in itertools.chain.from_iterable(files):
                    with h5py.File(first_file, "r") as h5:
                        self.info.features = _recursive_infer_features(h5)
                    break
            splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
        return splits

    def _generate_tables(self, files):
        import h5py

        batch_size_cfg = self.config.batch_size
        for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
            try:
                with h5py.File(file, "r") as h5:
                    # Infer features and lengths from first file
                    if self.info.features is None:
                        self.info.features = _recursive_infer_features(h5)
                    num_rows = _check_dataset_lengths(h5, self.info.features)
                    if num_rows is None:
                        logger.warning(f"File {file} contains no data, skipping...")
                        continue
                    effective_batch = batch_size_cfg or self._writer_batch_size or num_rows
                    for start in range(0, num_rows, effective_batch):
                        end = min(start + effective_batch, num_rows)
                        pa_table = _recursive_load_arrays(h5, self.info.features, start, end)
                        if pa_table is None:
                            logger.warning(f"File {file} contains no data, skipping...")
                            continue
                        yield f"{file_idx}_{start}", cast_table_to_features(pa_table, self.info.features)
            except ValueError as e:
                logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
                raise


# ┌───────────┐
# │  Complex  │
# └───────────┘


def _is_complex_dtype(dtype: np.dtype) -> bool:
    if dtype.kind == "c":
        return True
    if dtype.subdtype is not None:
        return _is_complex_dtype(dtype.subdtype[0])
    return False


def _create_complex_features(dset) -> Features:
    if dset.dtype.subdtype is not None:
        dtype, data_shape = dset.dtype.subdtype
    else:
        data_shape = dset.shape[1:]
        dtype = dset.dtype

    if dtype == np.complex64:
        # two float32s
        value_type = Value("float32")
    elif dtype == np.complex128:
        # two float64s
        value_type = Value("float64")
    else:
        logger.warning(f"Found complex dtype {dtype} that is not supported. Converting to float64...")
        value_type = Value("float64")

    return Features(
        {
            "real": _create_sized_feature_impl(data_shape, value_type),
            "imag": _create_sized_feature_impl(data_shape, value_type),
        }
    )


def _convert_complex_to_nested(arr: np.ndarray) -> pa.StructArray:
    data = {
        "real": datasets.features.features.numpy_to_pyarrow_listarray(arr.real),
        "imag": datasets.features.features.numpy_to_pyarrow_listarray(arr.imag),
    }
    return pa.StructArray.from_arrays([data["real"], data["imag"]], names=["real", "imag"])


# ┌────────────┐
# │  Compound  │
# └────────────┘


def _is_compound_dtype(dtype: np.dtype) -> bool:
    return dtype.kind == "V"


@dataclass
class _CompoundGroup:
    dset: "h5py.Dataset"
    data: np.ndarray = None

    def items(self):
        for field_name in self.dset.dtype.names:
            field_dtype = self.dset.dtype[field_name]
            yield field_name, _CompoundField(self.data, field_name, field_dtype)


@dataclass
class _CompoundField:
    data: Optional[np.ndarray]
    name: str
    dtype: np.dtype
    shape: tuple[int, ...] = field(init=False)

    def __post_init__(self):
        self.shape = (len(self.data) if self.data is not None else 0,) + self.dtype.shape

    def __getitem__(self, key):
        return self.data[key][self.name]


def _create_compound_features(dset) -> Features:
    mock_group = _CompoundGroup(dset)
    return _recursive_infer_features(mock_group)


def _convert_compound_to_nested(arr, dset) -> pa.StructArray:
    mock_group = _CompoundGroup(dset, data=arr)
    features = _create_compound_features(dset)
    return _recursive_load_arrays(mock_group, features, 0, len(arr))


# ┌───────────────────┐
# │  Variable-Length  │
# └───────────────────┘


def _is_vlen_dtype(dtype: np.dtype) -> bool:
    if dtype.metadata and "vlen" in dtype.metadata:
        return True
    return False


def _create_vlen_features(dset) -> Features:
    vlen_dtype = dset.dtype.metadata["vlen"]
    if vlen_dtype in (str, bytes):
        return Value("string")
    inner_feature = _np_to_pa_to_hf_value(vlen_dtype)
    return List(inner_feature)


def _convert_vlen_to_array(arr: np.ndarray) -> pa.Array:
    return datasets.features.features.numpy_to_pyarrow_listarray(arr)


# ┌───────────┐
# │  Generic  │
# └───────────┘


def _recursive_infer_features(h5_obj) -> Features:
    features_dict = {}
    for path, dset in h5_obj.items():
        if _is_group(dset):
            features = _recursive_infer_features(dset)
            if features:
                features_dict[path] = features
        elif _is_dataset(dset):
            features = _infer_feature(dset)
            if features:
                features_dict[path] = features

    return Features(features_dict)


def _infer_feature(dset):
    if _is_complex_dtype(dset.dtype):
        return _create_complex_features(dset)
    elif _is_compound_dtype(dset.dtype) or dset.dtype.kind == "V":
        return _create_compound_features(dset)
    elif _is_vlen_dtype(dset.dtype):
        return _create_vlen_features(dset)
    return _create_sized_feature(dset)


def _load_array(dset, path: str, start: int, end: int) -> pa.Array:
    arr = dset[start:end]

    if _is_vlen_dtype(dset.dtype):
        return _convert_vlen_to_array(arr)
    elif _is_complex_dtype(dset.dtype):
        return _convert_complex_to_nested(arr)
    elif _is_compound_dtype(dset.dtype):
        return _convert_compound_to_nested(arr, dset)
    elif dset.dtype.kind == "O":
        raise ValueError(
            f"Object dtype dataset '{path}' is not supported. "
            f"For variable-length data, please use h5py.vlen_dtype() "
            f"when creating the HDF5 file. "
            f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings"
        )
    else:
        # If any non-batch dimension is zero, emit an unsized pa.list_
        # to avoid creating FixedSizeListArray with list_size=0.
        if any(dim == 0 for dim in dset.shape[1:]):
            inner_type = pa.from_numpy_dtype(dset.dtype)
            return pa.array([[] for _ in arr], type=pa.list_(inner_type))
        else:
            return datasets.features.features.numpy_to_pyarrow_listarray(arr)


def _recursive_load_arrays(h5_obj, features: Features, start: int, end: int):
    batch_dict = {}
    for path, dset in h5_obj.items():
        if path not in features:
            continue
        if _is_group(dset):
            arr = _recursive_load_arrays(dset, features[path], start, end)
        elif _is_dataset(dset):
            arr = _load_array(dset, path, start, end)
        else:
            raise ValueError(f"Unexpected type {type(dset)}")

        if arr is not None:
            batch_dict[path] = arr

    if _is_file(h5_obj):
        return pa.Table.from_pydict(batch_dict)

    if batch_dict:
        should_chunk, keys, values = False, [], []
        for k, v in batch_dict.items():
            if isinstance(v, pa.ChunkedArray):
                should_chunk = True
                v = v.combine_chunks()
            keys.append(k)
            values.append(v)

        sarr = pa.StructArray.from_arrays(values, names=keys)
        return pa.chunked_array(sarr) if should_chunk else sarr


# ┌─────────────┐
# │  Utilities  │
# └─────────────┘


def _create_sized_feature(dset):
    dset_shape = dset.shape[1:]
    value_feature = _np_to_pa_to_hf_value(dset.dtype)
    return _create_sized_feature_impl(dset_shape, value_feature)


def _create_sized_feature_impl(dset_shape, value_feature):
    dtype_str = value_feature.dtype
    if any(dim == 0 for dim in dset_shape):
        logger.warning(
            f"HDF5 to Arrow: Found a dataset with shape {dset_shape} and dtype {dtype_str} that has a dimension with size 0. Shape information will be lost in the conversion to List({value_feature})."
        )
        return List(value_feature)

    rank = len(dset_shape)
    if rank == 0:
        return value_feature
    elif rank == 1:
        return List(value_feature, length=dset_shape[0])
    elif rank <= 5:
        return _sized_arrayxd(rank)(shape=dset_shape, dtype=dtype_str)
    else:
        raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.")


def _sized_arrayxd(rank: int):
    return {2: Array2D, 3: Array3D, 4: Array4D, 5: Array5D}[rank]


def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value:
    return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype)))


def _first_dataset(h5_obj, features: Features, prefix=""):
    for path, dset in h5_obj.items():
        if path not in features:
            continue
        if _is_group(dset):
            found = _first_dataset(dset, features[path], prefix=f"{prefix}{path}/")
            if found is not None:
                return found
        elif _is_dataset(dset):
            return f"{prefix}{path}"


def _check_dataset_lengths(h5_obj, features: Features) -> int:
    first_path = _first_dataset(h5_obj, features)
    if first_path is None:
        return None

    num_rows = h5_obj[first_path].shape[0]
    for path, dset in h5_obj.items():
        if path not in features:
            continue
        if _is_dataset(dset):
            if dset.shape[0] != num_rows:
                raise ValueError(f"Dataset '{path}' has length {dset.shape[0]} but expected {num_rows}")
    return num_rows


def _is_group(h5_obj) -> bool:
    import h5py

    return isinstance(h5_obj, h5py.Group) or isinstance(h5_obj, _CompoundGroup)


def _is_dataset(h5_obj) -> bool:
    import h5py

    return isinstance(h5_obj, h5py.Dataset) or isinstance(h5_obj, _CompoundField)


def _is_file(h5_obj) -> bool:
    import h5py

    return isinstance(h5_obj, h5py.File)


def _has_zero_dimensions(feature):
    if isinstance(feature, _ArrayXD):
        return any(dim == 0 for dim in feature.shape)
    elif isinstance(feature, List):
        return feature.length == 0 or _has_zero_dimensions(feature.feature)
    elif isinstance(feature, LargeList):
        return _has_zero_dimensions(feature.feature)
    else:
        return False
