Aimbot-PPO/Aimbot-PPO-Python/GAIL-Main.ipynb

294 lines
113 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import time\n",
"import datetime\n",
"import aimBotEnv\n",
"\n",
"from GAIL import GAIL\n",
"from GAILConfig import GAILConfig\n",
"from PPOConfig import PPOConfig\n",
"from GAILMem import GAILMem\n",
"from GAILHistory import GAILHistory\n",
"from IPython.display import clear_output\n",
"from tqdm.notebook import tqdm as tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Attempts to allocate only the GPU memory needed for allocation\n",
"physical_devices = tf.config.list_physical_devices(\"GPU\")\n",
"tf.config.experimental.set_memory_growth(physical_devices[0], True)\n",
"tf.random.set_seed(9331)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"√√√√√Enviroment Initialized Success√√√√√\n",
"√√√√√Buffer Initialized Success√√√√√\n",
"√√√√√Buffer Initialized Success√√√√√\n",
"---------thisPPO Params---------\n",
"self.stateSize = 90\n",
"self.disActShape = [3, 3, 2]\n",
"self.disActSize 3\n",
"self.disOutputSize 8\n",
"self.conActSize = 1\n",
"self.conActRange = 10\n",
"self.conOutputSize = 2\n",
"---------thisPPO config---------\n",
"self.NNShape = [512, 256, 128]\n",
"self.criticLR = 0.002\n",
"self.actorLR = 0.002\n",
"self.gamma = 0.99\n",
"self.lmbda = 0.95\n",
"self.clipRange = 0.2\n",
"self.entropyWeight = 0.01\n",
"self.trainEpochs = 10\n",
"self.saveDir = GAIL-Model/1015-0101/\n",
"self.loadModelDir = None\n",
"---------Actor Model Create Success---------\n",
"Model: \"model_1\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" stateInput (InputLayer) [(None, 90)] 0 [] \n",
" \n",
" dense0 (Dense) (None, 512) 46592 ['stateInput[0][0]'] \n",
" \n",
" dense1 (Dense) (None, 256) 131328 ['dense0[0][0]'] \n",
" \n",
" dense2 (Dense) (None, 128) 32896 ['dense1[0][0]'] \n",
" \n",
" muOut (Dense) (None, 1) 129 ['dense2[0][0]'] \n",
" \n",
" sigmaOut (Dense) (None, 1) 129 ['dense2[0][0]'] \n",
" \n",
" disAct0 (Dense) (None, 3) 387 ['dense2[0][0]'] \n",
" \n",
" disAct1 (Dense) (None, 3) 387 ['dense2[0][0]'] \n",
" \n",
" disAct2 (Dense) (None, 2) 258 ['dense2[0][0]'] \n",
" \n",
" tf.math.multiply (TFOpLambda) (None, 1) 0 ['muOut[0][0]'] \n",
" \n",
" tf.math.add (TFOpLambda) (None, 1) 0 ['sigmaOut[0][0]'] \n",
" \n",
" totalOut (Concatenate) (None, 10) 0 ['disAct0[0][0]', \n",
" 'disAct1[0][0]', \n",
" 'disAct2[0][0]', \n",
" 'tf.math.multiply[0][0]', \n",
" 'tf.math.add[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 212,106\n",
"Trainable params: 212,106\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"---------Critic Model Create Success---------\n",
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" stateInput (InputLayer) [(None, 90)] 0 \n",
" \n",
" dense0 (Dense) (None, 512) 46592 \n",
" \n",
" dense1 (Dense) (None, 256) 131328 \n",
" \n",
" dense2 (Dense) (None, 128) 32896 \n",
" \n",
" dense (Dense) (None, 1) 129 \n",
" \n",
"=================================================================\n",
"Total params: 210,945\n",
"Trainable params: 210,945\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"ENV_PATH = \"./Build-CloseEnemyCut/Aimbot-PPO\"\n",
"EXPERT_DIR = \"GAIL-Expert-Data/1014-1302/pack-24957-RE.npz\"\n",
"WORKER_ID = 1\n",
"BASE_PORT = 200\n",
"MAX_BUFFER_SIZE = 2048\n",
"\n",
"MAX_EP = 1000000000\n",
"STACKSTATESSIZE = 3\n",
"STACKINTERCE = 29\n",
"\n",
"env = aimBotEnv.makeEnv(\n",
" envPath=ENV_PATH,\n",
" workerID=WORKER_ID,\n",
" basePort=BASE_PORT,\n",
" stackSize=STACKSTATESSIZE,\n",
" stackIntercal=STACKINTERCE,\n",
")\n",
"\n",
"STATE_SIZE = env.STATE_SIZE\n",
"DISACT_SHAPE = env.DISCRETE_SHAPE\n",
"CONACT_SIZE = env.CONTINUOUS_SIZE\n",
"CONACT_RANGE = 10\n",
"\n",
"ppoConf = PPOConfig(\n",
" NNShape=[512, 256, 128],\n",
" actorLR=2e-3,\n",
" criticLR=2e-3,\n",
" gamma=0.99,\n",
" lmbda=0.95,\n",
" clipRange=0.20,\n",
" entropyWeight=1e-2,\n",
" trainEpochs=10,\n",
" saveDir=\"GAIL-Model/\" + datetime.datetime.now().strftime(\"%m%d-%H%M\") + \"/\",\n",
" loadModelDir=None,\n",
")\n",
"gailConf = GAILConfig(\n",
" discrimNNShape=[256, 128],\n",
" discrimLR=1e-3,\n",
" discrimTrainEpochs=10,\n",
" discrimSaveDir=\"GAIL-Model/\" + datetime.datetime.now().strftime(\"%m%d-%H%M\") + \"/\",\n",
" ppoConfig=ppoConf\n",
")\n",
"\n",
"agentMem = GAILMem()\n",
"expertMem = GAILMem()\n",
"expertMem.loadMemFile(EXPERT_DIR)\n",
"gailHis = GAILHistory()\n",
"gail = GAIL(\n",
" stateSize=STATE_SIZE,\n",
" disActShape=DISACT_SHAPE,\n",
" conActSize=CONACT_SIZE,\n",
" conActRange=CONACT_RANGE,\n",
" gailConfig=gailConf,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABM4AAALyCAYAAAAi4Zi5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd1xV9ePH8RcXEEEEZDpQpDQ0NXPl3iijxPp+v61fZcumZdv2Ti21MsuypVmmNsyvOAD33qY5cpEMEWQjU4R77+8P7CZfMTWUw3g/Hw8e3HvGve/DETy8Oed87Ox8g6yIiIiIiIiIiIhIOSajA4iIiIiIiIiIiFRHKs5EREREREREREQqoOJMRERERERERESkAirOREREREREREREKqDiTEREREREREREpAIqzkRERERERERERCqg4kxEpIo18fNj8/Jo7E36ESwiIiK1y6vPPcND995Nx/bt+GHGV5f89VcunE/TJo0v+euKiJyLg9EBREQupVefe4ahgwZQUlpKSUkpBw8f5v1PPiXhaJLR0URERETqjN/27uPWe0de8tcdNOymf7zu/FkzGffBZLb9uvMSJhKR2k6nO4hIrTPrx58ZNOwmIm67g/SMTF5+5inDsuisMhEREZHK0zGViBhFZ5yJyCUzf9ZMfo5cSFjwYJo1acKy1WuY9vUMXh3zDNe0b8fvBw7y0ltjycvPp13bNjzx8IMEBrTgeGoaH346jV9/2w3A9SFDuOuWm/Hx8SbnxAm+m/sT/128BIDOHa/hjReeY+68+dx56y1YLBY+mz6DxTHLzspTfOoUy9esZdyrL9umeXt58sxjj3Jth/YUFZ1k7rz5/PjfBdRzdGTpf39m+O13cSI3l3v+7zZG3n0XQ2+6mcLCQh68ZwQuzs5M/uxzenW/jofuHYF/kybkFxSyMDqGr76dBZRdhjn/+5mMnfQh94+4g5TjqYx69nlGjbyP60OGUFBYyOyffimX8/qhQ7jvrv/Dw92dEydy+XzGTGJWrrpcu0lERETkkrmq1ZW8/MxT+Ddryqat27BarcBfx2wRt98FwF233szNNw2ngYsLGZlZTJzyCdt37sJkMnHXrTczLCyERh4eHE06xpjX3yQtPYPNy6OZOGUqt/3rRuzt7fnXXfeweXk0/xlxL0nJKbz63DOcLD5J08aN6dihPbF/HOGFN99hxG23ED40mKzsHF4b9y6HYv/4221wdHRk1Mj7GNy/HwAr1qxl6lfTKSkpwd3NjVfHPEPH9u2wWKzEJSTwyNPPYbVaz7lNdnZ23HnrzQwPD6Whqyvbdu5iwuQp5OblU8/RkZeeeYqe13XFZDJx9Fgyz778Glk5OZd1P4nIP6fiTEQuqYF9ezN6zIvY29szc9pUglpdydj3PyQ+IZEPxr3NLTcNJ3JJNB+MfYs33p3I5m3b6dbpWsa//gq33vsAOSdOkJ1zgmdeeZ1jKSl0uqYDH457m/0HD3EwNhYAT09PGjRowLDb7uC6Lp0Z/9rLrN2wibz8/HJZ6td3YujAASQlJwNgZ2fHpLffZO3GTbw69l18fbz5eMJ4EpKS2LJ9B/sPHqJzxw6sWreBTtd04HhqKh3bXc2mbdvpdE0H5s6bD8DJkyd5671JHIlP4MqWLZkyYRyHYv9g7cZNtvfu1LEDt933AFaLleHhofTu0Z0RD4/i5MmTjH/91XIZnx71MPeOeoLEpCS8PD1xa9jwMu8lERERkcpzcHDgvTdf44df/stP/42kX6+evP3yC3z3w0/llmvh789/hkdw36jRZGRm0cTPD9PpM8hu/8+/GDJoAE+/9BqJSUm0uiKQk8XFtnX79+7J/Y8/QXHxqQozDO7fjydeeJm4+AQ+GPc2X035kC+//Y4pn3/JA3ffxRMPP8ioZ5//2+245/9uo/3VbRjx8KNYrTDhrde5947b+eKbb/m/m/9NWnoGof++FYD2bdtgtVr/dptuvjGC/r178sjTY8g5cYKnRz3Cs48/xmvj3iV86BBcG7gQcftdlJSU0PrKKzh5quJtE5HqQee7isgl9dP8SLJyckjPzOS3vXvZd+AAh2L/4FRJCWs2bOSqVlcSGjyIjVu22f4qufXXnew/dJhe13UDYOOWrRxLSQFg5+49bNnxKx07tLO9R2lpKdO/+x6z2cymrdsoLDpJi+b+tvn/d/O/Wfbfn1kZOZ+O7dvx5rsTAbg66Co8PNyZPms2paWlJKccZ8GSaIYM6G97r07XXIO9yUSrKwL5cf4COnXsQD1HR9oGXcXOPXsA+PW33fwRF4/VaiU2Lo6lK1fTqWOHcl+Hr76dxcmTxRSfOsXg/v344Zf5pKVnkJuXz8w5P5Rb1mK1cmVgAE716pGZlUVcQsIl3isiIiIil177tm1wsHdg7rz5mM1mVq1bz+8HD521nMVixtHRkcCAAOzt7UlJTbUd60WEhfL5jJkkJpXdjzb2SBy5uXm2dWfO+YHcvHyKz1EurdmwkYOHY23HmqdOnSJq2QosFgvLV6/hqlZXnnc7QgYP5OvvZpOdc4KcEyf4+rvvCQseDJQdd3p7edLEzxez2cxve/edd5v+dcP1TJs+k/SMDEpKSvjq21kM6tcHe5OJUnMpbm5uNG/aFIvFwsHDsRQWFl7EV11EqprOOBORS+rM08yLi0+RlV3+uYuzM439fBnUvy99ena3zXOwd2DHrt8A6NmtK/ePuIPmzfwxmeyo7+TEH3HxtmVzc3MxWyxnvG4xLs7Otuezf5rH5zNm4ufrw+Tx79CiuT+xcXE09vPF28uLZf/92basyWTitz17gbLi7ImHHySodSv+iItn646dvPzsU7S/ui1Jycm2g7h2bYJ4dOR9XNEyAEdHRxwdHVm5Zl25r0NqWrrtsbeXF6npGbbnx1PTbI9PnizmlXfGc8fN/+alZ55i977fmTLtCw1mICIiItWet5cX6ZmZ5aadeZzzp6TkFCZ/No2RI+4kMKAFW7bv4KNpX5CRmYWfjzfHklPO+R5p6ennnAecdaxZ/li0GOczjhH/bjvOzH08NRVvL08Avv/xZ0aOuJOP3hsHwH8XR/Hd3B//dpsa+/ny3huvYjl92SqA2WLBs1EjopatwM/Hh7dfeQHXBq7ErFjJZ9O/wWw2nzeniBhDxZmIVLnU9Ayil69g/AcfnTXP0dGR8a+/wpvvTWLtxk2YzWbee/M17OzsLv590tL5cOo0Xh3zLBs2byE1LYOUlOPcfM/9FS6/e9/vtPBvRv8+vdm5ew/xiYn4+frQ67pu7Pxtj225N196np8XLOSpF1/hVEkJTz7yEB7u7uVf7IwDpcyssoPCPzX29Sm36JbtO9iyfQdO9erx0L138+LTT/LwU89e9PaKiIiIVKXMrCx8vLzKTWvs62M78+pMS1euZunK1bi4uPDCk6MZNfJ+3nxvIqnpGTRr2oQj8RWfcW8945jqcsnIzKSxn6/trH8/X18yMrMAKCwqYsrnXzLl8y+5omUAn0x8j/0HD7F9566/3aaxkz5g977fK3y/r7/7nq+/+54mfn58MO4tEo4msTA65rJvp4j8M7pUU0SqXPTyFfTp0Z3uXbtgMpmo5+hI547X4OPtjaODA46OjuScOIHZbKZnt65079L5H7/X1l93kpGZyfDrw/j94EEKi4q469abcapXD5PJxBUtA2gbdBVQ9lfJA4dj+U/EDbaBCvbs289NN4Szc/dfxZmLswu5uXmcKinh6qCrGDpo4N9mWLFmLbfcNBwfb28aurpy12232OZ5enjQt1cP6td34lRJCYVFRVjOOJtOREREpLra8/t+zBYzt9w0HHt7ewb06c3VbYLOWq6Fvz9dru2Io6Mjp06dovhUMRZr2fFOZFQ0D91zN82bNQWgVWAgbm6X736vDvb21HN0tH3Ym0wsW7WGe++4HQ93d9zd3Lj/rjuIXrESgN7dr8O/aRMA8gsKsFgsWCyWv92m+YsW8/B999DY1xcAD3d3+vbqAZQNmnBlYEtMJhMFhQWUlppt64lI9aQzzkSkyqWlZzDmtTcZ9cD9vPXSC1gsZn4/eIgJH31MYVERH0ydxthXX8L
"text/plain": [
"<Figure size 1512x936 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fb112ad158f248a68f0950ef142f1951",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2048 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"bestReward = 0\n",
"for ep in range(MAX_EP):\n",
" # get sample\n",
" state, _, _, _, _ = env.reset()\n",
" totalRewards = []\n",
" totalReward = 0\n",
" saveNow = 0\n",
" for step in tqdm(range(MAX_BUFFER_SIZE)):\n",
" actions, predictResult = gail.getActions(state)\n",
" nextState, reward, done, _, saveNow = env.step(actions)\n",
" agentMem.saveMems(\n",
" state=state, actorProb=predictResult, action=actions, reward=reward, done=done\n",
" )\n",
" state = nextState\n",
" totalReward += reward\n",
" if done:\n",
" totalRewards.append(totalReward)\n",
" totalReward = 0\n",
" state, _, _, _, _ = env.reset()\n",
" # add reward to history\n",
" totalRewards.append(totalReward)\n",
" # get all memory data\n",
" demoStates, _, demoActions, _, _ = expertMem.getRandomSample(MAX_BUFFER_SIZE)\n",
" agentStates = agentMem.getStates()\n",
" agentActions = agentMem.getActions()\n",
" agentActorProbs = agentMem.getActorProbs()\n",
" agentDones = agentMem.getDones()\n",
" # train discriminator\n",
" discrimLosses, demoAcc, agentAcc = gail.trainDiscriminator(\n",
" demoStates, demoActions, agentStates, agentActions\n",
" )\n",
" # get disriminator predict rewards\n",
" discrimRewards = gail.inference(agentStates, agentActions)\n",
" # train agentPPO\n",
" actorLosses, criticLosses = gail.trainPPO(\n",
" agentStates, agentActorProbs, agentActions, discrimRewards, agentDones, nextState\n",
" )\n",
" gailHis.saveHis(\n",
" np.mean(totalRewards), discrimLosses, actorLosses, criticLosses, demoAcc, agentAcc\n",
" )\n",
" clear_output()\n",
" gailHis.drawHis()\n",
" # got best reward?\n",
" if np.mean(totalRewards)>=bestReward:\n",
" bestReward = np.mean(totalRewards)\n",
" gail.saveWeights(np.mean(totalRewards))\n",
" agentMem.clearMem()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}