import argparse
import wandb
import time
import numpy as np
import random
import uuid
import torch
import torch.nn as nn
import torch.optim as optim

from AimbotEnv import Aimbot
from tqdm import tqdm
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from distutils.util import strtobool
from torch.utils.tensorboard import SummaryWriter
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.side_channel import (
    SideChannel,
    IncomingMessage,
    OutgoingMessage,
)
from typing import List

bestReward = 0

DEFAULT_SEED = 9331
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-ExtremeReward/Aimbot-ParallelEnv"
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
WAND_ENTITY = "koha9"
WORKER_ID = 1
BASE_PORT = 1000

# max round steps per agent is 2500/Decision_period, 25 seconds
# !!!check every parameters before run!!!

TOTAL_STEPS = 6000000
BATCH_SIZE = 512
MAX_TRAINNING_DATASETS = 8000
DECISION_PERIOD = 1
LEARNING_RATE = 1e-3
GAMMA = 0.99
GAE_LAMBDA = 0.95
EPOCHS = 4
CLIP_COEF = 0.1
POLICY_COEF = 1.0
ENTROPY_COEF = 0.01
CRITIC_COEF = 0.5
TARGET_LEARNING_RATE = 5e-5

ANNEAL_LEARNING_RATE = True
CLIP_VLOSS = True
NORM_ADV = True
TRAIN = True

WANDB_TACK = False
#LOAD_DIR = None
LOAD_DIR = "../PPO-Model/Aimbot-target-last.pt"

# public data
TotalRounds = {"Go":0,"Attack":0,"Free":0}
WinRounds = {"Go":0,"Attack":0,"Free":0}


def parse_args():
    # fmt: off
    # pytorch and environment parameters
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=DEFAULT_SEED,
                        help="seed of the experiment")
    parser.add_argument("--path", type=str, default=ENV_PATH,
                        help="enviroment path")
    parser.add_argument("--workerID", type=int, default=WORKER_ID,
                        help="unity worker ID")
    parser.add_argument("--baseport", type=int, default=BASE_PORT,
                        help="port to connect to Unity environment")
    parser.add_argument("--lr", type=float, default=LEARNING_RATE,
                        help="the learning rate of optimizer")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--total-timesteps", type=int, default=TOTAL_STEPS,
                        help="total timesteps of the experiments")

    # model parameters
    parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True,
                        help="Train Model or not")
    parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS,
                        help="training dataset size,start training while dataset collect enough data")
    parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE,
                        help="nimi batch size")
    parser.add_argument("--epochs", type=int, default=EPOCHS,
                        help="the K epochs to update the policy")
    parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument("--wandb-track", type=lambda x: bool(strtobool(x)), default=WANDB_TACK, nargs="?", const=True,
                        help="track on the wandb")
    parser.add_argument("--wandb-entity", type=str, default=WAND_ENTITY,
                        help="the entity (team) of wandb's project")
    parser.add_argument("--load-dir", type=str, default=LOAD_DIR,
                        help="load model directory")
    parser.add_argument("--decision-period", type=int, default=DECISION_PERIOD,
                        help="the number of steps to run in each environment per policy rollout")

    # GAE loss
    parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="Use GAE for advantage computation")
    parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=NORM_ADV, nargs="?", const=True,
                        help="Toggles advantages normalization")
    parser.add_argument("--gamma", type=float, default=GAMMA,
                        help="the discount factor gamma")
    parser.add_argument("--gaeLambda", type=float, default=GAE_LAMBDA,
                        help="the lambda for the general advantage estimation")
    parser.add_argument("--clip-coef", type=float, default=CLIP_COEF,
                        help="the surrogate clipping coefficient")
    parser.add_argument("--policy-coef", type=float, default=POLICY_COEF,
                        help="coefficient of the policy")
    parser.add_argument("--ent-coef", type=float, default=ENTROPY_COEF,
                        help="coefficient of the entropy")
    parser.add_argument("--critic-coef", type=float, default=CRITIC_COEF,
                        help="coefficient of the value function")
    parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=CLIP_VLOSS, nargs="?", const=True,
                        help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
    parser.add_argument("--max-grad-norm", type=float, default=0.5,
                        help="the maximum norm for the gradient clipping")
    parser.add_argument("--target-kl", type=float, default=None,
                        help="the target KL divergence threshold")
    # fmt: on
    args = parser.parse_args()
    return args


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class PPOAgent(nn.Module):
    def __init__(self, env: Aimbot):
        super(PPOAgent, self).__init__()
        self.discrete_size = env.unity_discrete_size
        self.discrete_shape = list(env.unity_discrete_branches)
        self.continuous_size = env.unity_continuous_size

        self.network = nn.Sequential(
            layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 700)),
            nn.ReLU(),
            layer_init(nn.Linear(700, 500)),
            nn.ReLU(),
            layer_init(nn.Linear(500, 256)),
            nn.ReLU(),
        )
        self.actor_dis = layer_init(nn.Linear(256, self.discrete_size), std=0.01)
        self.actor_mean = layer_init(nn.Linear(256, self.continuous_size), std=0.01)
        self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))
        self.critic = layer_init(nn.Linear(256, 1), std=1)

    def get_value(self, state: torch.Tensor):
        return self.critic(self.network(state))

    def get_actions_value(self, state: torch.Tensor, actions=None):
        hidden = self.network(state)
        # discrete
        dis_logits = self.actor_dis(hidden)
        split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
        multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
        # continuous
        actions_mean = self.actor_mean(hidden)
        action_logstd = self.actor_logstd.expand_as(actions_mean)
        action_std = torch.exp(action_logstd)
        con_probs = Normal(actions_mean, action_std)

        if actions is None:
            if args.train:
                # select actions base on probability distribution model
                disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
                conAct = con_probs.sample()
                actions = torch.cat([disAct.T, conAct], dim=1)
            else:
                # select actions base on best probability distribution
                disAct = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])
                conAct = actions_mean
                actions = torch.cat([disAct.T, conAct], dim=1)
        else:
            disAct = actions[:, 0 : env.unity_discrete_type].T
            conAct = actions[:, env.unity_discrete_type :]
        dis_log_prob = torch.stack(
            [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]
        )
        dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
        return (
            actions,
            dis_log_prob.sum(0),
            dis_entropy.sum(0),
            con_probs.log_prob(conAct).sum(1),
            con_probs.entropy().sum(1),
            self.critic(hidden),
        )


