weight and bias sync added

weight and bias sync added
This commit is contained in:
Koha9 2022-11-01 19:11:45 +09:00
parent 7497ffcb0f
commit 0dbe2013ae
4 changed files with 40 additions and 120 deletions

2
.gitignore vendored
View File

@ -78,6 +78,8 @@ crashlytics-build.properties
/Aimbot-PPO-Python/__pycache__/
/Aimbot-PPO-Python/Tensorflow/__pycache__/
/Aimbot-PPO-Python/Pytorch/__pycache__/
/Aimbot-PPO-Python/Pytorch/runs/
/Aimbot-PPO-Python/Pytorch/wandb/
/Aimbot-PPO-Python/Backup/
/Aimbot-PPO-Python/Build-MultiScene-WithLoad/
/Aimbot-PPO-Python/Build-CloseEnemyCut/

View File

@ -99,7 +99,7 @@ class Aimbot(gym.Env):
# create continuous actions from actions list
continuousActions = actions[:,self.unity_discrete_size :]
"""
continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0]])
continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]])
# create actionTuple
thisActionTuple = ActionTuple(continuous=continuousActions, discrete=discreteActions)
# take action to env

View File

@ -1,4 +1,5 @@
import argparse
import wandb
import time
import numpy as np
import random
@ -14,6 +15,7 @@ from torch.utils.tensorboard import SummaryWriter
DEFAULT_SEED = 9331
ENV_PATH = "../Build-ParallelEnv/Aimbot-ParallelEnv"
WAND_ENTITY = "koha9"
WORKER_ID = 1
BASE_PORT = 2002
@ -60,6 +62,8 @@ def parse_args():
help="the K epochs to update the policy")
parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
# GAE
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Use GAE for advantage computation")
@ -139,6 +143,26 @@ if __name__ == "__main__":
agent = PPOAgent(env).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
# Tensorboard and WandB Recorder
game_name = "Aimbot"
run_name = f"{game_name}__{args.seed}__{int(time.time())}"
wandb.init(
project=run_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s"
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# Memory Record
obs = torch.zeros((args.stepNum, env.unity_agent_num) + env.unity_observation_shape).to(device)
actions = torch.zeros((args.stepNum, env.unity_agent_num) + (env.unity_discrete_type,)).to(
@ -168,7 +192,6 @@ if __name__ == "__main__":
# MAIN LOOP: run agent in environment
for step in range(args.stepNum):
print(step)
global_step += 1 * env.unity_agent_num
obs[step] = next_obs
dones[step] = next_done
@ -289,3 +312,16 @@ if __name__ == "__main__":
if args.target_kl is not None:
if approx_kl > args.target_kl:
break
# record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
env.close()
writer.close()

View File

@ -303,124 +303,6 @@
"(128, 4) + env.unity_observation_shape\n",
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]]\n",
"[[0]\n",
" [0]\n",
" [0]\n",
" [0]]\n",
"[[0 0 0]\n",
" [0 0 0]\n",
" [0 0 0]\n",
" [0 0 0]]\n"
]
},
{
"data": {
"text/plain": [
"([array([ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -15.994009 , 1. , -26.322788 , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 2. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1.3519633, 1.6946528,\n",
" 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n",
" 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n",
" 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 2. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1.3519667,\n",
" 1.6946585, 2.3051722, 3.6734192, 9.067533 , 17.521563 ,\n",
" 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n",
" 21.022812 , 21.053455 , 0. ], dtype=float32),\n",
" array([ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -1.8809433, 1. , -25.66834 , 1. ,\n",
" 2. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 16.768637 , 23.414627 ,\n",
" 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n",
" 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n",
" 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n",
" 1. , 2. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 16.768671 ,\n",
" 23.414669 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n",
" 21.050694 , 15.049746 , 11.578423 , 9.695195 , 20.398046 ,\n",
" 20.368372 , 20.398046 , 0. ], dtype=float32),\n",
" array([ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -13.672583 , 1. , -26.479263 , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 5.3249803, 6.401276 ,\n",
" 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n",
" 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n",
" 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 5.3249855,\n",
" 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n",
" 21.888773 , 22.922543 , 24.346823 , 26.097757 , 21.210148 ,\n",
" 21.17929 , 21.21015 , 0. ], dtype=float32),\n",
" array([ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -4.9038744, 1. , -25.185507 , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 20.33171 , 22.859762 ,\n",
" 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n",
" 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n",
" 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 20.331783 ,\n",
" 22.85977 , 21.522448 , 20.551773 , 20.00121 , 20.001146 ,\n",
" 20.551619 , 21.522217 , 17.707582 , 14.868943 , 19.914528 ,\n",
" 19.88554 , 19.914494 , 0. ], dtype=float32)],\n",
" [[-0.05], [-0.05], [-0.05], [-0.05]],\n",
" [[False], [False], [False], [False]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"actions = np.zeros_like(np.arange(16).reshape(4, 4))\n",
"print(actions)\n",
"env.step(actions)"
]
}
],
"metadata": {