159 lines
7.8 KiB
Python
159 lines
7.8 KiB
Python
import argparse
|
|
import uuid
|
|
|
|
from distutils.util import strtobool
|
|
|
|
DEFAULT_SEED = 9331
|
|
ENV_PATH = "../Build/3.4/Aimbot-ParallelEnv"
|
|
WAND_ENTITY = "koha9"
|
|
WORKER_ID = 1
|
|
BASE_PORT = 1000
|
|
|
|
# tensorboard names
|
|
|
|
# max round steps per agent is 2500/Decision_period, 25 seconds
|
|
TOTAL_STEPS = 3150000
|
|
BATCH_SIZE = 512
|
|
MAX_TRAINNING_DATASETS = 6000
|
|
DECISION_PERIOD = 1
|
|
LEARNING_RATE = 1.5e-4
|
|
GAMMA = 0.99
|
|
GAE_LAMBDA = 0.95
|
|
EPOCHS = 3
|
|
CLIP_COEF = 0.11
|
|
LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
|
|
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
|
|
ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]
|
|
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
|
|
TARGET_LEARNING_RATE = 1e-6
|
|
|
|
FREEZE_VIEW_NETWORK = False
|
|
ANNEAL_LEARNING_RATE = True
|
|
CLIP_VLOSS = True
|
|
NORM_ADV = False
|
|
TRAIN = True
|
|
SAVE_MODEL = True
|
|
WANDB_TACK = True
|
|
LOAD_DIR = None
|
|
LOAD_DIR = "../PPO-Model/GotoOnly-Level1234_9331_1697122986/8.853553.pt"
|
|
|
|
# Unity Environment Parameters
|
|
TARGET_STATE_SIZE = 6
|
|
INAREA_STATE_SIZE = 1
|
|
TIME_STATE_SIZE = 1
|
|
GUN_STATE_SIZE = 1
|
|
MY_STATE_SIZE = 4
|
|
TOTAL_T_SIZE = TARGET_STATE_SIZE+INAREA_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE
|
|
BASE_WINREWARD = 999
|
|
BASE_LOSEREWARD = -999
|
|
TARGETNUM= 4
|
|
ENV_TIMELIMIT = 30
|
|
RESULT_BROADCAST_RATIO = 1/ENV_TIMELIMIT
|
|
|
|
save_model_this_episode = False
|
|
|
|
def is_save_model():
|
|
global save_model_this_episode
|
|
return save_model_this_episode
|
|
def set_save_model(save_model:bool):
|
|
print("set save model to ",save_model)
|
|
global save_model_this_episode
|
|
save_model_this_episode = save_model
|
|
|
|
def parse_args():
|
|
# fmt: off
|
|
# pytorch and environment parameters
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED,
|
|
help="seed of the experiment")
|
|
parser.add_argument("--path", type=str, default=ENV_PATH,
|
|
help="enviroment path")
|
|
parser.add_argument("--workerID", type=int, default=WORKER_ID,
|
|
help="unity worker ID")
|
|
parser.add_argument("--baseport", type=int, default=BASE_PORT,
|
|
help="port to connect to Unity environment")
|
|
parser.add_argument("--lr", type=float, default=LEARNING_RATE,
|
|
help="the default learning rate of optimizer")
|
|
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
|
help="if toggled, cuda will be enabled by default")
|
|
parser.add_argument("--total-timesteps", type=int, default=TOTAL_STEPS,
|
|
help="total timesteps of the experiments")
|
|
|
|
# model parameters
|
|
parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True,
|
|
help="Train Model or not")
|
|
parser.add_argument("--freeze-viewnet", type=lambda x: bool(strtobool(x)), default=FREEZE_VIEW_NETWORK, nargs="?", const=True,
|
|
help="freeze view network or not")
|
|
parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS,
|
|
help="training dataset size,start training while dataset collect enough data")
|
|
parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE,
|
|
help="nimi batch size")
|
|
parser.add_argument("--epochs", type=int, default=EPOCHS,
|
|
help="the K epochs to update the policy")
|
|
parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True,
|
|
help="Toggle learning rate annealing for policy and value networks")
|
|
parser.add_argument("--wandb-track", type=lambda x: bool(strtobool(x)), default=WANDB_TACK, nargs="?", const=True,
|
|
help="track on the wandb")
|
|
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=SAVE_MODEL, nargs="?", const=True,
|
|
help="save model or not")
|
|
parser.add_argument("--wandb-entity", type=str, default=WAND_ENTITY,
|
|
help="the entity (team) of wandb's project")
|
|
parser.add_argument("--load-dir", type=str, default=LOAD_DIR,
|
|
help="load model directory")
|
|
parser.add_argument("--decision-period", type=int, default=DECISION_PERIOD,
|
|
help="the number of steps to run in each environment per policy rollout")
|
|
parser.add_argument("--result-broadcast-ratio", type=float, default=RESULT_BROADCAST_RATIO,
|
|
help="broadcast result when win round is reached,r=result-broadcast-ratio*remainTime")
|
|
# target_learning_rate
|
|
parser.add_argument("--target-lr", type=float, default=TARGET_LEARNING_RATE,
|
|
help="target value of downscaling the learning rate")
|
|
|
|
# POLICY_COEF ENTROPY_COEF CRITIC_COEF LOSS_COEF
|
|
parser.add_argument("--policy-coef", type=float, default=POLICY_COEF,
|
|
help="coefficient of the policy loss")
|
|
parser.add_argument("--entropy-coef", type=float, default=ENTROPY_COEF,
|
|
help="coefficient of the entropy loss")
|
|
parser.add_argument("--critic-coef", type=float, default=CRITIC_COEF,
|
|
help="coefficient of the critic loss")
|
|
parser.add_argument("--loss-coef", type=float, default=LOSS_COEF,
|
|
help="coefficient of the total loss")
|
|
|
|
# GAE loss
|
|
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
|
help="Use GAE for advantage computation")
|
|
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=NORM_ADV, nargs="?", const=True,
|
|
help="Toggles advantages normalization")
|
|
parser.add_argument("--gamma", type=float, default=GAMMA,
|
|
help="the discount factor gamma")
|
|
parser.add_argument("--gaeLambda", type=float, default=GAE_LAMBDA,
|
|
help="the lambda for the general advantage estimation")
|
|
parser.add_argument("--clip-coef", type=float, default=CLIP_COEF,
|
|
help="the surrogate clipping coefficient")
|
|
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=CLIP_VLOSS, nargs="?", const=True,
|
|
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
|
|
parser.add_argument("--max-grad-norm", type=float, default=0.5,
|
|
help="the maximum norm for the gradient clipping")
|
|
parser.add_argument("--target-kl", type=float, default=None,
|
|
help="the target KL divergence threshold")
|
|
# environment parameters
|
|
parser.add_argument("--target-num", type=int, default=TARGETNUM,
|
|
help="the number of targets")
|
|
parser.add_argument("--env-timelimit", type=int, default=ENV_TIMELIMIT,
|
|
help="the time limit of each round")
|
|
parser.add_argument("--base-win-reward", type=int, default=BASE_WINREWARD,
|
|
help="the base reward of win round")
|
|
parser.add_argument("--base-lose-reward", type=int, default=BASE_LOSEREWARD,
|
|
help="the base reward of lose round")
|
|
parser.add_argument("--target-state-size", type=int, default=TARGET_STATE_SIZE,
|
|
help="the size of target state")
|
|
parser.add_argument("--time-state-size", type=int, default=TIME_STATE_SIZE,
|
|
help="the size of time state")
|
|
parser.add_argument("--gun-state-size", type=int, default=GUN_STATE_SIZE,
|
|
help="the size of gun state")
|
|
parser.add_argument("--my-state-size", type=int, default=MY_STATE_SIZE,
|
|
help="the size of my state")
|
|
parser.add_argument("--total-target-size", type=int, default=TOTAL_T_SIZE,
|
|
help="the size of total target state")
|
|
# fmt: on
|
|
args = parser.parse_args()
|
|
return args |