def GAE(agent, args, rewards, dones, values, next_obs, next_done):
    # GAE
    with torch.no_grad():
        next_value = agent.get_value(next_obs).reshape(1, -1)
        data_size = rewards.size()[0]
        if args.gae:
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(data_size)):
                if t == data_size - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = (
                    delta + args.gamma * args.gaeLambda * nextnonterminal * lastgaelam
                )
            returns = advantages + values
        else:
            returns = torch.zeros_like(rewards).to(device)
            for t in reversed(range(data_size)):
                if t == data_size - 1:
                    nextnonterminal = 1.0 - next_done
                    next_return = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    next_return = returns[t + 1]
                returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
            advantages = returns - values
    return advantages, returns

class AimbotSideChannel(SideChannel):
    def __init__(self, channel_id: uuid.UUID) -> None:
        super().__init__(channel_id)
    def on_message_received(self, msg: IncomingMessage) -> None:
        """
        Note: We must implement this method of the SideChannel interface to
        receive messages from Unity
        """
        thisMessage = msg.read_string()
        #print(thisMessage)
        thisResult = thisMessage.split("|")
        if(thisResult[0] == "result"):
            TotalRounds[thisResult[1]]+=1
            if(thisResult[2] == "Win"):
                WinRounds[thisResult[1]]+=1
            #print(TotalRounds)
            #print(WinRounds)
        elif(thisResult[0] == "Error"):
            print(thisMessage)
	# 发送函数
    def send_string(self, data: str) -> None:
        """发送一个字符串给C#"""
        msg = OutgoingMessage()
        msg.write_string(data)
        super().queue_message_to_send(msg)

    def send_bool(self, data: bool) -> None:
        msg = OutgoingMessage()
        msg.write_bool(data)
        super().queue_message_to_send(msg)

    def send_int(self, data: int) -> None:
        msg = OutgoingMessage()
        msg.write_int32(data)
        super().queue_message_to_send(msg)

    def send_float(self, data: float) -> None:
        msg = OutgoingMessage()
        msg.write_float32(data)
        super().queue_message_to_send(msg)

    def send_float_list(self, data: List[float]) -> None:
        msg = OutgoingMessage()
        msg.write_float32_list(data)
        super().queue_message_to_send(msg)


