2022-10-14 16:08:08 +00:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 1,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import numpy as np\n",
|
|
|
|
"import tensorflow as tf\n",
|
|
|
|
"import time\n",
|
|
|
|
"import datetime\n",
|
|
|
|
"import aimBotEnv\n",
|
|
|
|
"\n",
|
|
|
|
"from GAIL import GAIL\n",
|
|
|
|
"from GAILConfig import GAILConfig\n",
|
|
|
|
"from PPOConfig import PPOConfig\n",
|
|
|
|
"from GAILMem import GAILMem\n",
|
|
|
|
"from GAILHistory import GAILHistory\n",
|
|
|
|
"from IPython.display import clear_output\n",
|
|
|
|
"from tqdm.notebook import tqdm as tqdm"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# Attempts to allocate only the GPU memory needed for allocation\n",
|
|
|
|
"physical_devices = tf.config.list_physical_devices(\"GPU\")\n",
|
|
|
|
"tf.config.experimental.set_memory_growth(physical_devices[0], True)\n",
|
|
|
|
"tf.random.set_seed(9331)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 3,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"√√√√√Enviroment Initialized Success√√√√√\n",
|
|
|
|
"√√√√√Buffer Initialized Success√√√√√\n",
|
|
|
|
"√√√√√Buffer Initialized Success√√√√√\n",
|
|
|
|
"---------thisPPO Params---------\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"self.stateSize = 93\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"self.disActShape = [3, 3, 2]\n",
|
|
|
|
"self.disActSize 3\n",
|
|
|
|
"self.disOutputSize 8\n",
|
|
|
|
"self.conActSize = 1\n",
|
|
|
|
"self.conActRange = 10\n",
|
|
|
|
"self.conOutputSize = 2\n",
|
|
|
|
"---------thisPPO config---------\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"self.NNShape = [512, 512, 256]\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"self.criticLR = 0.002\n",
|
|
|
|
"self.actorLR = 0.002\n",
|
|
|
|
"self.gamma = 0.99\n",
|
|
|
|
"self.lmbda = 0.95\n",
|
|
|
|
"self.clipRange = 0.2\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"self.entropyWeight = 0.005\n",
|
|
|
|
"self.trainEpochs = 5\n",
|
|
|
|
"self.saveDir = GAIL-Model/1020-0318/\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"self.loadModelDir = None\n",
|
|
|
|
"---------Actor Model Create Success---------\n",
|
|
|
|
"Model: \"model_1\"\n",
|
|
|
|
"__________________________________________________________________________________________________\n",
|
|
|
|
" Layer (type) Output Shape Param # Connected to \n",
|
|
|
|
"==================================================================================================\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" stateInput (InputLayer) [(None, 93)] 0 [] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense0 (Dense) (None, 512) 48128 ['stateInput[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense1 (Dense) (None, 512) 262656 ['dense0[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense2 (Dense) (None, 256) 131328 ['dense1[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" muOut (Dense) (None, 1) 257 ['dense2[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" sigmaOut (Dense) (None, 1) 257 ['dense2[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" disAct0 (Dense) (None, 3) 771 ['dense2[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" disAct1 (Dense) (None, 3) 771 ['dense2[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" disAct2 (Dense) (None, 2) 514 ['dense2[0][0]'] \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
|
|
|
" tf.math.multiply (TFOpLambda) (None, 1) 0 ['muOut[0][0]'] \n",
|
|
|
|
" \n",
|
|
|
|
" tf.math.add (TFOpLambda) (None, 1) 0 ['sigmaOut[0][0]'] \n",
|
|
|
|
" \n",
|
|
|
|
" totalOut (Concatenate) (None, 10) 0 ['disAct0[0][0]', \n",
|
|
|
|
" 'disAct1[0][0]', \n",
|
|
|
|
" 'disAct2[0][0]', \n",
|
|
|
|
" 'tf.math.multiply[0][0]', \n",
|
|
|
|
" 'tf.math.add[0][0]'] \n",
|
|
|
|
" \n",
|
|
|
|
"==================================================================================================\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"Total params: 444,682\n",
|
|
|
|
"Trainable params: 444,682\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"Non-trainable params: 0\n",
|
|
|
|
"__________________________________________________________________________________________________\n",
|
|
|
|
"---------Critic Model Create Success---------\n",
|
|
|
|
"Model: \"model\"\n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
" Layer (type) Output Shape Param # \n",
|
|
|
|
"=================================================================\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" stateInput (InputLayer) [(None, 93)] 0 \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense0 (Dense) (None, 512) 48128 \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense1 (Dense) (None, 512) 262656 \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense2 (Dense) (None, 256) 131328 \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" dense (Dense) (None, 1) 257 \n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" \n",
|
|
|
|
"=================================================================\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"Total params: 442,369\n",
|
|
|
|
"Trainable params: 442,369\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"Non-trainable params: 0\n",
|
|
|
|
"_________________________________________________________________\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"ENV_PATH = \"./Build-CloseEnemyCut/Aimbot-PPO\"\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"EXPERT_DIR = \"GAIL-Expert-Data/1015-0148/pack-53518.npz\"\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"WORKER_ID = 1\n",
|
|
|
|
"BASE_PORT = 200\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
"MAX_BUFFER_SIZE = 256\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
"\n",
|
|
|
|
"MAX_EP = 1000000000\n",
|
|
|
|
"STACKSTATESSIZE = 3\n",
|
|
|
|
"STACKINTERCE = 29\n",
|
|
|
|
"\n",
|
|
|
|
"env = aimBotEnv.makeEnv(\n",
|
|
|
|
" envPath=ENV_PATH,\n",
|
|
|
|
" workerID=WORKER_ID,\n",
|
|
|
|
" basePort=BASE_PORT,\n",
|
|
|
|
" stackSize=STACKSTATESSIZE,\n",
|
|
|
|
" stackIntercal=STACKINTERCE,\n",
|
|
|
|
")\n",
|
|
|
|
"\n",
|
|
|
|
"STATE_SIZE = env.STATE_SIZE\n",
|
|
|
|
"DISACT_SHAPE = env.DISCRETE_SHAPE\n",
|
|
|
|
"CONACT_SIZE = env.CONTINUOUS_SIZE\n",
|
|
|
|
"CONACT_RANGE = 10\n",
|
|
|
|
"\n",
|
|
|
|
"ppoConf = PPOConfig(\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" NNShape=[512, 512, 256],\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" actorLR=2e-3,\n",
|
|
|
|
" criticLR=2e-3,\n",
|
|
|
|
" gamma=0.99,\n",
|
|
|
|
" lmbda=0.95,\n",
|
|
|
|
" clipRange=0.20,\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" entropyWeight=5e-3,\n",
|
|
|
|
" trainEpochs=5,\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" saveDir=\"GAIL-Model/\" + datetime.datetime.now().strftime(\"%m%d-%H%M\") + \"/\",\n",
|
|
|
|
" loadModelDir=None,\n",
|
|
|
|
")\n",
|
|
|
|
"gailConf = GAILConfig(\n",
|
|
|
|
" discrimNNShape=[256, 128],\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" discrimLR=1e-4,\n",
|
|
|
|
" discrimTrainEpochs=5,\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" discrimSaveDir=\"GAIL-Model/\" + datetime.datetime.now().strftime(\"%m%d-%H%M\") + \"/\",\n",
|
|
|
|
" ppoConfig=ppoConf\n",
|
|
|
|
")\n",
|
|
|
|
"\n",
|
|
|
|
"agentMem = GAILMem()\n",
|
|
|
|
"expertMem = GAILMem()\n",
|
|
|
|
"expertMem.loadMemFile(EXPERT_DIR)\n",
|
|
|
|
"gailHis = GAILHistory()\n",
|
|
|
|
"gail = GAIL(\n",
|
|
|
|
" stateSize=STATE_SIZE,\n",
|
|
|
|
" disActShape=DISACT_SHAPE,\n",
|
|
|
|
" conActSize=CONACT_SIZE,\n",
|
|
|
|
" conActRange=CONACT_RANGE,\n",
|
|
|
|
" gailConfig=gailConf,\n",
|
|
|
|
")\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 4,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
2022-10-23 14:38:07 +00:00
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"20777.3\n"
|
|
|
|
]
|
|
|
|
},
|
2022-10-14 16:08:08 +00:00
|
|
|
{
|
|
|
|
"data": {
|
2022-10-23 14:38:07 +00:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABMgAAALyCAYAAAAv/+j+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd3xT9ffH8VeSpm3S0j0oG9lbUFBAloACDhw/3AtxIyiKinugIuBEUfwiIop7IChDkS2gqCxBWbKhdC86kjTJ74+2sbUD0LZpm/fz8eBBc3Nzc3Ibyu3JOedjMMS0cSMiIiIiIiIiIuKjjN4OQERERERERERExJuUIBMREREREREREZ+mBJmIiIiIiIiIiPg0JchERERERERERMSnKUEmIiIiIiIiIiI+TQkyERERERERERHxaUqQiYhUkbjYWH76YQkmo37UioiISN3z+AP3c/vIG+nSsQOfzn6n0o+//Jt5NIirX+nHFREpi5+3AxAR+Tcef+B+zju3P478fByOfHbu3s1Lb7zJgUOHvR2aiIiIiE/Zsm07V468pdKPe+5Fl/7rx86bO4fnX36VXzZuqsSIRKQuU1mDiNRacz/7gnMvupSLr7qWpOQUHr1/nNdiUZWYiIiISOXQdZWIeIMqyETklM2bO4cvFnzD0EEDaRgXx9KVq5gxazaPP3g/nTt24I8dO3nkmefIOn6cDu3acs8dt9G8aROOJSTyypsz2LhlKwAXnD+Y668YQXR0FOkZGXzwyed8vXARAN26dOapCQ/wyZfzuO7KK3C5XLz17mwWfre0VDw2u50fVq3m+ccf9WyLiozg/rvv4vROHcnNzeOTL+fx2dfz8Teb+f7rLxh+9fVkZGZy0zVXccuN13PepSPIycnhtptuwGqx8Opbb9PrrB7cPvIGGsXFcTw7h2+WfMc7788FCton5304h+defIVRN1xL/LEERo9/iNG33MwF5w8mOyeHjz7/qkScF5w3mJuvv4aw0FAyMjJ5e/Ycvlu+oqq+TSIiIiKVqnXLFjx6/zgaNWzA+g2/4Ha7gb+v2y6++noArr9yBCMuHU6Q1UpySipTp73Br5s2YzQauf7KEVw09HzCw8I4dPgIDz75NIlJyfz0wxKmTpvOVZddgslk4rLrb+KnH5bwfzeM5PDReB5/4H7ybHk0qF+fLp06suevvUx4+lluuOoKhp03iNS0dJ54/gV27fmrwtdgNpsZfcvNDOzXF4Blq1Yz/Z13cTgchIaE8PiD99OlYwdcLjf7DhzgzvsewO12l/uaDAYD1105guHDhlAvOJhfNm1myqvTyMw6jr/ZzCP3j6NnjzMxGo0cOnKU8Y8+QWp6epV+n0Tk31GCTET+lQF9ejP2wYcxmUzMmTGdNi1b8NxLr7D/wEFefn4iV1w6nAWLlvDyc8/w1AtT+emXX+ne9XQmPfkYV468lfSMDNLSM7j/sSc5Eh9P186deOX5ify5cxc79+wBICIigqCgIC666lp6nNGNSU88yuq168k6frxELIGBAZw3oD+Hjx4FwGAw8OLEp1m9bj2PP/cCMdFRvD5lEgcOH+bnX3/jz5276NalEyvWrKVr504cS0igS4f2rP/lV7p27sQnX84DIC8vj2cmv8je/Qdo0awZ06Y8z649f7F63XrPc3ft0omrbr4Vt8vN8GFD6H32Wdxwx2jy8vKY9OTjJWK8b/QdjBx9DwcPHyYyIoKQevWq+LskIiIiUjn8/PyY/PQTfPrV13z+9QL69urJxEcn8MGnn5fYr0mjRvzf8Iu5efRYklNSiYuNxVhYEXb1/13G4HP7c98jT3Dw8GFantacPJvN89h+vXsyasw92Gz2MmMY2K8v90x4lH37D/Dy8xN5Z9orzHz/A6a9PZNbb7yee+64jdHjH6rwddx0zVV0bN+WG+64C7cbpjzzJCOvvZr/vfc+14y4nMSkZIZcfiUAHdu1xe12V/iaRlxyMf169+TO+x4kPSOD+0bfyfgxd/PE8y8w7LzBBAdZufjq63E4HLRqcRp59rJfm4h4n2pXReRf+XzeAlLT00lKSWHLtm1s37GDXXv+wu5wsGrtOlq3bMGQQeey7udfPJ8wbti4iT937aZXj+4ArPt5A0fi4wHYtPV3fv5tI106dfA8R35+Pu9+8CFOp5P1G34hJzePJo0bee6/ZsTlLP36C5YvmEeXjh14+oWpALRv05qwsFDenfsR+fn5HI0/xvxFSxjcv5/nubp27ozJaKTlac35bN58unbphL/ZTLs2rdn0++8AbNyylb/27cftdrNn3z6+X76Srl06lTgP77w/l7w8Gza7nYH9+vLpV/NITEomM+s4cz7+tMS+LrebFs2bEuDvT0pqKvsOHKjk74qIiIhI1ejYri1+Jj8++XIeTqeTFWt+5I+du0rt53I5MZvNNG/aFJPJRHxCgud67+KhQ3h79hwOHi6YGbtn7z4yM7M8j53z8adkZh3HVk4SadXadezcvcdzvWm321m8dBkul4sfVq6idcsWJ3wd5w8cwKwPPiItPYP0jAxmffAhQwcNBAquPaMiI4iLjcHpdLJl2/YTvqbLLryAGe/OISk5GYfDwTvvz+XcvudgMhrJd+YTEhJC4wYNcLlc7Ny9h5ycnFM46yJSnVRBJiL/SvHScJvNTmpaydtWi4X6sTGc268P5/Q8y3Ofn8mP3zZvAaBn9zMZdcO1NG7YCKPRQGBAAH/t2+/ZNzMzE6fLVey4NqwWi+f2R59/yduz5xAbE82rk56lSeNG7Nm3j/qxMURFRrL06y88+xqNRrb8vg0oSJDdc8dttGnVkr/27WfDb5t4dPw4OrZvx+GjRz0Xah3atuGuW27mtGZNMZvNmM1mlq9aU+I8JCQmeb6OiowkISnZc/tYQqLn67w8G489O4lrR1zOI/ePY+v2P5g2439aVEBERERqhajISJJSUkpsK36tU+Tw0XhefWsGt9xwHc2bNuHnX3/jtRn/IzklldjoKI4cjS/3ORKTksq9Dyh1vVnyetSGpdh1YkWvo3jcxxISiIqMAODDz77glhuu47XJzwPw9cLFfPDJZxW+pvqxMUx+6nFche2mAE6Xi4jwcBYvXUZsdDQTH5tAcFAw3y1bzlvvvofT6TxhnCJS/ZQgE5Eqk5CUzJIfljHp5ddK3Wc2m5n05GM8PflFVq9bj9PpZPLTT2AwGE79eRKTeGX6DB5/cDxrf/qZhMRk4uOPMeKmUWXuv3X7HzRp1JB+5/Rm09bf2X/wILEx0fTq0Z1NW3737Pf0Iw/xxfxvGPfwY9gdDu6983bCQkNLHqzYxVBKasGFX5H6MdEldv3519/4+dffCPD35/aRN/Lwffdyx7jxp/x6RURERKpbSmoq0ZGRJbbVj4n2VFIV9/3ylXy/fCVWq5UJ945l9C2jeHryVBKSkmnYII69+8uuoncXu66qKskpKdSPjfFU8sfGxJCckgpATm4u096eybS3Z3Jas6a8MXUyf+7cxa+bNlf4mp578WW2bv+jzOeb9cGHzPrgQ+JiY3n5+Wc4cOgw3yz5rspfp4icOrVYikiVWfLDMs45+yzOOvMMjEYj/mYz3bp0JjoqCrOfH2azmfSMDJxOJz27n8lZZ3T718+1YeMmklNSGH7BUP7YuZOc3Fyuv3IEAf7+GI1GTmvWlHZtWgMFnzDu2L2H/7v4Qs+CAb9v/5NLLxzGpq1/J8isFiuZmVnYHQ7at2nNeecOqDCGZatWc8Wlw4mOiqJecDDXX3WF576IsDD69DqbwMAA7A4HObm5uIpVx4mIiIjUZL//8SdOl5MrLh2OyWSi/zm9ad+2Tan9mjRqxBmnd8FsNmO327HZbbjcBdc8CxYv4fabbqRxwwYAtGzenJCQqpvJ6mcy4W82e/6YjEaWrljFyGuvJiw0lNCQEEZdfy1Lli0HoPdZPWjUIA6A49nZuFwuXC5Xha9p3rcLuePmm6gfEwNAWGgofXqdDRQsXtCieTOMRiPZOdnk5zs9jxORmkcVZCJSZRKTknnwiacZfesonnlkAi6Xkz927mLKa6+Tk5vLy9Nn8Nzjj2A2m/lx/c+
|
2022-10-14 16:08:08 +00:00
|
|
|
"text/plain": [
|
2022-10-23 14:38:07 +00:00
|
|
|
"<Figure size 1512x936 with 8 Axes>"
|
2022-10-14 16:08:08 +00:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"application/vnd.jupyter.widget-view+json": {
|
2022-10-23 14:38:07 +00:00
|
|
|
"model_id": "61b563af512640828f05fb04f9795a5b",
|
2022-10-14 16:08:08 +00:00
|
|
|
"version_major": 2,
|
|
|
|
"version_minor": 0
|
|
|
|
},
|
|
|
|
"text/plain": [
|
2022-10-23 14:38:07 +00:00
|
|
|
" 0%| | 0/256 [00:00<?, ?it/s]"
|
2022-10-14 16:08:08 +00:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"bestReward = 0\n",
|
|
|
|
"for ep in range(MAX_EP):\n",
|
|
|
|
" # get sample\n",
|
|
|
|
" state, _, _, _, _ = env.reset()\n",
|
|
|
|
" totalRewards = []\n",
|
|
|
|
" totalReward = 0\n",
|
|
|
|
" saveNow = 0\n",
|
|
|
|
" for step in tqdm(range(MAX_BUFFER_SIZE)):\n",
|
|
|
|
" actions, predictResult = gail.getActions(state)\n",
|
|
|
|
" nextState, reward, done, _, saveNow = env.step(actions)\n",
|
|
|
|
" agentMem.saveMems(\n",
|
|
|
|
" state=state, actorProb=predictResult, action=actions, reward=reward, done=done\n",
|
|
|
|
" )\n",
|
|
|
|
" state = nextState\n",
|
|
|
|
" totalReward += reward\n",
|
|
|
|
" if done:\n",
|
|
|
|
" totalRewards.append(totalReward)\n",
|
|
|
|
" totalReward = 0\n",
|
|
|
|
" state, _, _, _, _ = env.reset()\n",
|
|
|
|
" # add reward to history\n",
|
|
|
|
" totalRewards.append(totalReward)\n",
|
|
|
|
" # get all memory data\n",
|
|
|
|
" demoStates, _, demoActions, _, _ = expertMem.getRandomSample(MAX_BUFFER_SIZE)\n",
|
|
|
|
" agentStates = agentMem.getStates()\n",
|
|
|
|
" agentActions = agentMem.getActions()\n",
|
|
|
|
" agentActorProbs = agentMem.getActorProbs()\n",
|
|
|
|
" agentDones = agentMem.getDones()\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" # train discriminatorQ\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" discrimLosses, demoAcc, agentAcc = gail.trainDiscriminator(\n",
|
|
|
|
" demoStates, demoActions, agentStates, agentActions\n",
|
|
|
|
" )\n",
|
|
|
|
" # get disriminator predict rewards\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" discrimRewards = gail.inference(agentStates, agentActions) * 10.0\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" # train agentPPO\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" actorLosses, criticLosses, averageEntropy, discreteEntropys, continuousEntropys = gail.trainPPO(\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" agentStates, agentActorProbs, agentActions, discrimRewards, agentDones, nextState\n",
|
|
|
|
" )\n",
|
|
|
|
" gailHis.saveHis(\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" np.mean(totalRewards),\n",
|
|
|
|
" discrimLosses,\n",
|
|
|
|
" actorLosses,\n",
|
|
|
|
" criticLosses,\n",
|
|
|
|
" demoAcc,\n",
|
|
|
|
" agentAcc,\n",
|
|
|
|
" averageEntropy,\n",
|
|
|
|
" discrimRewards,\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" )\n",
|
|
|
|
" clear_output()\n",
|
|
|
|
" gailHis.drawHis()\n",
|
|
|
|
" # got best reward?\n",
|
2022-10-23 14:38:07 +00:00
|
|
|
" if np.mean(totalRewards) >= bestReward:\n",
|
2022-10-14 16:08:08 +00:00
|
|
|
" bestReward = np.mean(totalRewards)\n",
|
|
|
|
" gail.saveWeights(np.mean(totalRewards))\n",
|
|
|
|
" agentMem.clearMem()\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3.9.7 64-bit",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.9.7"
|
|
|
|
},
|
|
|
|
"orig_nbformat": 4,
|
|
|
|
"vscode": {
|
|
|
|
"interpreter": {
|
|
|
|
"hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 2
|
|
|
|
}
|