From 0dbe2013ae5324eab6332131c6c5b4a7e9410469 Mon Sep 17 00:00:00 2001 From: Koha9 Date: Tue, 1 Nov 2022 19:11:45 +0900 Subject: [PATCH] weight and bias sync added weight and bias sync added --- .gitignore | 2 + Aimbot-PPO-Python/Pytorch/AimbotEnv.py | 2 +- Aimbot-PPO-Python/Pytorch/ppo.py | 38 +++++++- Aimbot-PPO-Python/Pytorch/testarea.ipynb | 118 ----------------------- 4 files changed, 40 insertions(+), 120 deletions(-) diff --git a/.gitignore b/.gitignore index 270a7f4..1292001 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,8 @@ crashlytics-build.properties /Aimbot-PPO-Python/__pycache__/ /Aimbot-PPO-Python/Tensorflow/__pycache__/ /Aimbot-PPO-Python/Pytorch/__pycache__/ +/Aimbot-PPO-Python/Pytorch/runs/ +/Aimbot-PPO-Python/Pytorch/wandb/ /Aimbot-PPO-Python/Backup/ /Aimbot-PPO-Python/Build-MultiScene-WithLoad/ /Aimbot-PPO-Python/Build-CloseEnemyCut/ diff --git a/Aimbot-PPO-Python/Pytorch/AimbotEnv.py b/Aimbot-PPO-Python/Pytorch/AimbotEnv.py index 0f68631..ae1384b 100644 --- a/Aimbot-PPO-Python/Pytorch/AimbotEnv.py +++ b/Aimbot-PPO-Python/Pytorch/AimbotEnv.py @@ -99,7 +99,7 @@ class Aimbot(gym.Env): # create continuous actions from actions list continuousActions = actions[:,self.unity_discrete_size :] """ - continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0]]) + continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]) # create actionTuple thisActionTuple = ActionTuple(continuous=continuousActions, discrete=discreteActions) # take action to env diff --git a/Aimbot-PPO-Python/Pytorch/ppo.py b/Aimbot-PPO-Python/Pytorch/ppo.py index 597327e..f01ca62 100644 --- a/Aimbot-PPO-Python/Pytorch/ppo.py +++ b/Aimbot-PPO-Python/Pytorch/ppo.py @@ -1,4 +1,5 @@ import argparse +import wandb import time import numpy as np import random @@ -14,6 +15,7 @@ from torch.utils.tensorboard import SummaryWriter DEFAULT_SEED = 9331 ENV_PATH = "../Build-ParallelEnv/Aimbot-ParallelEnv" +WAND_ENTITY = "koha9" WORKER_ID = 1 BASE_PORT = 2002 @@ -60,6 +62,8 @@ def parse_args(): help="the K epochs to update the policy") parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") # GAE parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Use GAE for advantage computation") @@ -139,6 +143,26 @@ if __name__ == "__main__": agent = PPOAgent(env).to(device) optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5) + # Tensorboard and WandB Recorder + game_name = "Aimbot" + run_name = f"{game_name}__{args.seed}__{int(time.time())}" + wandb.init( + project=run_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + # Memory Record obs = torch.zeros((args.stepNum, env.unity_agent_num) + env.unity_observation_shape).to(device) actions = torch.zeros((args.stepNum, env.unity_agent_num) + (env.unity_discrete_type,)).to( @@ -168,7 +192,6 @@ if __name__ == "__main__": # MAIN LOOP: run agent in environment for step in range(args.stepNum): - print(step) global_step += 1 * env.unity_agent_num obs[step] = next_obs dones[step] = next_done @@ -289,3 +312,16 @@ if __name__ == "__main__": if args.target_kl is not None: if approx_kl > args.target_kl: break + # record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + env.close() + writer.close() diff --git a/Aimbot-PPO-Python/Pytorch/testarea.ipynb b/Aimbot-PPO-Python/Pytorch/testarea.ipynb index 7b273ec..44a1f64 100644 --- a/Aimbot-PPO-Python/Pytorch/testarea.ipynb +++ b/Aimbot-PPO-Python/Pytorch/testarea.ipynb @@ -303,124 +303,6 @@ "(128, 4) + env.unity_observation_shape\n", "env.reset()" ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[0 0 0 0]\n", - " [0 0 0 0]\n", - " [0 0 0 0]\n", - " [0 0 0 0]]\n", - "[[0]\n", - " [0]\n", - " [0]\n", - " [0]]\n", - "[[0 0 0]\n", - " [0 0 0]\n", - " [0 0 0]\n", - " [0 0 0]]\n" - ] - }, - { - "data": { - "text/plain": [ - "([array([ 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , -15.994009 , 1. , -26.322788 , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 2. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1.3519633, 1.6946528,\n", - " 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n", - " 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n", - " 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 2. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1.3519667,\n", - " 1.6946585, 2.3051722, 3.6734192, 9.067533 , 17.521563 ,\n", - " 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n", - " 21.022812 , 21.053455 , 0. ], dtype=float32),\n", - " array([ 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , -1.8809433, 1. , -25.66834 , 1. ,\n", - " 2. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 16.768637 , 23.414627 ,\n", - " 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n", - " 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n", - " 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n", - " 1. , 2. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 16.768671 ,\n", - " 23.414669 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n", - " 21.050694 , 15.049746 , 11.578423 , 9.695195 , 20.398046 ,\n", - " 20.368372 , 20.398046 , 0. ], dtype=float32),\n", - " array([ 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , -13.672583 , 1. , -26.479263 , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 5.3249803, 6.401276 ,\n", - " 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n", - " 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n", - " 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 5.3249855,\n", - " 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n", - " 21.888773 , 22.922543 , 24.346823 , 26.097757 , 21.210148 ,\n", - " 21.17929 , 21.21015 , 0. ], dtype=float32),\n", - " array([ 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , -4.9038744, 1. , -25.185507 , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 20.33171 , 22.859762 ,\n", - " 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n", - " 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n", - " 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 1. ,\n", - " 1. , 1. , 1. , 1. , 20.331783 ,\n", - " 22.85977 , 21.522448 , 20.551773 , 20.00121 , 20.001146 ,\n", - " 20.551619 , 21.522217 , 17.707582 , 14.868943 , 19.914528 ,\n", - " 19.88554 , 19.914494 , 0. ], dtype=float32)],\n", - " [[-0.05], [-0.05], [-0.05], [-0.05]],\n", - " [[False], [False], [False], [False]])" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "actions = np.zeros_like(np.arange(16).reshape(4, 4))\n", - "print(actions)\n", - "env.step(actions)" - ] } ], "metadata": {