import time
import numpy as np
import random
import uuid
import torch
import atexit

from aimbotEnv import Aimbot
from aimbotEnv import AimbotSideChannel
from ppoagent import PPOAgent
from airecorder import WandbRecorder
from aimemory import PPOMem
from aimemory import Targets
from arguments import parse_args
import torch.optim as optim

# side channel uuid
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
# tensorboard names
GAME_NAME = "Aimbot_Hybrid_V3"
GAME_TYPE = "Mix_Verification"

if __name__ == "__main__":
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    best_reward = -1

    # Initialize environment agent optimizer
    aimbot_side_channel = AimbotSideChannel(SIDE_CHANNEL_UUID)
    env = Aimbot(
        env_path=args.path,
        worker_id=args.workerID,
        base_port=args.baseport,
        side_channels=[aimbot_side_channel])
    if args.load_dir is None:
        agent = PPOAgent(
            env=env,
            this_args=args,
            device=device,
        ).to(device)
    else:
        agent = torch.load(args.load_dir)
        # freeze
        if args.freeze_viewnet:
            # freeze the view network
            for p in agent.viewNetwork.parameters():
                p.requires_grad = False
            print("VIEW NETWORK FREEZE")
        print("Load Agent", args.load_dir)
        print(agent.eval())
    # optimizer
    optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
    # Tensorboard and WandB Recorder
    run_name = f"{GAME_TYPE}_{args.seed}_{int(time.time())}"
    wdb_recorder = WandbRecorder(GAME_NAME, GAME_TYPE, run_name, args)

    @atexit.register
    def save_model():
        # close env
        env.close()
        if args.save_model:
            # save model while exit
            save_dir = "../PPO-Model/" + run_name + "_last.pt"
            torch.save(agent, save_dir)
            print("save model to " + save_dir)

    # start the game
    total_update_step = args.target_num * args.total_timesteps // args.datasetSize
    target_steps = [0 for i in range(args.target_num)]
    start_time = time.time()
    state, _, done = env.reset()

    # initialize AI memories
    ppo_memories = PPOMem(
        args=args,
        unity_agent_num=env.unity_agent_num,
        device=device,
    )

    # MAIN LOOP: run agent in environment
    for total_steps in range(total_update_step):
        # discount learning rate, while step == total_update_step lr will be 0
        if args.annealLR:
            final_lr_ratio = args.target_lr / args.lr
            frac = 1.0 - ((total_steps + 1.0) / total_update_step)
            lr_now = frac * args.lr
            optimizer.param_groups[0]["lr"] = lr_now
        else:
            lr_now = args.lr

        # episode start show learning rate
        print("new episode", total_steps, "learning rate = ", lr_now)
        step = 0
        training = False
        train_queue = []
        last_reward = [0. for i in range(env.unity_agent_num)]
        # MAIN LOOP: run agent in environment
        while True:
            # Target Type(state[0][0]) is stay(4),use all zero action
            if state[0][0] == 4:
                next_state, reward, next_done = env.step(env.all_zero_action)
                state, done = next_state, next_done
                continue
            # On decision point, and Target Type(state[0][0]) is not stay(4) choose action by agent
            if step % args.decision_period == 0:
                step += 1
                # Choose action by agent
                with torch.no_grad():
                    # predict actions
                    action, dis_logprob, _, con_logprob, _, value = agent.get_actions_value(
                        torch.Tensor(state).to(device)
                    )
                    value = value.flatten()

                # variable from GPU to CPU
                action_cpu = action.cpu().numpy()
                dis_logprob_cpu = dis_logprob.cpu().numpy()
                con_logprob_cpu = con_logprob.cpu().numpy()
                value_cpu = value.cpu().numpy()
                # Environment step
                next_state, reward, next_done = env.step(action_cpu)

                # save memories
                if args.train:
                    ppo_memories.save_memories(
                        now_step=step,
                        agent=agent,
                        state=state,
                        action_cpu=action_cpu,
                        dis_logprob_cpu=dis_logprob_cpu,
                        con_logprob_cpu=con_logprob_cpu,
                        reward=reward,
                        done=done,
                        value_cpu=value_cpu,
                        last_reward=last_reward,
                        next_done=next_done,
                        next_state=next_state,
                    )
                    # check if any training dataset is full and ready to train
                    for i in range(args.target_num):
                        if ppo_memories.obs[i].size()[0] >= args.datasetSize:
                            # start train NN
                            train_queue.append(i)
                    if len(train_queue) > 0:
                        # break while loop and start train
                        break
                    # update state
                state, done = next_state, next_done
            else:
                step += 1
                # skip this step use last predict action
                next_state, reward, next_done = env.step(action_cpu)
                # save memories
                if args.train:
                    ppo_memories.save_memories(
                        now_step=step,
                        agent=agent,
                        state=state,
                        action_cpu=action_cpu,
                        dis_logprob_cpu=dis_logprob_cpu,
                        con_logprob_cpu=con_logprob_cpu,
                        reward=reward,
                        done=done,
                        value_cpu=value_cpu,
                        last_reward=last_reward,
                        next_done=next_done,
                        next_state=next_state,
                    )
                    # update state
                    state = next_state
                    last_reward = reward

        if args.train:
            # train mode on
            mean_reward_list = []  # for WANDB
            # loop all training queue
            for this_train_ind in train_queue:
                # start time
                start_time = time.time()
                target_steps[this_train_ind] += 1
                # train agent
                (
                    v_loss,
                    dis_pg_loss,
                    con_pg_loss,
                    loss,
                    entropy_loss
                ) = agent.train_net(
                    this_train_ind=this_train_ind,
                    ppo_memories=ppo_memories,
                    optimizer=optimizer
                )
                # record mean reward before clear history
                print("done")
                target_reward_mean = np.mean(ppo_memories.rewards[this_train_ind].to("cpu").detach().numpy().copy())
                mean_reward_list.append(target_reward_mean)
                targetName = Targets(this_train_ind).name

                # clear this target training set buffer
                ppo_memories.clear_training_datasets(this_train_ind)
                # record rewards for plotting purposes
                wdb_recorder.add_target_scalar(
                    targetName,
                    this_train_ind,
                    v_loss,
                    dis_pg_loss,
                    con_pg_loss,
                    loss,
                    entropy_loss,
                    target_reward_mean,
                    target_steps,
                )
                print(f"episode over Target{targetName} mean reward:", target_reward_mean)
            TotalRewardMean = np.mean(mean_reward_list)
            wdb_recorder.add_global_scalar(
                TotalRewardMean,
                optimizer.param_groups[0]["lr"],
                total_steps,
            )
            # print cost time as seconds
            print("cost time:", time.time() - start_time)
            # New Record!
            if TotalRewardMean > best_reward and args.save_model:
                best_reward = target_reward_mean
                saveDir = "../PPO-Model/" + run_name + "_" + str(TotalRewardMean) + ".pt"
                torch.save(agent, saveDir)
        else:
            # train mode off
            mean_reward_list = []  # for WANDB
            # while not in training mode, clear the buffer
            for this_train_ind in train_queue:
                target_steps[this_train_ind] += 1
                targetName = Targets(this_train_ind).name
                target_reward_mean = np.mean(ppo_memories.rewards[this_train_ind].to("cpu").detach().numpy().copy())
                mean_reward_list.append(target_reward_mean)
                print(target_steps[this_train_ind])

                # clear this target training set buffer
                ppo_memories.clear_training_datasets(this_train_ind)

                # record rewards for plotting purposes
                wdb_recorder.writer.add_scalar(f"Target{targetName}/Reward", target_reward_mean,
                                               target_steps[this_train_ind])
                wdb_recorder.add_win_ratio(targetName, target_steps[this_train_ind])
                print(f"episode over Target{targetName} mean reward:", target_reward_mean)
            TotalRewardMean = np.mean(mean_reward_list)
            wdb_recorder.writer.add_scalar("GlobalCharts/TotalRewardMean", TotalRewardMean, total_steps)

    saveDir = "../PPO-Model/" + run_name + "_last.pt"
    torch.save(agent, saveDir)
    env.close()
    wdb_recorder.writer.close()