import numpy as np
import torch
import argparse
import time

from torch import nn
from aimbotEnv import Aimbot
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, bias_const)
    return layer


class PPOAgent(nn.Module):
    def __init__(
        self,
        env: Aimbot,
        this_args:argparse.Namespace,
        device: torch.device,
    ):
        super(PPOAgent, self).__init__()
        self.device = device
        self.args = this_args
        self.train_agent = self.args.train
        self.target_num = self.args.target_num
        self.unity_observation_shape = env.unity_observation_shape
        self.unity_action_size = env.unity_action_size
        self.state_size = self.unity_observation_shape[0]
        self.agent_num = env.unity_agent_num
        self.target_size = self.args.target_state_size
        self.time_state_size = self.args.time_state_size
        self.gun_state_size = self.args.gun_state_size
        self.my_state_size = self.args.my_state_size
        self.ray_state_size = env.unity_observation_shape[0] - self.args.total_target_size
        self.state_size_without_ray = self.args.total_target_size
        self.head_input_size = (
            env.unity_observation_shape[0] - self.target_size - self.time_state_size - self.gun_state_size
        )  # except target state input

        self.unity_discrete_type = env.unity_discrete_type
        self.discrete_size = env.unity_discrete_size
        self.discrete_shape = list(env.unity_discrete_branches)
        self.continuous_size = env.unity_continuous_size

        self.view_network = nn.Sequential(layer_init(nn.Linear(self.ray_state_size, 200)), nn.LeakyReLU())
        self.target_networks = nn.ModuleList(
            [
                nn.Sequential(layer_init(nn.Linear(self.state_size_without_ray, 100)), nn.LeakyReLU())
                for i in range(self.target_num)
            ]
        )
        self.middle_networks = nn.ModuleList(
            [
                nn.Sequential(layer_init(nn.Linear(300, 200)), nn.LeakyReLU())
                for i in range(self.target_num)
            ]
        )
        self.actor_dis = nn.ModuleList(
            [layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(self.target_num)]
        )
        self.actor_mean = nn.ModuleList(
            [layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(self.target_num)]
        )
        # self.actor_logstd = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=1) for i in range(targetNum)])
        # self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))
        self.actor_logstd = nn.ParameterList(
            [nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
        )  # nn.Parameter(torch.zeros(1, self.continuous_size))
        self.critic = nn.ModuleList(
            [layer_init(nn.Linear(200, 1), std=1) for i in range(self.target_num)]
        )

    def get_value(self, state: torch.Tensor):
        target = state[:, 0].to(torch.int32)  # int
        this_state_num = target.size()[0]
        view_input = state[:, -self.ray_state_size :]  # all ray input
        target_input = state[:, : self.state_size_without_ray]
        view_layer = self.view_network(view_input)
        target_layer = torch.stack(
            [self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
        )
        middle_input = torch.cat([view_layer, target_layer], dim=1)
        middle_layer = torch.stack(
            [self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
        )
        criticV = torch.stack(
            [self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
        )  # self.critic
        return criticV

    def get_actions_value(self, state: torch.Tensor, actions=None):
        target = state[:, 0].to(torch.int32)  # int
        this_state_num = target.size()[0]
        view_input = state[:, -self.ray_state_size :]  # all ray input
        target_input = state[:, : self.state_size_without_ray]
        view_layer = self.view_network(view_input)
        target_layer = torch.stack(
            [self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
        )
        middle_input = torch.cat([view_layer, target_layer], dim=1)
        middle_layer = torch.stack(
            [self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
        )

        # discrete
        # 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
        dis_logits = torch.stack(
            [self.actor_dis[target[i]](middle_layer[i]) for i in range(this_state_num)]
        )
        split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
        multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
        # continuous
        actions_mean = torch.stack(
            [self.actor_mean[target[i]](middle_layer[i]) for i in range(this_state_num)]
        )  # self.actor_mean(hidden)
        # action_logstd = torch.stack([self.actor_logstd[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.actor_logstd(hidden)
        # action_logstd = self.actor_logstd.expand_as(actions_mean) # self.actor_logstd.expand_as(actions_mean)
        action_logstd = torch.stack(
            [torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(this_state_num)]
        )
        # print(action_logstd)
        action_std = torch.exp(action_logstd)  # torch.exp(action_logstd)
        con_probs = Normal(actions_mean, action_std)
        # critic
        criticV = torch.stack(
            [self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
        )  # self.critic

        if actions is None:
            if self.train_agent:
                # select actions base on probability distribution model
                disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
                conAct = con_probs.sample()
                actions = torch.cat([disAct.T, conAct], dim=1)
            else:
                # select actions base on best probability distribution
                # disAct = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])
                conAct = actions_mean
                disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
                conAct = con_probs.sample()
                actions = torch.cat([disAct.T, conAct], dim=1)
        else:
            disAct = actions[:, 0 : self.unity_discrete_type].T
            conAct = actions[:, self.unity_discrete_type :]
        dis_log_prob = torch.stack(
            [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]
        )
        dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
        return (
            actions,
            dis_log_prob.sum(0),
            dis_entropy.sum(0),
            con_probs.log_prob(conAct).sum(1),
            con_probs.entropy().sum(1),
            criticV,
        )
    def train_net(self, this_train_ind:int,ppo_memories,optimizer) -> tuple:
        start_time = time.time()
        # flatten the batch
        b_obs = ppo_memories.obs[this_train_ind].reshape((-1,) + self.unity_observation_shape)
        b_dis_logprobs = ppo_memories.dis_logprobs[this_train_ind].reshape(-1)
        b_con_logprobs = ppo_memories.con_logprobs[this_train_ind].reshape(-1)
        b_actions = ppo_memories.actions[this_train_ind].reshape((-1,) + (self.unity_action_size,))
        b_advantages = ppo_memories.advantages[this_train_ind].reshape(-1)
        b_returns = ppo_memories.returns[this_train_ind].reshape(-1)
        b_values = ppo_memories.values[this_train_ind].reshape(-1)
        b_size = b_obs.size()[0]
        # optimizing the policy and value network
        b_inds = np.arange(b_size)
        
        for epoch in range(self.args.epochs):
            print("epoch:",epoch,end="")
            # shuffle all datasets
            np.random.shuffle(b_inds)
            for start in range(0, b_size, self.args.minibatchSize):
                print(".",end="")
                end = start + self.args.minibatchSize
                mb_inds = b_inds[start:end]
                if(np.size(mb_inds)<=1):
                    break
                mb_advantages = b_advantages[mb_inds]

                # normalize advantages
                if self.args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )

                (
                    _,
                    new_dis_logprob,
                    dis_entropy,
                    new_con_logprob,
                    con_entropy,
                    newvalue,
                ) = self.get_actions_value(b_obs[mb_inds], b_actions[mb_inds])
                # discrete ratio
                dis_logratio = new_dis_logprob - b_dis_logprobs[mb_inds]
                dis_ratio = dis_logratio.exp()
                # continuous ratio
                con_logratio = new_con_logprob - b_con_logprobs[mb_inds]
                con_ratio = con_logratio.exp()

                """
                # early stop
                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
                """

                # discrete Policy loss
                dis_pg_loss_orig = -mb_advantages * dis_ratio
                dis_pg_loss_clip = -mb_advantages * torch.clamp(
                    dis_ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef
                )
                dis_pg_loss = torch.max(dis_pg_loss_orig, dis_pg_loss_clip).mean()
                # continuous Policy loss
                con_pg_loss_orig = -mb_advantages * con_ratio
                con_pg_loss_clip = -mb_advantages * torch.clamp(
                    con_ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef
                )
                con_pg_loss = torch.max(con_pg_loss_orig, con_pg_loss_clip).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if self.args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -self.args.clip_coef,
                        self.args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                # total loss
                entropy_loss = dis_entropy.mean() + con_entropy.mean()
                loss = (
                    dis_pg_loss * self.args.policy_coef[this_train_ind]
                    + con_pg_loss * self.args.policy_coef[this_train_ind]
                    + entropy_loss * self.args.entropy_coef[this_train_ind]
                    + v_loss * self.args.critic_coef[this_train_ind]
                )*self.args.loss_coef[this_train_ind]

                if(torch.isnan(loss).any()):
                    print("LOSS Include NAN!!!")
                    if(torch.isnan(dis_pg_loss.any())):
                        print("dis_pg_loss include nan")
                    if(torch.isnan(con_pg_loss.any())):
                        print("con_pg_loss include nan")
                    if(torch.isnan(entropy_loss.any())):
                        print("entropy_loss include nan")
                    if(torch.isnan(v_loss.any())):
                        print("v_loss include nan")
                    raise

                optimizer.zero_grad()
                loss.backward()
                # Clips gradient norm of an iterable of parameters.
                nn.utils.clip_grad_norm_(self.parameters(), self.args.max_grad_norm)
                optimizer.step()

            """
            if args.target_kl is not None:
                if approx_kl > args.target_kl:
                    break
            """
        return (v_loss,dis_pg_loss,con_pg_loss,loss,entropy_loss)

    def gae(
        self,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        values: torch.tensor,
        next_obs: torch.tensor,
        next_done: torch.Tensor,
    ) -> tuple:
        # GAE
        with torch.no_grad():
            next_value = self.get_value(next_obs).reshape(1, -1)
            data_size = rewards.size()[0]
            if self.args.gae:
                advantages = torch.zeros_like(rewards).to(self.device)
                last_gae_lam = 0
                for t in reversed(range(data_size)):
                    if t == data_size - 1:
                        nextnonterminal = 1.0 - next_done
                        next_values = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        next_values = values[t + 1]
                    delta = rewards[t] + self.args.gamma * next_values * nextnonterminal - values[t]
                    advantages[t] = last_gae_lam = (
                        delta + self.args.gamma * self.args.gaeLambda * nextnonterminal * last_gae_lam
                    )
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(self.device)
                for t in reversed(range(data_size)):
                    if t == data_size - 1:
                        nextnonterminal = 1.0 - next_done
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = rewards[t] + self.args.gamma * nextnonterminal * next_return
                advantages = returns - values
        return advantages, returns