import torch
import numpy as np
import argparse
from ppoagent import PPOAgent
from enum import Enum

# public data
class Targets(Enum):
    Free = 0
    Go = 1
    Attack = 2
    Defence = 3
    Num = 4

class PPOMem:
    def __init__(
        self,
        args: argparse.Namespace,
        unity_agent_num: int,
        device: torch.device,
    ) -> None:
        self.target_num = args.target_num
        self.data_set_size = args.datasetSize
        self.result_broadcast_ratio = args.result_broadcast_ratio
        self.decision_period = args.decision_period
        self.unity_agent_num = unity_agent_num

        self.base_lose_reward = args.base_lose_reward
        self.base_win_reward = args.base_win_reward
        self.target_state_size = args.target_state_size
        self.device = device

        # Trajectory Buffer
        self.ob_bf = [[] for i in range(self.unity_agent_num)]
        self.act_bf = [[] for i in range(self.unity_agent_num)]
        self.dis_logprobs_bf = [[] for i in range(self.unity_agent_num)]
        self.con_logprobs_bf = [[] for i in range(self.unity_agent_num)]
        self.rewards_bf = [[] for i in range(self.unity_agent_num)]
        self.dones_bf = [[] for i in range(self.unity_agent_num)]
        self.values_bf = [[] for i in range(self.unity_agent_num)]

        # initialize empty training datasets
        self.obs = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,env.unity_observation_size)
        self.actions = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,env.unity_action_size)
        self.dis_logprobs = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,1)
        self.con_logprobs = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,1)
        self.rewards = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,1)
        self.values = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,1)
        self.advantages = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,1)
        self.returns = [torch.tensor([]).to(device) for i in range(self.target_num)]  # (TARGETNUM,n,1)

    def broad_cast_end_reward(self, rewardBF: list, remainTime: float) -> torch.Tensor:
        thisRewardBF = rewardBF.copy()
        if rewardBF[-1] <= -500:
            # print("Lose DO NOT BROAD CAST",rewardBF[-1])
            thisRewardBF[-1] = rewardBF[-1] - self.base_lose_reward
        elif rewardBF[-1] >= 500:
            # print("Win! Broadcast reward!",rewardBF[-1])
            print(sum(thisRewardBF) / len(thisRewardBF))
            thisRewardBF[-1] = rewardBF[-1] - self.base_win_reward
            thisRewardBF = (np.asarray(thisRewardBF) + (remainTime * self.result_broadcast_ratio)).tolist()
        else:
            print("!!!!!DIDNT GET RESULT REWARD!!!!!!", rewardBF[-1])
        return torch.tensor(thisRewardBF,dtype=torch.float32).to(self.device)

    def save_memories(
        self,
        now_step: int,
        agent: PPOAgent,
        state: np.ndarray,
        action_cpu: np.ndarray,
        dis_logprob_cpu: np.ndarray,
        con_logprob_cpu: np.ndarray,
        reward: list,
        done: list,
        value_cpu: np.ndarray,
        last_reward: list,
        next_done: list,
        next_state: np.ndarray,
    ):
        for i in range(self.unity_agent_num):
            if now_step % self.decision_period == 0 or next_done[i] == True:
                # only on decision period or finished a round, save memories to buffer
                self.ob_bf[i].append(state[i])
                self.act_bf[i].append(action_cpu[i])
                self.dis_logprobs_bf[i].append(dis_logprob_cpu[i])
                self.con_logprobs_bf[i].append(con_logprob_cpu[i])
                self.dones_bf[i].append(done[i])
                self.values_bf[i].append(value_cpu[i])
                if now_step % self.decision_period == 0:
                    # on decision period, add last skiped round's reward
                    self.rewards_bf[i].append(reward[i] + last_reward[i])
                else:
                    # not on decision period, only add this round's reward
                    self.rewards_bf[i].append(reward[i])
            if next_done[i] == True:
                # finished a round, send finished memories to training datasets
                # compute advantage and discounted reward
                remainTime = state[i, self.target_state_size]
                roundTargetType = int(state[i, 0])
                thisRewardsTensor = self.broad_cast_end_reward(self.rewards_bf[i], remainTime)
                adv, rt = agent.gae(
                    rewards=thisRewardsTensor,
                    dones=torch.tensor(self.dones_bf[i],dtype=torch.float32).to(self.device),
                    values=torch.tensor(self.values_bf[i]).to(self.device),
                    next_obs=torch.tensor(next_state[i]).to(self.device).unsqueeze(0),
                    next_done=torch.tensor([next_done[i]],dtype=torch.float32).to(self.device),
                )
                # send memories to training datasets
                self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(np.array(self.ob_bf[i])).to(self.device)), 0)
                self.actions[roundTargetType] = torch.cat((self.actions[roundTargetType], torch.tensor(np.array(self.act_bf[i])).to(self.device)), 0)
                self.dis_logprobs[roundTargetType] = torch.cat((self.dis_logprobs[roundTargetType], torch.tensor(np.array(self.dis_logprobs_bf[i])).to(self.device)), 0)
                self.con_logprobs[roundTargetType] = torch.cat((self.con_logprobs[roundTargetType], torch.tensor(np.array(self.con_logprobs_bf[i])).to(self.device)), 0)
                self.rewards[roundTargetType] = torch.cat((self.rewards[roundTargetType], thisRewardsTensor), 0)
                self.values[roundTargetType] = torch.cat((self.values[roundTargetType], torch.tensor(np.array(self.values_bf[i])).to(self.device)), 0)
                self.advantages[roundTargetType] = torch.cat((self.advantages[roundTargetType], adv), 0)
                self.returns[roundTargetType] = torch.cat((self.returns[roundTargetType], rt), 0)

                # clear buffers
                self.clear_buffers(i)
                print(f"train dataset {Targets(roundTargetType).name} added:{self.obs[roundTargetType].size()[0]}/{self.data_set_size}")

    def clear_buffers(self,ind:int):
        # clear buffers
        self.ob_bf[ind] = []
        self.act_bf[ind] = []
        self.dis_logprobs_bf[ind] = []
        self.con_logprobs_bf[ind] = []
        self.rewards_bf[ind] = []
        self.dones_bf[ind] = []
        self.values_bf[ind] = []

    def clear_training_datasets(self,ind:int):
        # clear training datasets
        self.obs[ind] = torch.tensor([]).to(self.device)
        self.actions[ind] = torch.tensor([]).to(self.device)
        self.dis_logprobs[ind] = torch.tensor([]).to(self.device)
        self.con_logprobs[ind] = torch.tensor([]).to(self.device)
        self.rewards[ind] = torch.tensor([]).to(self.device)
        self.values[ind] = torch.tensor([]).to(self.device)
        self.advantages[ind] = torch.tensor([]).to(self.device)
        self.returns[ind] = torch.tensor([]).to(self.device)