import argparse
import uuid

from distutils.util import strtobool

DEFAULT_SEED = 9331
ENV_PATH = "../Build/3.5/Aimbot-ParallelEnv"
WAND_ENTITY = "koha9"
WORKER_ID = 1
BASE_PORT = 1000

# tensorboard names

# max round steps per agent is 2500/Decision_period, 25 seconds
TOTAL_STEPS = 3150000
BATCH_SIZE = 512
MAX_TRAINNING_DATASETS = 6000
DECISION_PERIOD = 1
LEARNING_RATE = 1.5e-4
GAMMA = 0.99
GAE_LAMBDA = 0.95
EPOCHS = 3
CLIP_COEF = 0.11
LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
POLICY_COEF = [0.8, 0.8, 0.8, 0.8]
ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]
CRITIC_COEF = [0.8, 0.8, 0.8, 0.8]
TARGET_LEARNING_RATE = 1e-6

FREEZE_VIEW_NETWORK = False
ANNEAL_LEARNING_RATE = True
CLIP_VLOSS = True
NORM_ADV = False
TRAIN = True
SAVE_MODEL = True
WANDB_TACK = True
LOAD_DIR = None
# LOAD_DIR = "../PPO-Model/GotoOnly-Level0123_9331_1696965321/5.1035867.pt"

# Unity Environment Parameters
TARGET_STATE_SIZE = 6
INAREA_STATE_SIZE = 1
TIME_STATE_SIZE = 1
GUN_STATE_SIZE = 1
MY_STATE_SIZE = 4
TOTAL_T_SIZE = TARGET_STATE_SIZE+INAREA_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE
BASE_WINREWARD = 999
BASE_LOSEREWARD = -999
TARGETNUM= 4
ENV_TIMELIMIT = 30
RESULT_BROADCAST_RATIO = 1/ENV_TIMELIMIT

save_model_this_episode = False

def is_save_model():
    global save_model_this_episode
    return save_model_this_episode
def set_save_model(save_model:bool):
    print("set save model to ",save_model)
    global save_model_this_episode
    save_model_this_episode = save_model

def parse_args():
    # fmt: off
    # pytorch and environment parameters
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=DEFAULT_SEED,
                        help="seed of the experiment")
    parser.add_argument("--path", type=str, default=ENV_PATH,
                        help="enviroment path")
    parser.add_argument("--workerID", type=int, default=WORKER_ID,
                        help="unity worker ID")
    parser.add_argument("--baseport", type=int, default=BASE_PORT,
                        help="port to connect to Unity environment")
    parser.add_argument("--lr", type=float, default=LEARNING_RATE,
                        help="the default learning rate of optimizer")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--total-timesteps", type=int, default=TOTAL_STEPS,
                        help="total timesteps of the experiments")

    # model parameters
    parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True,
                        help="Train Model or not")
    parser.add_argument("--freeze-viewnet", type=lambda x: bool(strtobool(x)), default=FREEZE_VIEW_NETWORK, nargs="?", const=True,
                        help="freeze view network or not")
    parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS,
                        help="training dataset size,start training while dataset collect enough data")
    parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE,
                        help="nimi batch size")
    parser.add_argument("--epochs", type=int, default=EPOCHS,
                        help="the K epochs to update the policy")
    parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument("--wandb-track", type=lambda x: bool(strtobool(x)), default=WANDB_TACK, nargs="?", const=True,
                        help="track on the wandb")
    parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=SAVE_MODEL, nargs="?", const=True,
                        help="save model or not")
    parser.add_argument("--wandb-entity", type=str, default=WAND_ENTITY,
                        help="the entity (team) of wandb's project")
    parser.add_argument("--load-dir", type=str, default=LOAD_DIR,
                        help="load model directory")
    parser.add_argument("--decision-period", type=int, default=DECISION_PERIOD,
                        help="the number of steps to run in each environment per policy rollout")
    parser.add_argument("--result-broadcast-ratio", type=float, default=RESULT_BROADCAST_RATIO,
                        help="broadcast result when win round is reached,r=result-broadcast-ratio*remainTime")
    # target_learning_rate
    parser.add_argument("--target-lr", type=float, default=TARGET_LEARNING_RATE,
                        help="target value of downscaling the learning rate")

    # POLICY_COEF ENTROPY_COEF CRITIC_COEF LOSS_COEF
    parser.add_argument("--policy-coef", type=float, default=POLICY_COEF,
                        help="coefficient of the policy loss")
    parser.add_argument("--entropy-coef", type=float, default=ENTROPY_COEF,
                        help="coefficient of the entropy loss")
    parser.add_argument("--critic-coef", type=float, default=CRITIC_COEF,
                        help="coefficient of the critic loss")
    parser.add_argument("--loss-coef", type=float, default=LOSS_COEF,
                        help="coefficient of the total loss")

    # GAE loss
    parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="Use GAE for advantage computation")
    parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=NORM_ADV, nargs="?", const=True,
                        help="Toggles advantages normalization")
    parser.add_argument("--gamma", type=float, default=GAMMA,
                        help="the discount factor gamma")
    parser.add_argument("--gaeLambda", type=float, default=GAE_LAMBDA,
                        help="the lambda for the general advantage estimation")
    parser.add_argument("--clip-coef", type=float, default=CLIP_COEF,
                        help="the surrogate clipping coefficient")
    parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=CLIP_VLOSS, nargs="?", const=True,
                        help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
    parser.add_argument("--max-grad-norm", type=float, default=0.5,
                        help="the maximum norm for the gradient clipping")
    parser.add_argument("--target-kl", type=float, default=None,
                        help="the target KL divergence threshold")
    # environment parameters
    parser.add_argument("--target-num", type=int, default=TARGETNUM,
                        help="the number of targets")
    parser.add_argument("--env-timelimit", type=int, default=ENV_TIMELIMIT,
                        help="the time limit of each round")
    parser.add_argument("--base-win-reward", type=int, default=BASE_WINREWARD,
                        help="the base reward of win round")
    parser.add_argument("--base-lose-reward", type=int, default=BASE_LOSEREWARD,
                        help="the base reward of lose round")
    parser.add_argument("--target-state-size", type=int, default=TARGET_STATE_SIZE,
                        help="the size of target state")
    parser.add_argument("--time-state-size", type=int, default=TIME_STATE_SIZE,
                        help="the size of time state")
    parser.add_argument("--gun-state-size", type=int, default=GUN_STATE_SIZE,
                        help="the size of gun state")
    parser.add_argument("--my-state-size", type=int, default=MY_STATE_SIZE,
                        help="the size of my state")
    parser.add_argument("--total-target-size", type=int, default=TOTAL_T_SIZE,
                        help="the size of total target state")
    # fmt: on
    args = parser.parse_args()
    return args