weight and bias sync added
weight and bias sync added
This commit is contained in:
parent
7497ffcb0f
commit
0dbe2013ae
2
.gitignore
vendored
2
.gitignore
vendored
@ -78,6 +78,8 @@ crashlytics-build.properties
|
||||
/Aimbot-PPO-Python/__pycache__/
|
||||
/Aimbot-PPO-Python/Tensorflow/__pycache__/
|
||||
/Aimbot-PPO-Python/Pytorch/__pycache__/
|
||||
/Aimbot-PPO-Python/Pytorch/runs/
|
||||
/Aimbot-PPO-Python/Pytorch/wandb/
|
||||
/Aimbot-PPO-Python/Backup/
|
||||
/Aimbot-PPO-Python/Build-MultiScene-WithLoad/
|
||||
/Aimbot-PPO-Python/Build-CloseEnemyCut/
|
||||
|
@ -99,7 +99,7 @@ class Aimbot(gym.Env):
|
||||
# create continuous actions from actions list
|
||||
continuousActions = actions[:,self.unity_discrete_size :]
|
||||
"""
|
||||
continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0]])
|
||||
continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]])
|
||||
# create actionTuple
|
||||
thisActionTuple = ActionTuple(continuous=continuousActions, discrete=discreteActions)
|
||||
# take action to env
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import wandb
|
||||
import time
|
||||
import numpy as np
|
||||
import random
|
||||
@ -14,6 +15,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
DEFAULT_SEED = 9331
|
||||
ENV_PATH = "../Build-ParallelEnv/Aimbot-ParallelEnv"
|
||||
WAND_ENTITY = "koha9"
|
||||
WORKER_ID = 1
|
||||
BASE_PORT = 2002
|
||||
|
||||
@ -60,6 +62,8 @@ def parse_args():
|
||||
help="the K epochs to update the policy")
|
||||
parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True,
|
||||
help="Toggle learning rate annealing for policy and value networks")
|
||||
parser.add_argument("--wandb-entity", type=str, default=None,
|
||||
help="the entity (team) of wandb's project")
|
||||
# GAE
|
||||
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
||||
help="Use GAE for advantage computation")
|
||||
@ -139,6 +143,26 @@ if __name__ == "__main__":
|
||||
agent = PPOAgent(env).to(device)
|
||||
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
|
||||
|
||||
# Tensorboard and WandB Recorder
|
||||
game_name = "Aimbot"
|
||||
run_name = f"{game_name}__{args.seed}__{int(time.time())}"
|
||||
wandb.init(
|
||||
project=run_name,
|
||||
entity=args.wandb_entity,
|
||||
sync_tensorboard=True,
|
||||
config=vars(args),
|
||||
name=run_name,
|
||||
monitor_gym=True,
|
||||
save_code=True,
|
||||
)
|
||||
|
||||
writer = SummaryWriter(f"runs/{run_name}")
|
||||
writer.add_text(
|
||||
"hyperparameters",
|
||||
"|param|value|\n|-|-|\n%s"
|
||||
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
||||
)
|
||||
|
||||
# Memory Record
|
||||
obs = torch.zeros((args.stepNum, env.unity_agent_num) + env.unity_observation_shape).to(device)
|
||||
actions = torch.zeros((args.stepNum, env.unity_agent_num) + (env.unity_discrete_type,)).to(
|
||||
@ -168,7 +192,6 @@ if __name__ == "__main__":
|
||||
|
||||
# MAIN LOOP: run agent in environment
|
||||
for step in range(args.stepNum):
|
||||
print(step)
|
||||
global_step += 1 * env.unity_agent_num
|
||||
obs[step] = next_obs
|
||||
dones[step] = next_done
|
||||
@ -289,3 +312,16 @@ if __name__ == "__main__":
|
||||
if args.target_kl is not None:
|
||||
if approx_kl > args.target_kl:
|
||||
break
|
||||
# record rewards for plotting purposes
|
||||
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
|
||||
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
|
||||
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
|
||||
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
|
||||
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
|
||||
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
|
||||
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
|
||||
print("SPS:", int(global_step / (time.time() - start_time)))
|
||||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
|
||||
|
||||
env.close()
|
||||
writer.close()
|
||||
|
@ -303,124 +303,6 @@
|
||||
"(128, 4) + env.unity_observation_shape\n",
|
||||
"env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[0 0 0 0]\n",
|
||||
" [0 0 0 0]\n",
|
||||
" [0 0 0 0]\n",
|
||||
" [0 0 0 0]]\n",
|
||||
"[[0]\n",
|
||||
" [0]\n",
|
||||
" [0]\n",
|
||||
" [0]]\n",
|
||||
"[[0 0 0]\n",
|
||||
" [0 0 0]\n",
|
||||
" [0 0 0]\n",
|
||||
" [0 0 0]]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"([array([ 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , -15.994009 , 1. , -26.322788 , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 2. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1.3519633, 1.6946528,\n",
|
||||
" 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n",
|
||||
" 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n",
|
||||
" 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 2. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1.3519667,\n",
|
||||
" 1.6946585, 2.3051722, 3.6734192, 9.067533 , 17.521563 ,\n",
|
||||
" 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n",
|
||||
" 21.022812 , 21.053455 , 0. ], dtype=float32),\n",
|
||||
" array([ 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , -1.8809433, 1. , -25.66834 , 1. ,\n",
|
||||
" 2. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 16.768637 , 23.414627 ,\n",
|
||||
" 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n",
|
||||
" 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n",
|
||||
" 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n",
|
||||
" 1. , 2. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 16.768671 ,\n",
|
||||
" 23.414669 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n",
|
||||
" 21.050694 , 15.049746 , 11.578423 , 9.695195 , 20.398046 ,\n",
|
||||
" 20.368372 , 20.398046 , 0. ], dtype=float32),\n",
|
||||
" array([ 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , -13.672583 , 1. , -26.479263 , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 5.3249803, 6.401276 ,\n",
|
||||
" 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n",
|
||||
" 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n",
|
||||
" 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 5.3249855,\n",
|
||||
" 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n",
|
||||
" 21.888773 , 22.922543 , 24.346823 , 26.097757 , 21.210148 ,\n",
|
||||
" 21.17929 , 21.21015 , 0. ], dtype=float32),\n",
|
||||
" array([ 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , 0. , 0. , 0. , 0. ,\n",
|
||||
" 0. , -4.9038744, 1. , -25.185507 , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 20.33171 , 22.859762 ,\n",
|
||||
" 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n",
|
||||
" 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n",
|
||||
" 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 1. ,\n",
|
||||
" 1. , 1. , 1. , 1. , 20.331783 ,\n",
|
||||
" 22.85977 , 21.522448 , 20.551773 , 20.00121 , 20.001146 ,\n",
|
||||
" 20.551619 , 21.522217 , 17.707582 , 14.868943 , 19.914528 ,\n",
|
||||
" 19.88554 , 19.914494 , 0. ], dtype=float32)],\n",
|
||||
" [[-0.05], [-0.05], [-0.05], [-0.05]],\n",
|
||||
" [[False], [False], [False], [False]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"actions = np.zeros_like(np.arange(16).reshape(4, 4))\n",
|
||||
"print(actions)\n",
|
||||
"env.step(actions)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
Reference in New Issue
Block a user