import numpy as np
import torch
import argparse
import time

from torch import nn
from aimbotEnv import Aimbot
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, bias_const)
    return layer


class PPOAgent(nn.Module):
    def __init__(
            self,
            env: Aimbot,
            this_args: argparse.Namespace,
            device: torch.device,
    ):
        super(PPOAgent, self).__init__()
        self.device = device
        self.args = this_args
        self.train_agent = self.args.train
        self.target_num = self.args.target_num
        self.unity_observation_shape = env.unity_observation_shape
        self.unity_action_size = env.unity_action_size
        self.state_size = self.unity_observation_shape[0]
        self.agent_num = env.unity_agent_num
        self.target_size = self.args.target_state_size
        self.time_state_size = self.args.time_state_size
        self.gun_state_size = self.args.gun_state_size
        self.my_state_size = self.args.my_state_size
        self.ray_state_size = env.unity_observation_shape[0] - self.args.total_target_size
        self.state_size_without_ray = self.args.total_target_size
        self.head_input_size = (
                env.unity_observation_shape[0] - self.target_size - self.time_state_size - self.gun_state_size
        )  # except target state input

        self.unity_discrete_type = env.unity_discrete_type
        self.discrete_size = env.unity_discrete_size
        self.discrete_shape = list(env.unity_discrete_branches)
        self.continuous_size = env.unity_continuous_size

        self.hidden_networks = nn.ModuleList(
            [
                nn.Sequential(
                    layer_init(nn.Linear(self.state_size, 128)),
                    nn.LeakyReLU(),
                    layer_init(nn.Linear(128, 64)),
                    nn.LeakyReLU(),
                    )
                for i in range(self.target_num)
            ]
        )

        self.actor_dis = nn.ModuleList(
            [layer_init(nn.Linear(64, self.discrete_size), std=0.5) for i in range(self.target_num)]
        )
        self.actor_mean = nn.ModuleList(
            [layer_init(nn.Linear(64, self.continuous_size), std=0.5) for i in range(self.target_num)]
        )
        self.actor_logstd = nn.ParameterList(
            [nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
        )
        self.critic = nn.ModuleList(
            [layer_init(nn.Linear(64, 1), std=1) for i in range(self.target_num)]
        )

    def get_value(self, state: torch.Tensor):
        # get critic value
        # state.size()[0] is batch_size
        target = state[:, 0].to(torch.int32)  # int
        hidden_output = torch.stack(
            [self.hidden_networks[target[i]](state[i]) for i in range(state.size()[0])]
        )
        criticV = torch.stack(
            [self.critic[target[i]](hidden_output[i]) for i in range(state.size()[0])]
        )
        return criticV

    def get_actions_value(self, state: torch.Tensor, actions=None):
        # get actions and value
        target = state[:, 0].to(torch.int32)  # int
        hidden_output = torch.stack(
            [self.hidden_networks[target[i]](state[i]) for i in range(target.size()[0])]
        )

        # discrete
        # 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
        dis_logits = torch.stack(
            [self.actor_dis[target[i]](hidden_output[i]) for i in range(target.size()[0])]
        )
        split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
        multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
        # continuous
        actions_mean = torch.stack(
            [self.actor_mean[target[i]](hidden_output[i]) for i in range(target.size()[0])]
        )  # self.actor_mean(hidden)
        action_logstd = torch.stack(
            [torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(target.size()[0])]
        )
        # print(action_logstd)
        action_std = torch.exp(action_logstd)  # torch.exp(action_logstd)
        con_probs = Normal(actions_mean, action_std)
        # critic
        criticV = torch.stack(
            [self.critic[target[i]](hidden_output[i]) for i in range(target.size()[0])]
        )  # self.critic

        if actions is None:
            if self.train_agent:
                # select actions base on probability distribution model
                dis_act = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
                con_act = con_probs.sample()
                actions = torch.cat([dis_act.T, con_act], dim=1)
            else:
                # select actions base on best probability distribution
                dis_act = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])
                con_act = actions_mean
                actions = torch.cat([dis_act.T, con_act], dim=1)
        else:
            dis_act = actions[:, 0: self.unity_discrete_type].T
            con_act = actions[:, self.unity_discrete_type:]
        dis_log_prob = torch.stack(
            [ctgr.log_prob(act) for act, ctgr in zip(dis_act, multi_categoricals)]
        )
        dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
        return (
            actions,
            dis_log_prob.sum(0),
            dis_entropy.sum(0),
            con_probs.log_prob(con_act).sum(1),
            con_probs.entropy().sum(1),
            criticV,
        )

    def train_net(self, this_train_ind: int, ppo_memories, optimizer) -> tuple:
        start_time = time.time()
        # flatten the batch
        b_obs = ppo_memories.obs[this_train_ind].reshape((-1,) + self.unity_observation_shape)
        b_dis_logprobs = ppo_memories.dis_logprobs[this_train_ind].reshape(-1)
        b_con_logprobs = ppo_memories.con_logprobs[this_train_ind].reshape(-1)
        b_actions = ppo_memories.actions[this_train_ind].reshape((-1,) + (self.unity_action_size,))
        b_advantages = ppo_memories.advantages[this_train_ind].reshape(-1)
        b_returns = ppo_memories.returns[this_train_ind].reshape(-1)
        b_values = ppo_memories.values[this_train_ind].reshape(-1)
        b_size = b_obs.size()[0]
        # optimizing the policy and value network
        b_index = np.arange(b_size)

        for epoch in range(self.args.epochs):
            print("epoch:", epoch, end="")
            # shuffle all datasets
            np.random.shuffle(b_index)
            for start in range(0, b_size, self.args.minibatchSize):
                print(".", end="")
                end = start + self.args.minibatchSize
                mb_index = b_index[start:end]
                if np.size(mb_index) <= 1:
                    break
                mb_advantages = b_advantages[mb_index]

                # normalize advantages
                if self.args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                            mb_advantages.std() + 1e-8
                    )

                (
                    _,
                    new_dis_logprob,
                    dis_entropy,
                    new_con_logprob,
                    con_entropy,
                    new_value,
                ) = self.get_actions_value(b_obs[mb_index], b_actions[mb_index])
                # discrete ratio
                dis_log_ratio = new_dis_logprob - b_dis_logprobs[mb_index]
                dis_ratio = dis_log_ratio.exp()
                # continuous ratio
                con_log_ratio = new_con_logprob - b_con_logprobs[mb_index]
                con_ratio = con_log_ratio.exp()

                """
                # early stop
                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
                """

                # discrete Policy loss
                dis_pg_loss_orig = -mb_advantages * dis_ratio
                dis_pg_loss_clip = -mb_advantages * torch.clamp(
                    dis_ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef
                )
                dis_pg_loss = torch.max(dis_pg_loss_orig, dis_pg_loss_clip).mean()
                # continuous Policy loss
                con_pg_loss_orig = -mb_advantages * con_ratio
                con_pg_loss_clip = -mb_advantages * torch.clamp(
                    con_ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef
                )
                con_pg_loss = torch.max(con_pg_loss_orig, con_pg_loss_clip).mean()

                # Value loss
                new_value = new_value.view(-1)
                if self.args.clip_vloss:
                    v_loss_unclipped = (new_value - b_returns[mb_index]) ** 2
                    v_clipped = b_values[mb_index] + torch.clamp(
                        new_value - b_values[mb_index],
                        -self.args.clip_coef,
                        self.args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()

                # total loss
                entropy_loss = dis_entropy.mean() + con_entropy.mean()
                loss = (
                               dis_pg_loss * self.args.policy_coef[this_train_ind]
                               + con_pg_loss * self.args.policy_coef[this_train_ind]
                               + entropy_loss * self.args.entropy_coef[this_train_ind]
                               + v_loss * self.args.critic_coef[this_train_ind]
                       ) * self.args.loss_coef[this_train_ind]

                if torch.isnan(loss).any():
                    print("LOSS Include NAN!!!")
                    if torch.isnan(dis_pg_loss.any()):
                        print("dis_pg_loss include nan")
                    if torch.isnan(con_pg_loss.any()):
                        print("con_pg_loss include nan")
                    if torch.isnan(entropy_loss.any()):
                        print("entropy_loss include nan")
                    if torch.isnan(v_loss.any()):
                        print("v_loss include nan")
                    raise

                optimizer.zero_grad()
                loss.backward()
                # Clips gradient norm of an iterable of parameters.
                nn.utils.clip_grad_norm_(self.parameters(), self.args.max_grad_norm)
                optimizer.step()

            """
            if args.target_kl is not None:
                if approx_kl > args.target_kl:
                    break
            """
        return v_loss, dis_pg_loss, con_pg_loss, loss, entropy_loss

    def gae(
            self,
            rewards: torch.Tensor,
            dones: torch.Tensor,
            values: torch.Tensor,
            next_obs: torch.Tensor,
            next_done: torch.Tensor,
    ) -> tuple:
        # GAE
        with torch.no_grad():
            next_value = self.get_value(next_obs).reshape(1, -1)
            data_size = rewards.size()[0]
            if self.args.gae:
                advantages = torch.zeros_like(rewards).to(self.device)
                last_gae_lam = 0
                for t in reversed(range(data_size)):
                    if t == data_size - 1:
                        next_non_terminal = 1.0 - next_done
                        next_values = next_value
                    else:
                        next_non_terminal = 1.0 - dones[t + 1]
                        next_values = values[t + 1]
                    delta = rewards[t] + self.args.gamma * next_values * next_non_terminal - values[t]
                    advantages[t] = last_gae_lam = (
                            delta + self.args.gamma * self.args.gaeLambda * next_non_terminal * last_gae_lam
                    )
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(self.device)
                for t in reversed(range(data_size)):
                    if t == data_size - 1:
                        next_non_terminal = 1.0 - next_done
                        next_return = next_value
                    else:
                        next_non_terminal = 1.0 - dones[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = rewards[t] + self.args.gamma * next_non_terminal * next_return
                advantages = returns - values
        return advantages, returns