if __name__ == "__main__":
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Initialize environment anget optimizer
    aimBotsideChannel = AimbotSideChannel(SIDE_CHANNEL_UUID);
    env = Aimbot(envPath=args.path, workerID=args.workerID, basePort=args.baseport,side_channels=[aimBotsideChannel])
    if args.load_dir is None:
        agent = PPOAgent(env).to(device)
    else:
        agent = torch.load(args.load_dir)
        print("Load Agent", args.load_dir)
        print(agent.eval())

    optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)

    # Tensorboard and WandB Recorder
    game_name = "Aimbot_Target"
    game_type = "OffPolicy"
    run_name = f"{game_name}_{game_type}_{args.seed}_{int(time.time())}"
    if args.wandb_track:
        wandb.init(
            project=game_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )

    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s"
        % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # Trajectory Buffer
    ob_bf = [[] for i in range(env.unity_agent_num)]
    act_bf = [[] for i in range(env.unity_agent_num)]
    dis_logprobs_bf = [[] for i in range(env.unity_agent_num)]
    con_logprobs_bf = [[] for i in range(env.unity_agent_num)]
    rewards_bf = [[] for i in range(env.unity_agent_num)]
    dones_bf = [[] for i in range(env.unity_agent_num)]
    values_bf = [[] for i in range(env.unity_agent_num)]

    # TRY NOT TO MODIFY: start the game
    total_update_step = args.total_timesteps // args.datasetSize
    global_step = 0
    start_time = time.time()
    state, _, done = env.reset()
    # state = torch.Tensor(next_obs).to(device)
    # next_done = torch.zeros(env.unity_agent_num).to(device)

    for total_steps in range(total_update_step):
        # discunt learning rate, while step == total_update_step lr will be 0
        print("new episode")
        if args.annealLR:
            finalRatio = TARGET_LEARNING_RATE/args.lr
            frac = 1.0 - finalRatio*((total_steps - 1.0) / total_update_step)
            lrnow = frac * args.lr
            optimizer.param_groups[0]["lr"] = lrnow

        # initialize empty training datasets
        obs = torch.tensor([]).to(device)  # (n,env.unity_observation_size)
        actions = torch.tensor([]).to(device)  # (n,env.unity_action_size)
        dis_logprobs = torch.tensor([]).to(device)  # (n,1)
        con_logprobs = torch.tensor([]).to(device)  # (n,1)
        rewards = torch.tensor([]).to(device)  # (n,1)
        values = torch.tensor([]).to(device)  # (n,1)
        advantages = torch.tensor([]).to(device)  # (n,1)
        returns = torch.tensor([]).to(device)  # (n,1)

        # MAIN LOOP: run agent in environment
        i = 0
        training = False
        while True:
            if i % args.decision_period == 0:
                step = round(i / args.decision_period)
                # Choose action by agent
                global_step += 1 * env.unity_agent_num

                with torch.no_grad():
                    # predict actions
                    action, dis_logprob, _, con_logprob, _, value = agent.get_actions_value(
                        torch.Tensor(state).to(device)
                    )
                    value = value.flatten()

                # variable from GPU to CPU
                action_cpu = action.cpu().numpy()
                dis_logprob_cpu = dis_logprob.cpu().numpy()
                con_logprob_cpu = con_logprob.cpu().numpy()
                value_cpu = value.cpu().numpy()
                # Environment step
                next_state, reward, next_done = env.step(action_cpu)

                # save memories
                for i in range(env.unity_agent_num):
                    # save memories to buffers
                    ob_bf[i].append(state[i])
                    act_bf[i].append(action_cpu[i])
                    dis_logprobs_bf[i].append(dis_logprob_cpu[i])
                    con_logprobs_bf[i].append(con_logprob_cpu[i])
                    rewards_bf[i].append(reward[i])
                    dones_bf[i].append(done[i])
                    values_bf[i].append(value_cpu[i])
                    if next_done[i] == True:
                        # finished a round, send finished memories to training datasets
                        # compute advantage and discounted reward
                        #print(i,"over")
                        adv, rt = GAE(
                            agent,
                            args,
                            torch.tensor(rewards_bf[i]).to(device),
                            torch.Tensor(dones_bf[i]).to(device),
                            torch.tensor(values_bf[i]).to(device),
                            torch.tensor(next_state[i]).to(device),
                            torch.Tensor([next_done[i]]).to(device),
                        )
                        # send memories to training datasets
                        obs = torch.cat((obs, torch.tensor(ob_bf[i]).to(device)), 0)
                        actions = torch.cat((actions, torch.tensor(act_bf[i]).to(device)), 0)
                        dis_logprobs = torch.cat(
                            (dis_logprobs, torch.tensor(dis_logprobs_bf[i]).to(device)), 0
                        )
                        con_logprobs = torch.cat(
                            (con_logprobs, torch.tensor(con_logprobs_bf[i]).to(device)), 0
                        )
                        rewards = torch.cat((rewards, torch.tensor(rewards_bf[i]).to(device)), 0)
                        values = torch.cat((values, torch.tensor(values_bf[i]).to(device)), 0)
                        advantages = torch.cat((advantages, adv), 0)
                        returns = torch.cat((returns, rt), 0)

                        # clear buffers
                        ob_bf[i] = []
                        act_bf[i] = []
                        dis_logprobs_bf[i] = []
                        con_logprobs_bf[i] = []
                        rewards_bf[i] = []
                        dones_bf[i] = []
                        values_bf[i] = []
                        print(f"train dataset added:{obs.size()[0]}/{args.datasetSize}")

                if obs.size()[0] >= args.datasetSize:
                    # start train NN
                    break
                state, done = next_state, next_done
            else:
                # skip this step use last predict action
                next_obs, reward, next_done = env.step(action_cpu)
                # save memories
                for i in range(env.unity_agent_num):
                    if next_done[i] == True:
                        #print(i,"over???")
                        # save last memories to buffers
                        ob_bf[i].append(state[i])
                        act_bf[i].append(action_cpu[i])
                        dis_logprobs_bf[i].append(dis_logprob_cpu[i])
                        con_logprobs_bf[i].append(con_logprob_cpu[i])
                        rewards_bf[i].append(reward[i])
                        dones_bf[i].append(done[i])
                        values_bf[i].append(value_cpu[i])
                        # finished a round, send finished memories to training datasets
                        # compute advantage and discounted reward
                        adv, rt = GAE(
                            agent,
                            args,
                            torch.tensor(rewards_bf[i]).to(device),
                            torch.Tensor(dones_bf[i]).to(device),
                            torch.tensor(values_bf[i]).to(device),
                            torch.tensor(next_state[i]).to(device),
                            torch.Tensor([next_done[i]]).to(device),
                        )
                        # send memories to training datasets
                        obs = torch.cat((obs, torch.tensor(ob_bf[i]).to(device)), 0)
                        actions = torch.cat((actions, torch.tensor(act_bf[i]).to(device)), 0)
                        dis_logprobs = torch.cat(
                            (dis_logprobs, torch.tensor(dis_logprobs_bf[i]).to(device)), 0
                        )
                        con_logprobs = torch.cat(
                            (con_logprobs, torch.tensor(con_logprobs_bf[i]).to(device)), 0
                        )
                        rewards = torch.cat((rewards, torch.tensor(rewards_bf[i]).to(device)), 0)
                        values = torch.cat((values, torch.tensor(values_bf[i]).to(device)), 0)
                        advantages = torch.cat((advantages, adv), 0)
                        returns = torch.cat((returns, rt), 0)

                        # clear buffers
                        ob_bf[i] = []
                        act_bf[i] = []
                        dis_logprobs_bf[i] = []
                        con_logprobs_bf[i] = []
                        rewards_bf[i] = []
                        dones_bf[i] = []
                        values_bf[i] = []
                        print(f"train dataset added:{obs.size()[0]}/{args.datasetSize}")
                state, done = next_state, next_done
            i += 1

        if args.train:
            # flatten the batch
            b_obs = obs.reshape((-1,) + env.unity_observation_shape)
            b_dis_logprobs = dis_logprobs.reshape(-1)
            b_con_logprobs = con_logprobs.reshape(-1)
            b_actions = actions.reshape((-1,) + (env.unity_action_size,))
            b_advantages = advantages.reshape(-1)
            b_returns = returns.reshape(-1)
            b_values = values.reshape(-1)
            b_size = b_obs.size()[0]
            # Optimizing the policy and value network
            b_inds = np.arange(b_size)
            # clipfracs = []
            for epoch in range(args.epochs):
                # shuffle all datasets
                np.random.shuffle(b_inds)
                for start in range(0, b_size, args.minibatchSize):
                    end = start + args.minibatchSize
                    mb_inds = b_inds[start:end]
                    mb_advantages = b_advantages[mb_inds]

                    # normalize advantages
                    if args.norm_adv:
                        mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                            mb_advantages.std() + 1e-8
                        )

                    (
                        _,
                        new_dis_logprob,
                        dis_entropy,
                        new_con_logprob,
                        con_entropy,
                        newvalue,
                    ) = agent.get_actions_value(b_obs[mb_inds], b_actions[mb_inds])
                    # discrete ratio
                    dis_logratio = new_dis_logprob - b_dis_logprobs[mb_inds]
                    dis_ratio = dis_logratio.exp()
                    # continuous ratio
                    con_logratio = new_con_logprob - b_con_logprobs[mb_inds]
                    con_ratio = con_logratio.exp()

                    """
                    # early stop
                    with torch.no_grad():
                        # calculate approx_kl http://joschu.net/blog/kl-approx.html
                        old_approx_kl = (-logratio).mean()
                        approx_kl = ((ratio - 1) - logratio).mean()
                        clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
                    """

                    # discrete Policy loss
                    dis_pg_loss_orig = -mb_advantages * dis_ratio
                    dis_pg_loss_clip = -mb_advantages * torch.clamp(
                        dis_ratio, 1 - args.clip_coef, 1 + args.clip_coef
                    )
                    dis_pg_loss = torch.max(dis_pg_loss_orig, dis_pg_loss_clip).mean()
                    # continuous Policy loss
                    con_pg_loss_orig = -mb_advantages * con_ratio
                    con_pg_loss_clip = -mb_advantages * torch.clamp(
                        con_ratio, 1 - args.clip_coef, 1 + args.clip_coef
                    )
                    con_pg_loss = torch.max(con_pg_loss_orig, con_pg_loss_clip).mean()

                    # Value loss
                    newvalue = newvalue.view(-1)
                    if args.clip_vloss:
                        v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                        v_clipped = b_values[mb_inds] + torch.clamp(
                            newvalue - b_values[mb_inds],
                            -args.clip_coef,
                            args.clip_coef,
                        )
                        v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                        v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                        v_loss = 0.5 * v_loss_max.mean()
                    else:
                        v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                    # total loss
                    entropy_loss = dis_entropy.mean() + con_entropy.mean()
                    loss = (
                        dis_pg_loss * args.policy_coef
                        + con_pg_loss * args.policy_coef
                        - entropy_loss * args.ent_coef
                        + v_loss * args.critic_coef
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    # Clips gradient norm of an iterable of parameters.
                    nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                    optimizer.step()

                """
                if args.target_kl is not None:
                    if approx_kl > args.target_kl:
                        break
                """
            # record rewards for plotting purposes
            rewardsMean = np.mean(rewards.to("cpu").detach().numpy().copy())
            writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
            writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
            writer.add_scalar("losses/dis_policy_loss", dis_pg_loss.item(), global_step)
            writer.add_scalar("losses/con_policy_loss", con_pg_loss.item(), global_step)
            writer.add_scalar("losses/total_loss", loss.item(), global_step)
            writer.add_scalar("losses/entropy_loss", entropy_loss.item(), global_step)
            # writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
            # writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
            # writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
            # print("SPS:", int(global_step / (time.time() - start_time)))
            print("episode over mean reward:", rewardsMean)
            writer.add_scalar(
                "charts/SPS", int(global_step / (time.time() - start_time)), global_step
            )
            writer.add_scalar("charts/Reward", rewardsMean, global_step)
            writer.add_scalar("charts/GoWinRatio", WinRounds["Go"]/TotalRounds["Go"], global_step)
            writer.add_scalar("charts/AttackWinRatio", WinRounds["Attack"]/TotalRounds["Attack"], global_step)
            writer.add_scalar("charts/FreeWinRatio", WinRounds["Free"]/TotalRounds["Free"], global_step)
            if rewardsMean > bestReward:
                bestReward = rewardsMean
                saveDir = "../PPO-Model/Target-700-500-256-hybrid-" + str(rewardsMean) + ".pt"
                torch.save(agent, saveDir)

    env.close()
    writer.close()