# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
from typing import Any, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge


class HRALayer(BaseTunerLayer):
    # All names of layers that may contain (trainable) adapter weights
    adapter_layer_names = ("hra_u",)
    # All names of other parameters that may contain adapter-related parameters
    other_param_names = ("hra_r", "hra_apply_GS")

    def __init__(self, base_layer: nn.Module, **kwargs) -> None:
        self.base_layer = base_layer
        self.hra_r = {}
        self.hra_apply_GS = {}
        self.hra_u = nn.ParameterDict({})
        # Mark the weight as unmerged
        self._disable_adapters = False
        self.merged_adapters = []
        # flag to enable/disable casting of input to weight dtype during forward call
        self.cast_input_dtype_enabled = True
        self.kwargs = kwargs

        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            self.in_features, self.out_features = base_layer.in_features, base_layer.out_features
        elif isinstance(base_layer, nn.Conv2d):
            self.in_features, self.out_features = base_layer.in_channels, base_layer.out_channels
        else:
            raise ValueError(f"Unsupported layer type {type(base_layer)}")

    def update_layer(
        self,
        adapter_name: str,
        r: int,
        apply_GS: bool,
        init_weights: bool,
        **kwargs,
    ) -> None:
        """Internal function to create hra adapter

        Args:
            adapter_name (`str`): Name for the adapter to add.
            r (`int`): Rank for the added adapter.
            init_weights (`bool`): Whether to initialize weights.
            apply_GS (`bool`): Whether to apply Gram-Schmidt orthogonalization or not.
        """
        if r <= 0:
            raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

        self.hra_r[adapter_name] = r
        self.hra_apply_GS[adapter_name] = apply_GS

        # Determine shape of HRA weights
        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            self.hra_u[adapter_name] = nn.Parameter(torch.empty(self.in_features, r), requires_grad=True)
        elif isinstance(base_layer, nn.Conv2d):
            self.hra_u[adapter_name] = nn.Parameter(
                torch.empty(self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], r),
                requires_grad=True,
            )
        else:
            raise TypeError(f"HRA is not implemented for base layers of type {type(base_layer).__name__}")

        # Initialize weights
        if init_weights:
            self.reset_hra_parameters(adapter_name)
        else:
            self.reset_hra_parameters_random(adapter_name)

        # Move new weights to device
        self._move_adapter_to_device_of_base_layer(adapter_name)
        self.set_adapter(self.active_adapters)

    def reset_hra_parameters(self, adapter_name: str):
        if self.hra_r[adapter_name] % 2 != 0:
            warnings.warn("The symmetric initialization can NOT be performed when r is odd!")
            nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5))
        else:
            shape = self.hra_u[adapter_name].shape
            half_u = torch.zeros(shape[0], shape[1] // 2)
            nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
            self.hra_u[adapter_name] = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1))

    def reset_hra_parameters_random(self, adapter_name: str):
        nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5))

    def scale_layer(self, scale: float) -> None:
        if scale == 1:
            return

        for active_adapter in self.active_adapters:
            if active_adapter not in self.hra_u.keys():
                continue

            warnings.warn("Scaling operation for HRA not supported! Automatically set scale to 1.")

    def unscale_layer(self, scale=None) -> None:
        for active_adapter in self.active_adapters:
            if active_adapter not in self.hra_u.keys():
                continue

            warnings.warn("Unscaling operation for HRA not supported! Keeping scale at 1.")


