Aimbot-PPO/Aimbot-PPO-Python/GAIL-Main.ipynb

308 lines
136 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import time\n",
"import datetime\n",
"import aimBotEnv\n",
"\n",
"from GAIL import GAIL\n",
"from GAILConfig import GAILConfig\n",
"from PPOConfig import PPOConfig\n",
"from GAILMem import GAILMem\n",
"from GAILHistory import GAILHistory\n",
"from IPython.display import clear_output\n",
"from tqdm.notebook import tqdm as tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Attempts to allocate only the GPU memory needed for allocation\n",
"physical_devices = tf.config.list_physical_devices(\"GPU\")\n",
"tf.config.experimental.set_memory_growth(physical_devices[0], True)\n",
"tf.random.set_seed(9331)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"√√√√√Enviroment Initialized Success√√√√√\n",
"√√√√√Buffer Initialized Success√√√√√\n",
"√√√√√Buffer Initialized Success√√√√√\n",
"---------thisPPO Params---------\n",
"self.stateSize = 93\n",
"self.disActShape = [3, 3, 2]\n",
"self.disActSize 3\n",
"self.disOutputSize 8\n",
"self.conActSize = 1\n",
"self.conActRange = 10\n",
"self.conOutputSize = 2\n",
"---------thisPPO config---------\n",
"self.NNShape = [512, 512, 256]\n",
"self.criticLR = 0.002\n",
"self.actorLR = 0.002\n",
"self.gamma = 0.99\n",
"self.lmbda = 0.95\n",
"self.clipRange = 0.2\n",
"self.entropyWeight = 0.005\n",
"self.trainEpochs = 5\n",
"self.saveDir = GAIL-Model/1020-0318/\n",
"self.loadModelDir = None\n",
"---------Actor Model Create Success---------\n",
"Model: \"model_1\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" stateInput (InputLayer) [(None, 93)] 0 [] \n",
" \n",
" dense0 (Dense) (None, 512) 48128 ['stateInput[0][0]'] \n",
" \n",
" dense1 (Dense) (None, 512) 262656 ['dense0[0][0]'] \n",
" \n",
" dense2 (Dense) (None, 256) 131328 ['dense1[0][0]'] \n",
" \n",
" muOut (Dense) (None, 1) 257 ['dense2[0][0]'] \n",
" \n",
" sigmaOut (Dense) (None, 1) 257 ['dense2[0][0]'] \n",
" \n",
" disAct0 (Dense) (None, 3) 771 ['dense2[0][0]'] \n",
" \n",
" disAct1 (Dense) (None, 3) 771 ['dense2[0][0]'] \n",
" \n",
" disAct2 (Dense) (None, 2) 514 ['dense2[0][0]'] \n",
" \n",
" tf.math.multiply (TFOpLambda) (None, 1) 0 ['muOut[0][0]'] \n",
" \n",
" tf.math.add (TFOpLambda) (None, 1) 0 ['sigmaOut[0][0]'] \n",
" \n",
" totalOut (Concatenate) (None, 10) 0 ['disAct0[0][0]', \n",
" 'disAct1[0][0]', \n",
" 'disAct2[0][0]', \n",
" 'tf.math.multiply[0][0]', \n",
" 'tf.math.add[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 444,682\n",
"Trainable params: 444,682\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"---------Critic Model Create Success---------\n",
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" stateInput (InputLayer) [(None, 93)] 0 \n",
" \n",
" dense0 (Dense) (None, 512) 48128 \n",
" \n",
" dense1 (Dense) (None, 512) 262656 \n",
" \n",
" dense2 (Dense) (None, 256) 131328 \n",
" \n",
" dense (Dense) (None, 1) 257 \n",
" \n",
"=================================================================\n",
"Total params: 442,369\n",
"Trainable params: 442,369\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"ENV_PATH = \"./Build-CloseEnemyCut/Aimbot-PPO\"\n",
"EXPERT_DIR = \"GAIL-Expert-Data/1015-0148/pack-53518.npz\"\n",
"WORKER_ID = 1\n",
"BASE_PORT = 200\n",
"MAX_BUFFER_SIZE = 256\n",
"\n",
"MAX_EP = 1000000000\n",
"STACKSTATESSIZE = 3\n",
"STACKINTERCE = 29\n",
"\n",
"env = aimBotEnv.makeEnv(\n",
" envPath=ENV_PATH,\n",
" workerID=WORKER_ID,\n",
" basePort=BASE_PORT,\n",
" stackSize=STACKSTATESSIZE,\n",
" stackIntercal=STACKINTERCE,\n",
")\n",
"\n",
"STATE_SIZE = env.STATE_SIZE\n",
"DISACT_SHAPE = env.DISCRETE_SHAPE\n",
"CONACT_SIZE = env.CONTINUOUS_SIZE\n",
"CONACT_RANGE = 10\n",
"\n",
"ppoConf = PPOConfig(\n",
" NNShape=[512, 512, 256],\n",
" actorLR=2e-3,\n",
" criticLR=2e-3,\n",
" gamma=0.99,\n",
" lmbda=0.95,\n",
" clipRange=0.20,\n",
" entropyWeight=5e-3,\n",
" trainEpochs=5,\n",
" saveDir=\"GAIL-Model/\" + datetime.datetime.now().strftime(\"%m%d-%H%M\") + \"/\",\n",
" loadModelDir=None,\n",
")\n",
"gailConf = GAILConfig(\n",
" discrimNNShape=[256, 128],\n",
" discrimLR=1e-4,\n",
" discrimTrainEpochs=5,\n",
" discrimSaveDir=\"GAIL-Model/\" + datetime.datetime.now().strftime(\"%m%d-%H%M\") + \"/\",\n",
" ppoConfig=ppoConf\n",
")\n",
"\n",
"agentMem = GAILMem()\n",
"expertMem = GAILMem()\n",
"expertMem.loadMemFile(EXPERT_DIR)\n",
"gailHis = GAILHistory()\n",
"gail = GAIL(\n",
" stateSize=STATE_SIZE,\n",
" disActShape=DISACT_SHAPE,\n",
" conActSize=CONACT_SIZE,\n",
" conActRange=CONACT_RANGE,\n",
" gailConfig=gailConf,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20777.3\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABMgAAALyCAYAAAAv/+j+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd3xT9ffH8VeSpm3S0j0oG9lbUFBAloACDhw/3AtxIyiKinugIuBEUfwiIop7IChDkS2gqCxBWbKhdC86kjTJ74+2sbUD0LZpm/fz8eBBc3Nzc3Ibyu3JOedjMMS0cSMiIiIiIiIiIuKjjN4OQERERERERERExJuUIBMREREREREREZ+mBJmIiIiIiIiIiPg0JchERERERERERMSnKUEmIiIiIiIiIiI+TQkyERERERERERHxaUqQiYhUkbjYWH76YQkmo37UioiISN3z+AP3c/vIG+nSsQOfzn6n0o+//Jt5NIirX+nHFREpi5+3AxAR+Tcef+B+zju3P478fByOfHbu3s1Lb7zJgUOHvR2aiIiIiE/Zsm07V468pdKPe+5Fl/7rx86bO4fnX36VXzZuqsSIRKQuU1mDiNRacz/7gnMvupSLr7qWpOQUHr1/nNdiUZWYiIiISOXQdZWIeIMqyETklM2bO4cvFnzD0EEDaRgXx9KVq5gxazaPP3g/nTt24I8dO3nkmefIOn6cDu3acs8dt9G8aROOJSTyypsz2LhlKwAXnD+Y668YQXR0FOkZGXzwyed8vXARAN26dOapCQ/wyZfzuO7KK3C5XLz17mwWfre0VDw2u50fVq3m+ccf9WyLiozg/rvv4vROHcnNzeOTL+fx2dfz8Teb+f7rLxh+9fVkZGZy0zVXccuN13PepSPIycnhtptuwGqx8Opbb9PrrB7cPvIGGsXFcTw7h2+WfMc7788FCton5304h+defIVRN1xL/LEERo9/iNG33MwF5w8mOyeHjz7/qkScF5w3mJuvv4aw0FAyMjJ5e/Ycvlu+oqq+TSIiIiKVqnXLFjx6/zgaNWzA+g2/4Ha7gb+v2y6++noArr9yBCMuHU6Q1UpySipTp73Br5s2YzQauf7KEVw09HzCw8I4dPgIDz75NIlJyfz0wxKmTpvOVZddgslk4rLrb+KnH5bwfzeM5PDReB5/4H7ybHk0qF+fLp06suevvUx4+lluuOoKhp03iNS0dJ54/gV27fmrwtdgNpsZfcvNDOzXF4Blq1Yz/Z13cTgchIaE8PiD99OlYwdcLjf7DhzgzvsewO12l/uaDAYD1105guHDhlAvOJhfNm1myqvTyMw6jr/ZzCP3j6NnjzMxGo0cOnKU8Y8+QWp6epV+n0Tk31GCTET+lQF9ejP2wYcxmUzMmTGdNi1b8NxLr7D/wEFefn4iV1w6nAWLlvDyc8/w1AtT+emXX+ne9XQmPfkYV468lfSMDNLSM7j/sSc5Eh9P186deOX5ify5cxc79+wBICIigqCgIC666lp6nNGNSU88yuq168k6frxELIGBAZw3oD+Hjx4FwGAw8OLEp1m9bj2PP/cCMdFRvD5lEgcOH+bnX3/jz5276NalEyvWrKVr504cS0igS4f2rP/lV7p27sQnX84DIC8vj2cmv8je/Qdo0awZ06Y8z649f7F63XrPc3ft0omrbr4Vt8vN8GFD6H32Wdxwx2jy8vKY9OTjJWK8b/QdjBx9DwcPHyYyIoKQevWq+LskIiIiUjn8/PyY/PQTfPrV13z+9QL69urJxEcn8MGnn5fYr0mjRvzf8Iu5efRYklNSiYuNxVhYEXb1/13G4HP7c98jT3Dw8GFantacPJvN89h+vXsyasw92Gz2MmMY2K8v90x4lH37D/Dy8xN5Z9orzHz/A6a9PZNbb7yee+64jdHjH6rwddx0zVV0bN+WG+64C7cbpjzzJCOvvZr/vfc+14y4nMSkZIZcfiUAHdu1xe12V/iaRlxyMf169+TO+x4kPSOD+0bfyfgxd/PE8y8w7LzBBAdZufjq63E4HLRqcRp59rJfm4h4n2pXReRf+XzeAlLT00lKSWHLtm1s37GDXXv+wu5wsGrtOlq3bMGQQeey7udfPJ8wbti4iT937aZXj+4ArPt5A0fi4wHYtPV3fv5tI106dfA8R35+Pu9+8CFOp5P1G34hJzePJo0bee6/ZsTlLP36C5YvmEeXjh14+oWpALRv05qwsFDenfsR+fn5HI0/xvxFSxjcv5/nubp27ozJaKTlac35bN58unbphL/ZTLs2rdn0++8AbNyylb/27cftdrNn3z6+X76Srl06lTgP77w/l7w8Gza7nYH9+vLpV/NITEomM+s4cz7+tMS+LrebFs2bEuDvT0pqKvsOHKjk74qIiIhI1ejYri1+Jj8++XIeTqeTFWt+5I+du0rt53I5MZvNNG/aFJPJRHxCgud67+KhQ3h79hwOHi6YGbtn7z4yM7M8j53z8adkZh3HVk4SadXadezcvcdzvWm321m8dBkul4sfVq6idcsWJ3wd5w8cwKwPPiItPYP0jAxmffAhQwcNBAquPaMiI4iLjcHpdLJl2/YTvqbLLryAGe/OISk5GYfDwTvvz+XcvudgMhrJd+YTEhJC4wYNcLlc7Ny9h5ycnFM46yJSnVRBJiL/SvHScJvNTmpaydtWi4X6sTGc268P5/Q8y3Ofn8mP3zZvAaBn9zMZdcO1NG7YCKPRQGBAAH/t2+/ZNzMzE6fLVey4NqwWi+f2R59/yduz5xAbE82rk56lSeNG7Nm3j/qxMURFRrL06y88+xqNRrb8vg0oSJDdc8dttGnVkr/27WfDb5t4dPw4OrZvx+GjRz0Xah3atuGuW27mtGZNMZvNmM1mlq9aU+I8JCQmeb6OiowkISnZc/tYQqLn67w8G489O4lrR1zOI/ePY+v2P5g2439aVEBERERqhajISJJSUkpsK36tU+Tw0XhefWsGt9xwHc2bNuHnX3/jtRn/IzklldjoKI4cjS/3ORKTksq9Dyh1vVnyetSGpdh1YkWvo3jcxxISiIqMAODDz77glhuu47XJzwPw9cLFfPDJZxW+pvqxMUx+6nFche2mAE6Xi4jwcBYvXUZsdDQTH5tAcFAw3y1bzlvvvofT6TxhnCJS/ZQgE5Eqk5CUzJIfljHp5ddK3Wc2m5n05GM8PflFVq9bj9PpZPLTT2AwGE79eRKTeGX6DB5/cDxrf/qZhMRk4uOPMeKmUWXuv3X7HzRp1JB+5/Rm09bf2X/wILEx0fTq0Z1NW3737Pf0Iw/xxfxvGPfwY9gdDu6983bCQkNLHqzYxVBKasGFX5H6MdEldv3519/4+dffCPD35/aRN/Lwffdyx7jxp/x6RURERKpbSmoq0ZGRJbbVj4n2VFIV9/3ylXy/fCVWq5UJ945l9C2jeHryVBKSkmnYII69+8uuoncXu66qKskpKdSPjfFU8sfGxJCckgpATm4u096eybS3Z3Jas6a8MXUyf+7cxa+bNlf4mp578WW2bv+jzOeb9cGHzPrgQ+JiY3n5+Wc4cOgw3yz5rspfp4icOrVYikiVWfLDMs45+yzOOvMMjEYj/mYz3bp0JjoqCrOfH2azmfSMDJxOJz27n8lZZ3T718+1YeMmklNSGH7BUP7YuZOc3Fyuv3IEAf7+GI1GTmvWlHZtWgMFnzDu2L2H/7v4Qs+CAb9v/5NLLxzGpq1/J8isFiuZmVnYHQ7at2nNeecOqDCGZatWc8Wlw4mOiqJecDDXX3WF576IsDD69DqbwMAA7A4HObm5uIpVx4mIiIjUZL//8SdOl5MrLh2OyWSi/zm9ad+2Tan9mjRqxBmnd8FsNmO327HZbbjcBdc8CxYv4fabbqRxwwYAtGzenJCQqpvJ6mcy4W82e/6YjEaWrljFyGuvJiw0lNCQEEZdfy1Lli0HoPdZPWjUIA6A49nZuFwuXC5Xha9p3rcLuePmm6gfEwNAWGgofXqdDRQsXtCieTOMRiPZOdnk5zs9jxORmkcVZCJSZRKTknnwiacZfesonnlkAi6Xkz927mLKa6+Tk5vLy9Nn8Nzjj2A2m/lx/c+
"text/plain": [
"<Figure size 1512x936 with 8 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "61b563af512640828f05fb04f9795a5b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/256 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"bestReward = 0\n",
"for ep in range(MAX_EP):\n",
" # get sample\n",
" state, _, _, _, _ = env.reset()\n",
" totalRewards = []\n",
" totalReward = 0\n",
" saveNow = 0\n",
" for step in tqdm(range(MAX_BUFFER_SIZE)):\n",
" actions, predictResult = gail.getActions(state)\n",
" nextState, reward, done, _, saveNow = env.step(actions)\n",
" agentMem.saveMems(\n",
" state=state, actorProb=predictResult, action=actions, reward=reward, done=done\n",
" )\n",
" state = nextState\n",
" totalReward += reward\n",
" if done:\n",
" totalRewards.append(totalReward)\n",
" totalReward = 0\n",
" state, _, _, _, _ = env.reset()\n",
" # add reward to history\n",
" totalRewards.append(totalReward)\n",
" # get all memory data\n",
" demoStates, _, demoActions, _, _ = expertMem.getRandomSample(MAX_BUFFER_SIZE)\n",
" agentStates = agentMem.getStates()\n",
" agentActions = agentMem.getActions()\n",
" agentActorProbs = agentMem.getActorProbs()\n",
" agentDones = agentMem.getDones()\n",
" # train discriminatorQ\n",
" discrimLosses, demoAcc, agentAcc = gail.trainDiscriminator(\n",
" demoStates, demoActions, agentStates, agentActions\n",
" )\n",
" # get disriminator predict rewards\n",
" discrimRewards = gail.inference(agentStates, agentActions) * 10.0\n",
" # train agentPPO\n",
" actorLosses, criticLosses, averageEntropy, discreteEntropys, continuousEntropys = gail.trainPPO(\n",
" agentStates, agentActorProbs, agentActions, discrimRewards, agentDones, nextState\n",
" )\n",
" gailHis.saveHis(\n",
" np.mean(totalRewards),\n",
" discrimLosses,\n",
" actorLosses,\n",
" criticLosses,\n",
" demoAcc,\n",
" agentAcc,\n",
" averageEntropy,\n",
" discrimRewards,\n",
" )\n",
" clear_output()\n",
" gailHis.drawHis()\n",
" # got best reward?\n",
" if np.mean(totalRewards) >= bestReward:\n",
" bestReward = np.mean(totalRewards)\n",
" gail.saveWeights(np.mean(totalRewards))\n",
" agentMem.clearMem()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}