class HRALinear(nn.Module, HRALayer):
    """
    HRA implemented in a dense layer.
    """

    def __init__(
        self,
        base_layer,
        adapter_name: str,
        r: int = 0,
        apply_GS: bool = False,
        init_weights: Union[bool, str] = True,
        **kwargs,
    ) -> None:
        super().__init__()
        HRALayer.__init__(self, base_layer, **kwargs)
        self._active_adapter = adapter_name
        self.update_layer(adapter_name, r, apply_GS, init_weights, **kwargs)

    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
            adapter_names (`List[str]`, *optional*):
                The list of adapter names that should be merged. If `None`, all active adapters will be merged.
                Defaults to `None`.
        """
        adapter_names = check_adapters_to_merge(self, adapter_names)
        if not adapter_names:
            # no adapter to merge
            return

        for active_adapter in adapter_names:
            if active_adapter in self.hra_u.keys():
                base_layer = self.get_base_layer()
                orig_dtype = base_layer.weight.dtype
                if safe_merge:
                    # Note that safe_merge will be slower than the normal merge
                    # because of the copy operation.
                    orig_weight = base_layer.weight.data.clone()
                    delta_weight = self.get_delta_weight(active_adapter)
                    orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)

                    if not torch.isfinite(orig_weight).all():
                        raise ValueError(
                            f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                        )

                    base_layer.weight.data = orig_weight.to(orig_dtype)
                else:
                    delta_weight = self.get_delta_weight(active_adapter)
                    new_weight = torch.mm(base_layer.weight.data.to(delta_weight.dtype), delta_weight)
                    base_layer.weight.data = new_weight.to(orig_dtype)
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        """
        This method unmerges all merged adapter layers from the base weights.
        """
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return

        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            base_layer = self.get_base_layer()
            orig_dtype = base_layer.weight.dtype
            if active_adapter in self.hra_u.keys():
                orig_weight = base_layer.weight.data.clone()
                delta_weight = self.get_delta_weight(active_adapter, reverse=True)
                new_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
                base_layer.weight.data = new_weight.to(orig_dtype)

    def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor:
        rank = self.hra_r[adapter_name]
        apply_GS = self.hra_apply_GS[adapter_name]
        opt_u = self.hra_u[adapter_name]
        shape = opt_u.shape

        if apply_GS:
            weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)]
            for i in range(1, rank):
                ui = opt_u[:, i].view(-1, 1)
                for j in range(i):
                    ui = ui - (weight[j].t() @ ui) * weight[j]
                weight.append((ui / ui.norm()).view(-1, 1))
            weight = torch.cat(weight, dim=1)
            weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t()

        else:
            opt_u = opt_u / opt_u.norm(dim=0)
            weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype)
            if reverse:
                indices = range(rank - 1, -1, -1)
            else:
                indices = range(rank)

            for i in indices:
                ui = opt_u[:, i].view(-1, 1)
                weight = weight - 2 * weight @ ui @ ui.t()

        return weight

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            new_weight = torch.eye(self.in_features, device=x.device)

            for active_adapter in self.active_adapters:
                if active_adapter not in self.hra_u.keys():
                    continue
                delta_weight = self.get_delta_weight(active_adapter)
                new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)

            orig_weight = self.get_base_layer().weight.data
            orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
            new_weight = torch.mm(orig_weight, new_weight)
            bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)

            if self.cast_input_dtype_enabled:
                x = self._cast_input_dtype(x, new_weight.dtype)
            else:
                x = x.to(self.get_base_layer().weight.data.dtype)
            result = F.linear(input=x, weight=new_weight, bias=bias)

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "hra." + rep


class HRAConv2d(nn.Module, HRALayer):
    """HRA implemented in Conv2d layer"""

    def __init__(
        self,
        base_layer,
        adapter_name: str,
        r: int = 0,
        apply_GS: bool = False,
        init_weights: Union[bool, str] = True,
        **kwargs,
    ):
        super().__init__()
        HRALayer.__init__(self, base_layer)
        self._active_adapter = adapter_name
        self.update_layer(adapter_name, r, apply_GS, init_weights, **kwargs)

    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
            adapter_names (`List[str]`, *optional*):
                The list of adapter names that should be merged. If `None`, all active adapters will be merged.
                Defaults to `None`.
        """
        adapter_names = check_adapters_to_merge(self, adapter_names)
        if not adapter_names:
            # no adapter to merge
            return

        for active_adapter in adapter_names:
            if active_adapter in self.hra_u.keys():
                base_layer = self.get_base_layer()
                orig_dtype = base_layer.weight.dtype
                if safe_merge:
                    # Note that safe_merge will be slower than the normal merge
                    # because of the copy operation.
                    orig_weight = base_layer.weight.data.clone()
                    orig_weight = orig_weight.view(
                        self.out_features,
                        self.in_features * base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
                    )
                    delta_weight = self.get_delta_weight(active_adapter)
                    orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
                    orig_weight = orig_weight.view(
                        self.out_features,
                        self.in_features,
                        base_layer.kernel_size[0],
                        base_layer.kernel_size[0],
                    )

                    if not torch.isfinite(orig_weight).all():
                        raise ValueError(
                            f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                        )

                    base_layer.weight.data = orig_weight.to(orig_dtype)
                else:
                    orig_weight = base_layer.weight.data
                    orig_weight = orig_weight.view(
                        self.out_features,
                        self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
                    )
                    delta_weight = self.get_delta_weight(active_adapter)
                    orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
                    orig_weight = orig_weight.view(
                        self.out_features,
                        self.in_features,
                        base_layer.kernel_size[0],
                        base_layer.kernel_size[0],
                    )

                    base_layer.weight.data = orig_weight.to(orig_dtype)
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        """
        This method unmerges all merged adapter layers from the base weights.
        """
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            base_layer = self.get_base_layer()
            orig_dtype = base_layer.weight.dtype
            if active_adapter in self.hra_u.keys():
                orig_weight = base_layer.weight.data.clone()
                orig_weight = orig_weight.view(
                    self.out_features,
                    self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
                )
                delta_weight = self.get_delta_weight(active_adapter, reverse=True)
                orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
                orig_weight = orig_weight.view(
                    self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
                )

                base_layer.weight.data = orig_weight.to(orig_dtype)

    def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor:
        rank = self.hra_r[adapter_name]
        apply_GS = self.hra_apply_GS[adapter_name]
        opt_u = self.hra_u[adapter_name]
        shape = opt_u.shape

        if apply_GS:
            weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)]
            for i in range(1, rank):
                ui = opt_u[:, i].view(-1, 1)
                for j in range(i):
                    ui = ui - (weight[j].t() @ ui) * weight[j]
                weight.append((ui / ui.norm()).view(-1, 1))
            weight = torch.cat(weight, dim=1)
            weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t()

        else:
            opt_u = opt_u / opt_u.norm(dim=0)
            weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype)
            if reverse:
                indices = range(rank - 1, -1, -1)
            else:
                indices = range(rank)

            for i in indices:
                ui = opt_u[:, i].view(-1, 1)
                weight = weight - 2 * weight @ ui @ ui.t()

        return weight

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            new_weight = torch.eye(
                self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
                device=x.device,
            )
            for active_adapter in self.active_adapters:
                if active_adapter not in self.hra_u.keys():
                    continue
                delta_weight = self.get_delta_weight(active_adapter)
                new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)

            orig_weight = self.base_layer.weight.data
            orig_weight = orig_weight.view(
                self.out_features,
                self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
            )
            orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
            bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)

            new_weight = torch.mm(orig_weight, new_weight)
            new_weight = new_weight.view(
                self.out_features,
                self.in_features,
                self.base_layer.kernel_size[0],
                self.base_layer.kernel_size[0],
            )

            if self.cast_input_dtype_enabled:
                x = self._cast_input_dtype(x, new_weight.dtype)
            else:
                x = x.to(self.get_base_layer().weight.data.dtype)
            result = F.conv2d(
                input=x,
                weight=new_weight,
                bias=bias,
                padding=self.base_layer.padding[0],
                stride=self.base_layer.stride[0],
            )

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "hra." + rep
