ScrollBall-DQN/scrollball-main.ipynb

511 lines
103 KiB
Plaintext
Raw Permalink Normal View History

2024-03-05 10:11:38 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 自由課題レポート\n",
"\n",
"## 1 課題概要\n",
" 今度使用するデータセットは自分で作成したゲーム環境が生成したエージェント観測データである。Double DQN(Deep Q-Learning)モデルを使用し、エージェントが観測した環境データによって最善な動作を予測する。\n",
"## 2 ゲーム環境\n",
"### 2.1 ゲーム概要\n",
" Ml-AgentsのサンプルオブジェクトScrollBallを模倣した、ターゲットを追いかけるゲームである。真ん中の球は自機エージェント、正方形はターゲットと設定する。ゲーム開始時にターゲットは自機エージェント以外のゲームエリア内にランダムで生成する。自機は上下左右各方向にコントロールすることができ、リアルにシミュレーションするため、Unityの物理エンジンを使用する。各方向に入力があった場合、直接に一定の速度でその方向に移動するわけではなく、一定な加速度を加えることと設定している。自機エージェントがターゲットと接触するとゲームクリアと判定し、8秒経過もしくはゲームエリアから落ちる場合は失敗判定とする。\n",
"### 2.2 観測情報\n",
" エージェントはターゲットの位置、および自機位置のx,z座標が観測できることとする。その他、自機の水平と垂直xz方向の速度と合わせてステップ毎に合計六つのデータが観測できる。\n",
"### 2.3 Reward設定\n",
" ステップ毎に環境からRewardが返すことになっている。このステップの操作に評価を行うことである。自機がターゲットに前ステップより近づくとその距離値が正のRewardとしてが戻る、逆に前ステップより遠く離れると離れた距離値が負のRewardとして戻ってくる。ゲーム成功時にもらうRewardが10,失敗するときに-10と設定している。\n",
"### 2.4 動作確定\n",
" 2種類の動作が存在する、垂直及び水平になる。垂直では-1,0,1三つの値があり、各自下、静止、上と意味する,水平も-1,0,1三つの値があり、各左、静止、右と意味する。\n",
"## 3 DoubleDQN\n",
"### 3.1 Q-Learning\n",
" Q-Learningはすべての環境状態を十分にサンプリングし、各状態ににQualityが最も高い動作を実行する機械学習手法である。\n",
" Q-Learning試行したすべての環境状態をQ-tableに記録し、その環境状態で試行した動作が環境からもらうRewardをそのQ値として記録する。まだ同様な環境が観測した場合、Q-Tableに記載したQ値が最も高い動作を実行する。だが多くな場合、ゲームの環境状態は無限であり、すべての状況を記録することは難しい、しかもメモリにも大量に消耗する。\n",
"### 3.2 DQN(Deep Q-Learning)\n",
" DQNはQ-Tableの代わりにニューラルネットワークを使用して、観測した環境状態から各動作のQ値を予測する。今回使用するDouble DQNの構造及びコードは以下になる。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tensorflow_addons\\utils\\ensure_tf_install.py:53: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.4.0 and strictly below 2.7.0 (nightly versions are not supported). \n",
" The versions of TensorFlow you are currently using is 2.8.0 and is not supported. \n",
"Some things might work, some things might not.\n",
"If you were to encounter a bug, do not file an issue.\n",
"If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. \n",
"You can find the compatibility matrix in TensorFlow Addon's readme:\n",
"https://github.com/tensorflow/addons\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ML-Agents Version : 0.27.0\n",
"TensroFlow Version: 2.8.0\n"
]
}
],
"source": [
"import mlagents_envs\n",
"from mlagents_envs.base_env import ActionTuple\n",
"from mlagents_envs.environment import UnityEnvironment\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_addons as tfa\n",
"import tensorboard\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import datetime\n",
"from collections import deque\n",
"from IPython.display import clear_output\n",
"\n",
"print(\"ML-Agents Version :\",mlagents_envs.__version__)\n",
"print(\"TensroFlow Version:\",tf.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# 割り当てに必要なGPUメモリのみを割り当てようとする\n",
"physical_devices = tf.config.list_physical_devices('GPU')\n",
"tf.config.experimental.set_memory_growth(physical_devices[0], True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以下環境を実行する際にディレクトリにすべて半角英数字符号となっていることが必要"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ステップ毎に環境観測データ数 6\n",
"ステップ毎に実行可能な動作数 2\n"
]
}
],
"source": [
"# 環境パラメータ\n",
"log_dir = \"ML-logs/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"env_path = './ScrollBall-Build/ML-ScrollBall-Sample'\n",
"brain_name = 'RollerBallBrain'\n",
"\n",
"# ゲーム環境獲得\n",
"env = UnityEnvironment(file_name=env_path, seed=1, side_channels=[])\n",
"env.reset()\n",
"\n",
"# 環境スペック獲得\n",
"tracked_agent = -1\n",
"behavior_specs = env.behavior_specs\n",
"behavior_name = list(behavior_specs)[0]\n",
"spec = behavior_specs[behavior_name]\n",
"observation_specs = spec.observation_specs[0] # 観測spec\n",
"action_spec = spec.action_spec # 動作spec\n",
"\n",
"\n",
"ENV_Discrete_ACTION_SIZE = action_spec.discrete_size# 連続的な動作のSize\n",
"ENV_Continuous_ACTION_SIZE = action_spec.continuous_size# 離散的な動作のSize\n",
"STATE_SIZE = observation_specs.shape[0]# 環境観測データ数\n",
"SAVE_STEPS = 100 # SAVE_STEPS毎にNNを保存する\n",
"ACTION_SIZE = ENV_Discrete_ACTION_SIZE * 3#トータル動作数、一種類の動作に三つの動作が存在するため、*3とする\n",
"MAX_EXP_NUM = 2500 # ExperiencePoolに保存できる最大過去記録数\n",
"\n",
"EPSILON_CUT_STEP = 1300\n",
"EPISODES = 500\n",
"REPLACE_STEPS = 50\n",
"BATCH_SIZE = 256\n",
"LEARNING_RATE = 0.0005\n",
"GAMMA = 0.9\n",
"\n",
"epsilon = 1\n",
"epsilon_min = 0.01\n",
"\n",
"print(\"ステップ毎に環境観測データ数\",STATE_SIZE)\n",
"print(\"ステップ毎に実行可能な動作数\",ENV_Discrete_ACTION_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Experience Pool\n",
"class experiencePool:\n",
" def __init__(self):\n",
" self.exp_pool = deque(maxlen=MAX_EXP_NUM)\n",
"\n",
" def add(self, state, action, reward, netx_state, done):\n",
" self.exp_pool.append((state, action, reward, netx_state, done))\n",
"\n",
" def get_random(self, num=1):\n",
" random_index = np.random.choice(len(self.exp_pool), num)\n",
" random_exps = [self.exp_pool[i] for i in random_index]\n",
" return random_exps\n",
"\n",
" def get_len(self):\n",
" return len(self.exp_pool)\n",
"\n",
"# DQNメソッド\n",
"class DQN:\n",
" def __init__(self,load,load_dir):\n",
" self.learning_rate = LEARNING_RATE\n",
" self.epsilon = 1\n",
" self.epsilon_min = 0.01\n",
" self.epsilon_cut = (1-self.epsilon_min)/EPSILON_CUT_STEP\n",
" self.gamma = GAMMA\n",
" \n",
" if load:\n",
" #既存NNデータをローディングする\n",
" self.epsilon = self.epsilon_min\n",
" main_load_dir = load_dir+\"main.h5\"\n",
" target_load_dir = load_dir+\"target.h5\"\n",
" self.main_net,self.target_net = self.loadNN(main_load_dir,target_load_dir)\n",
" else:\n",
" #新規mainとtarget NNを作成する\n",
" self.main_net = self.build_net()\n",
" self.target_net = self.build_net()\n",
" self.exp_pool = experiencePool()\n",
"\n",
" # ---------------------------------------------------------------------------------\n",
" def build_net(self):\n",
" # NNを作成\n",
" rectifiedAdam = tfa.optimizers.RectifiedAdam(learning_rate = self.learning_rate,weight_decay = 0.001)\n",
" #Adam = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)\n",
" neural_net = tf.keras.Sequential()\n",
" neural_net.add(tf.keras.layers.Dense(\n",
" units=128, activation='relu', input_dim=STATE_SIZE))\n",
" neural_net.add(tf.keras.layers.Dense(\n",
" units=256, activation='relu'))\n",
" neural_net.add(tf.keras.layers.Dense(\n",
" units=128, activation='relu'))\n",
" neural_net.add(tf.keras.layers.Dense(\n",
" units=64, activation='relu'))\n",
" neural_net.add(tf.keras.layers.Dense(\n",
" units=ACTION_SIZE, activation='elu'))\n",
"\n",
" neural_net.compile(optimizer=rectifiedAdam, loss='mse', metrics=['accuracy'])\n",
" \n",
" return neural_net\n",
"\n",
" def select_action(self, state):\n",
" # 動作Q値を予測と動作選択\n",
" random_num = np.random.sample()\n",
" \n",
" if random_num > self.epsilon:\n",
" # DQNをベースにし、動作を選択する\n",
" predictResult = self.main_net(state).numpy()[0]\n",
" actionX = np.argmax(predictResult[0:3])-1\n",
" actionZ = np.argmax(predictResult[3:])-1\n",
" action = np.array([actionX, actionZ], dtype=np.float32)\n",
" #print(\"action = \",action)\n",
" else:\n",
" # ランダムで動作を選択\n",
" actionX = np.random.randint(ACTION_SIZE/2)-1\n",
" actionY = np.random.randint(ACTION_SIZE/2)-1\n",
" action = np.array([actionX, actionY], dtype=np.float32)\n",
"\n",
" # 缩小epsilon\n",
" if self.epsilon > self.epsilon_min:\n",
" self.epsilon -= self.epsilon_cut\n",
"\n",
" return action\n",
"\n",
" def training(self):\n",
" # トレーニング開始\n",
" if self.exp_pool.get_len() >= BATCH_SIZE:\n",
" # トレーニング集を獲得\n",
" exp_set = self.exp_pool.get_random(num=BATCH_SIZE)\n",
" exp_state = [data[0] for data in exp_set] # EXP_Poolが記録した当時ラウンドの環境状態\n",
" exp_action = [data[1] for data in exp_set] # そのラウンドで選んだ動作\n",
" exp_reward = [data[2] for data in exp_set] # その動作に応じるreward\n",
" exp_next_state = [data[3] for data in exp_set] # その動作が実行した後の環境状態。\n",
" exp_done = [data[4] for data in exp_set] # 実行後にゲームが終了したか\n",
"\n",
" exp_state = np.asarray(exp_state).squeeze()\n",
" exp_action = np.asarray(exp_action).squeeze()\n",
" exp_next_state = np.asarray(exp_next_state).squeeze()\n",
"\n",
" # 各ネットでQ値予測\n",
" target_net_q = self.target_net(exp_next_state).numpy() # target_NN 未来状況のQ値を予測\n",
" main_net_q = self.main_net(exp_state).numpy() # main_NN 現在状況のQ値を予測\n",
"\n",
" # トレーニング用Q値、目標y、\n",
" y = main_net_q.copy() # (1,6)\n",
" \n",
" # Batch全体インデクス、[0,1,......,BATCH_SIZE]\n",
" batch_index = np.arange(BATCH_SIZE, dtype=np.int32)\n",
" \n",
" # 動作の値(-1,0,1)によってQ値のインデクス(Xは(0,1,2)、Zは(3,4,5),各自Shapeは(1,BATCH_SIZE))を作成\n",
" exp_actionX_index = exp_action[:,0] + 1\n",
" exp_actionZ_index = exp_action[:,1] + 4\n",
" exp_actionX_index = exp_actionX_index.astype(np.int)\n",
" exp_actionZ_index = exp_actionZ_index.astype(np.int)\n",
" \n",
" # target_NNが未来状況によって予測したQ値から 垂直/水平動作各自の最大値Q値を摘出\n",
" fixedX = np.max(target_net_q[:, :3], axis=1) # (batchsize,1)\n",
" fixedZ = np.max(target_net_q[:, -3:], axis=1) # (batchsize,1)\n",
" # そのラウンドで受けたreward+未来最大Q値の和で修正値となる\n",
" fixedX = exp_reward + self.gamma*fixedX \n",
" fixedZ = exp_reward + self.gamma*fixedZ\n",
" # ゲーム終了のラウンドでの修正値は元のreward,ゲーム続行する時の修正値はfixedXとfixedYとする\n",
" y_fixedX = np.where(exp_done,exp_reward,fixedX)\n",
" y_fixedZ = np.where(exp_done,exp_reward,fixedZ)\n",
" \n",
" # 修正値を応用\n",
" y[batch_index, exp_actionX_index] = y_fixedX\n",
" y[batch_index, exp_actionZ_index] = y_fixedZ\n",
"\n",
" # main_netに入れて、フィットする\n",
" self.main_net.fit(exp_state, y, epochs=5, verbose=0,callbacks = [tb_callback])\n",
" \n",
" def saveNN(self):\n",
" # 両NNを保存する\n",
" main_save_dir= \"ML-Model/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")+\"main.h5\"\n",
" target_save_dir= \"ML-Model/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")+\"target.h5\"\n",
" self.main_net.save(main_save_dir)\n",
" self.target_net.save(target_save_dir)\n",
" print(\"Model Saved\")\n",
"\n",
" def loadNN(self,main_load_dir,target_load_dir):\n",
" # 両NNをローディングする\n",
" main_net_loaded = tf.keras.models.load_model(main_load_dir)\n",
" target_net_loaded = tf.keras.models.load_model(target_load_dir)\n",
" print(\"Model Loaded\")\n",
" return main_net_loaded,target_net_loaded"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAI/CAYAAADkwzGCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABvX0lEQVR4nO3dd5gUVdqG8buGKBgQWkXFiFlRVAyIK1mMYFZUhMIVcRXTbrfpUzHtuphzXAsRAyZQFEWyrhgAswKKCghIaHKGmanvj8EBV0DCDDXh/l0XF9WnqnuewdqBd89b5wRxHCNJkiRJKplykg4gSZIkSVozizZJkiRJKsEs2iRJkiSpBLNokyRJkqQSzKJNkiRJkkowizZJkiRJKsEqJh0AIJVKxbvuumvSMSRJkiQpEaNGjcrGcbzN6s6ViKJt1113ZeTIkUnHkCRJkqREBEEwYU3nbI+UJEmSpBLMok2SJEmSSjCLNkmSJEkqwUrEM22rs3z5ciZNmsSSJUuSjqL1VLVqVerUqUOlSpWSjiJJkiSVeiW2aJs0aRJbbLEFu+66K0EQJB1H6yiOY2bOnMmkSZPYbbfdko4jSZIklXoltj1yyZIl1KpVy4KtlAmCgFq1ajlDKkmSJBWRElu0ARZspZT/3SRJkqSiU6KLtiTNnDmT+vXrU79+fWrXrs2OO+5Y+HrZsmW/u/b+++9n0aJFf/qZTZo0YeTIkTzwwANceeWVheMXX3wxLVq0KHz90EMPcfnllzNy5Eguv/zyDcr46aefrvW9kiRJkkqHEvtMW9Jq1arFF198AUDXrl3ZfPPN+cc//rHaa++//37OP/98qlWrtk6f3ahRI55//vnC119++SV5eXnk5eVRoUIFhg8fTps2bWjQoAENGjTY4IyHH374OuWRJEmSVHI507YeBg0axMEHH0y9evXo2LEjS5cu5cEHH2TKlCk0bdqUpk2bAnDJJZfQoEED9t9/f26++eY/fE79+vX5/vvvWbx4MXPnzmWzzTajfv36fP311wAMHz6cRo0aMXToUE466SSgoCjr2LEjTZo0Yffdd+fBBx9ca9b/fW/79u35y1/+wi677MLrr79OJpOhXr16HHfccSxfvhyAUaNG0bhxYw499FBatWrFr7/+CsCDDz7Ifvvtx4EHHsg555xTNH+YkiRJktaJRds6WrJkCR06dKBXr158/fXX5Obm8thjj3H55Zezww47MGTIEIYMGQLAHXfcwciRI/nqq68YNmwYX3311e8+q2LFihx88MGMGDGCjz/+mCOOOIIjjzyS4cOHM3nyZOI4ZqeddvpDhjFjxtC/f38+/fRTbrnllsJia138+OOPDB48mDfffJPzzz+fpk2b8vXXX7PZZpvx9ttvs3z5crp06cKrr77KqFGj6NixIzfccAMAd955J59//jlfffUVjz/++Eb8KUqSJElaXxZt6ygvL4/ddtuNvfbaC4D27dvz/vvvr/bal19+mUMOOYSDDz6Yb7/9lu++++4P1xx11FEMHz6c4cOH07BhQxo2bFj4+qijjlrt55544olUqVKFVCrFtttuy7Rp09Y5//HHH0+lSpWoV68eeXl5HHfccQDUq1eP8ePHM3bsWL755htatmxJ/fr1uf3225k0aRIABx54IOeddx49e/akYkU7aiVJkqRNqdT8C7zr0K7cMuyWwtcjLxoJQIOnVj7zdXPjm+napCs73LMDvy4oaO07ZPtDGNVpFJ36duKpz54qvHby1ZPZYYsdijznzz//zN13382IESPYeuut6dChw2qXv2/UqBGPP/44S5Ys4dJLL2Wbbbbhu+++Y5tttllj0ValSpXC4woVKpCbm7vOuX57b05ODpUqVSpc4TEnJ4fc3FziOGb//ffno48++sN73377bd5//3369u3LHXfcwddff23xJkmSJG0ipeZf3l2bdKVrk65/GI9vjv8wNuXvU/4w9uTJT/LkyU9u8NevUKEC48ePZ9y4ceyxxx4899xzNG7cGIAtttiC+fPnk0qlmDdvHtWrV2errbZi2rRpvPPOOzRp0uQPn9ewYUM6dOjAjjvuyLbbbgvANttswxtvvMErr7yywTk31N57782MGTP46KOPaNiwIcuXL+f7779n33335ZdffqFp06YcffTRvPTSSyxYsIAaNWps8oySJElSeVRqirakVa1alSiKOPPMM8nNzeWwww6jc+fOAHTq1Injjjuu8Nm2gw8+mH322YeddtqJRo0arfbztt56a7bZZhv233//wrGGDRvy4YcfctBBB22S72lVlStX5tVXX+Xyyy9n7ty55ObmcuWVV7LXXntx/vnnM3fuXOI45vLLL7dgkyRJkjahII7/OFO1qTVo0CAeOXLk78ZGjx7Nvvvum1AibSz/+0mSJEnrLgiCUXEcr3a/rz9diCQIgmeCIJgeBME3q4z1CoLgixW/xgdB8MWK8V2DIFi8yjmXGpQkSZKkjbAu7ZHdgYeBHr8NxHF89m/HQRDcA8xd5fof4ziuX0T5JEmSJKlc+9OiLY7j94Mg2HV154KCJQjPApoVcS5JkiRJEhu/T9tfgGlxHP+wythuQRB8HgTBsCAI/rKRny9JkiRJ5drGrh7ZFnhxlde/AjvHcTwzCIJDgT5BEOwfx/G8/31jEASdgE4AO++880bGkCRJkqSyaYNn2oIgqAicBvT6bSyO46VxHM9ccTwK+BHYa3Xvj+P4yTiOG8Rx3GCbbbbZ0BiSJEmSVKZtTHtkC2BMHMeTfhsIgmCbIAgqrDjeHdgT+GnjIiZj5syZ1K9fn/r161O7dm123HHHwtfLli373bX3338/ixYt+tPPbNKkCb9tbbDrrrtSr169ws8cPnz4Gt93wgknMGfOHAA233zz1V7ToUMHXn31VU499VTq16/PHnvswVZbbfW7zz/qqKPW8buXJEmSyobs119z19FHk+3QgWyHDtx14olks9mkY62XP22PDILgRaAJkAqCYBJwcxzH/wHO4fetkQDHALcGQbAcyAc6x3E8q2gjbxq1atXiiy++AKBr165svvnm/OMf/1jttffffz/nn38+1apVW6+vMWTIEFKp1J9e169fv3X+zN69ewMwdOhQ7r77bt56663Cc2srDCVJkqQyZ8IEoiZNyMyaBdtuC0CmXz+IItLpdMLh1t26rB7Zdg3jHVYz9hrw2sbHKpkGDRrEP/7xD3JzcznssMN47LHHeOKJJ5gyZQpNmzYllUoxZMgQLrnkEkaMGMHixYs544wzuOWWW9bp80855RR++eUXlixZwhVXXEGnTp2Aglm5kSNH/q7Ai+OYLl26MGDAAHbaaScqV678p5+/+eabs2DBAoYOHcrNN99MjRo1+PrrrznrrLOoV68eDzzwAIsXL6ZPnz7UrVuXGTNm0LlzZyZOnAgUFKeNGjVi2LBhXHHFFQAEQcD777/PFltssb5/nJIkSVLx+fFHaN6c8OqroXJlwjAsGG/YcOVxKbGxC5GUG0uWLKFDhw4MGjSIvfbaiwsuuIDHHnuMK6+8knvvvfd3s2Z33HEHNWvWJC8vj+bNm/PVV19x4IEH/uEzmzZtSoUKFahSpQqffPIJzzzzDDVr1mTx4sUcdthhnH766dSqVWu1eXr37s3YsWP57rvvmDZtGvvttx8dO3Zc5+/nyy+/ZPTo0dSsWZPdd9+dv/71r3z66ac88MADPPTQQ9x///1cccUVXHXVVRx99NFMnDiRVq1aMXr0aO6++24eeeQRGjVqxIIFC6hateqG/aFKkiRJRST75ZdEbdsSbr01ANFXXxHefDOpf/yDVefUStMM2282dsn/EiWbzXLXXXcVS49qXl4eu+22G3vtVbCuSvv27Xn//fdXe+3LL7/MIYccwsEHH8y3337Ld999t9rrhgwZwhdffMEnn3wCwIMPPshBBx3EkUceyS+//MIPP/yw2vcBvP/++7Rt25YKFSqwww470KzZ+m2Vd9hhh7H99ttTpUoV6taty7HHHgtAvXr1GD9+PAADBw7ksssuo379+rRu3Zp58+axYMECGjVqxNVXX82DDz7InDlzqFjR2l+SJEk
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAI/CAYAAAAGHyr7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABv/ElEQVR4nO39e3Db553ne34eAAQJ3iGRFElAtmRL1sU3iXIn6c6lk3QudjrxRcxMdWrPqenaPZvZqukzMztnz05md0+fqT41VXtm58zunrOpqZMzO1W9VTuT6Q1pR06c2NNJnEt30m2LkmwLkGJZvojEjyJFifyBVxDAs3+QkGRZF14A/C54v6pSsUgI+Bomgd8Hz/N8v8ZaKwAAAACAf0S8LgAAAAAA8GEENQAAAADwGYIaAAAAAPgMQQ0AAAAAfIagBgAAAAA+Q1ADAAAAAJ+JefXAPT09ds+ePV49PAAAAAB46uTJk1estb23+55nQW3Pnj16/fXXvXp4AAAAAPCUMeb9O32PrY8AAAAA4DMENQAAAADwGYIaAAAAAPgMQQ0AAAAAfIagBgAAAAA+Q1ADAAAAAJ8hqAEAAACAzxDUAAAAAMBnCGoAAAAA4DMENQAAAADwGYIaAAAAAPgMQQ0AAAAAfIagBgAAAAA+Q1ADAAAAAJ8hqAEAAACAzxDUAAAAAMBnCGoAAAAA4DMENQAAAADwGYIaAAAAAPjMhoKaMeZJY8x5Y8wFY8y3bvP9/7sx5vT6/35rjJmteqUAAAAA0CBi97qBMSYq6duSvihpXNJrxpgT1tpM5TbW2v/9Tbf/LyUdrUGtAAAAANAQNrKi9jFJF6y1F621BUnflfTMXW7/DUn/oRrFAQAAAEAj2khQS0m6dNOfx9e/9hHGmPsl7ZX00+2XBgAAANTOr96+ov/iz1/XaqnsdSl1M3JyXP/0e294XQY2oNrNRP5I0vestaXbfdMY801jzOvGmNenp6er/NAAAADAxv2bn1/QX2Yv6+fnG+O61Fqr/+dP3tZ/fP2Szk26XpeDe9hIUJuQtPumP6fXv3Y7f6S7bHu01n7HWvuEtfaJ3t7ejVcJAAAAVFFudkl//c6MJGlkbNzjaurj9fev6YOri5Kk0bE7Xc7DLzYS1F6TtN8Ys9cYE9daGDtx642MMQclJSX9urolAgAAANX1/KkJWSt94dAu/SQ7pdnFgtcl1dzIyXG1xqP65L6dev7UhIoNtOUziO4Z1Ky1RUl/IullSVlJf2GtPWuM+TNjzNM33fSPJH3XWmtrUyoAAACwfdZajYyN62N7dugff2G/CqWyXnzD8bqsmlpeLemHbzh68pF+/eefuF/T+RX96sIVr8vCXdyzPb8kWWtfkvTSLV/701v+/M+rVxYAAABQG6cvzeri9IL+/mce0MODnTrY36GRk+P6zz9xv9el1cwrmcvKrxT19aG0ju1Jqru1SSNjE/rsgT6vS8MdVLuZCAAAAOBro2MTao5F9NSjAzLG6PhQSqcvzeqd6XmvS6uZ0bFxDXa16BMP7FRzLKqvPTaoV85Oyl1e9bo03AFBDQAAAA1jpVjSiTM5ffnhfnW2NEmSnj2SUsSshZkwmnKX9YvfTuu5oZQiESNJGj6W1kqxrJdCvuUzyAhqAAAAaBg/zU5pbmlVx4dujAXu62zRp/f36vmxCZXL4Wu38MLpCZWtdHwoff1rj6e79GBvW8N0vAwighoAAAAaxsjYhPo6mvXp/R8eFTV8LK3c3LJ+c3HGo8pqw1qrkZMTOrK7Ww/2tl//+tqWz7Ree++a3p9Z8LBC3AlBDQAAAA1hZn5Fr56f0nNHU4qubwGs+NLhXepojul7IVthOptzdf5yXsPH0h/53nNHUzKGmWp+RVADAABAQzhxJqdi2X5oC2BFS1NUf/jYgH781qQWVooeVFcbo2MTikcj+tpjAx/53mB3Qr/34E6NnhoXE7b8h6AGAACAhjAyNq5HUp060N9x2+8PH0trsVDSj9+arHNltbFaKuv7pyf0B4f61N0av+1thofSunR1Sa+9d63O1eFeCGoAAAAIvfOTeb014Wr4NqtpFU/cn9R9O1o1eioc2x9/fn5aMwuF264gVjz5SL9a41GNnAzHv3OYENQAAAAQeqNj44pFjJ5+fPCOt6nMVPvrd2aUm12qY3W1MXpqXDvb4vrsgd473qY1HtNTjwzoh286Wl4t1bE63AtBDQAAAKFWLJX1/KkJffZAn3a2N9/1tsePpmWt9PypYDfYmF0s6C8zU3r6yKCaone/5B8+ltL8SlEvnw3Hls+wIKgBAAAg1P7qnRlN5Vc0fNPstDu5b2erPrZnh0bGgt1g4wdvOCqUynfd6lnxib07lepO0P3RZwhqAAAACLWRk+PqSjTp84f6NnT74WMpXZxe0OlLs7UtrIZGxsZ1YFeHHh7svOdtIxGj546m9Mu3p3XZXa5DddgIghoAAABCy11e1ctnJ/X044NqjkU39HeeenRAzbFIYFeY3pme16kPZnV8KCVjzL3/gqTjQymVrfRCwLd8hglBDQAAAKH1ozcdrRTLOr6BbY8VnS1N+vLD/TpxJqeVYvAabDw/NqGIWRtovVEP9Lbr6H3dgd/yGSYENQAAAITWyMkJPdDbpiO7uzf194aPpTW3tKqfZqdqU1iNlMtWz5+a0Kf396qvs2VTf3d4KK3fXp7X2Zxbo+qwGQQ1AAAAhNIHM4v62/euangoveEtgBWf2tejvo5mjQRs++Nv3p3RxOzSplYQK7762IDi0YhGxpip5gcENQAAAITS6KlxmU1uAayIrjfYePX8lGbmV2pQXW2MnJxQR3NMX364f9N/t7s1ri8c7tOJ0zmtlso1qA6bQVADAABA6FhrNTo2od97cKcGuxNbuo/jQ2kVy1YnzuSqXF1tLKwU9aO3HH3l0QG1NG2sccqthofSmlko6NXz01WuDptFUAMAAEDovP7+NX1wdVHHj957jtidHOjv0COpzsBsBXz57KQWCyUNH9v6v/NnHurVzra4RgPy7xxmBDUAAACEzsjJcbXGo3rykc1vAbzZ8FBab024Oj+Zr1JltTMyNq7dOxL6nT3JLd9HUzSiZ46k9JPslGYXC1WsDptFUAMAAECoLK+W9MM3HD35SL/ammPbuq+nHx9ULGJ8v8KUm13SX78zo+NHN9845VbHh1IqlMp68Q2nStVhKwhqAAAACJVXMpeVXynq60Nb3wJYsbO9WZ890KfnT02o6OMGG8+fmpC1ayuA2/XwYKcO9ndo5KS/w2nYEdQAAAAQKiMnxzXY1aJPPLCzKvc3PJTSVH5Ff/XOTFXur9qstRoZG9fv7Enqvp2t274/Y4yGh9I6fWlW70zPV6FCbAVBDQAAAKEx5S7rl29P67mhlCKR7W0BrPj8oT51JZp8u8J0ZnxOF6cXqrKaVvHMkUFFjHy/5TPMCGoAAAAIjRdOT6hs11rrV0tzLKqnHx/Uy2cn5S6vVu1+q2Xk5LiaYxF95bGBqt1nX2eLPvNQr54fm1C5bKt2v9g4ghoAAABCwVqrkZMTOrK7Ww/2tlf1vo8PpbRSLOtHb/qrwcZKsaQTZ3L60sP96mxpqup9Dw+llZtb1q8v+nPLZ9gR1AAAABAKZ3Ouzl/Ob2uO2J0c2d2tB3rbNHJyour3vR0/OzeluaVVDQ+lqn7fXzy8Sx0tscDMkQsbghoAAABCYWRsXPFoRF+r4hbAikqDjb9976o+mFms+v1v1fdOTqi3o1mf2tdT9ftuaYrqq48N6MdvTWphpVj1+8fdEdQAAAAQeKulsk6czukPDvWpuzVek8d47mhKxkijp/yxwjQzv6JXz0/puaMpxaK1uaw/PpTWYqGkH781WZP7x50R1AAAABB4Pz8/rZmFQlU7H95qsDuh33twp0bHJmSt9w02TpzJqVi2Nf13fuL+pO7f2cr2Rw8Q1AAAABB4I2Pj2tkW1+8f6K3p4xw/mtYHVxf1+vvXavo4GzEyNq6HBzt1oL+jZo9hjNHxo2n9+uKMJmaXavY4+CiCGgAAAAJtdrGgn2Sn9PSRQTXVaAtgxZOP9Ks1HvV8ptr5ybz
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Episode 261 done,Reward = 10.0 mean steps = 13.049618320610687 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"\n",
"Episode 262 done,Reward = -10.0 mean steps = 13.049429657794677 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"target_net replaced\n",
"\n",
"Episode 263 done,Reward = -10.0 mean steps = 13.151515151515152 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"\n",
"Episode 264 done,Reward = 10.0 mean steps = 13.124528301886793 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"\n",
"Episode 265 done,Reward = 10.0 mean steps = 13.135338345864662 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"target_net replaced\n",
"Model Saved\n",
"\n",
"Episode 266 done,Reward = 10.0 mean steps = 13.142322097378278 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"\n",
"Episode 267 done,Reward = -10.0 mean steps = 13.156716417910447 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n",
"\n",
"Episode 268 done,Reward = 10.0 mean steps = 13.118959107806692 exp_num = 2500 \n",
"epsilon = 0.009999999999982863\n"
]
},
{
"ename": "UnityCommunicatorStoppedException",
"evalue": "Communicator has exited.",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mUnityCommunicatorStoppedException\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_13628/1299734337.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 51\u001b[0m \u001b[1;31m# 動作をゲーム環境に渡す。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mset_actions\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbehavior_name\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbehavior_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maction\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0maction_Tuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 53\u001b[1;33m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 54\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[1;31m# 環境が動作を実行し、次の環境状態を返す。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\mlagents_envs\\timers.py\u001b[0m in \u001b[0;36mwrapped\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 303\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 304\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mhierarchical_timer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__qualname__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 305\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 306\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 307\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m \u001b[1;31m# type: ignore\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\mlagents_envs\\environment.py\u001b[0m in \u001b[0;36mstep\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 333\u001b[0m \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_communicator\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexchange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstep_input\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_poll_process\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 334\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0moutputs\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 335\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mUnityCommunicatorStoppedException\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Communicator has exited.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 336\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_update_behavior_specs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 337\u001b[0m \u001b[0mrl_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrl_output\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mUnityCommunicatorStoppedException\u001b[0m: Communicator has exited."
]
}
],
"source": [
"# エージェント作成\n",
"agent = DQN(load = False,load_dir=\"ML-Model/\" + \"20220205-051103\")\n",
"# トレーニング済みNNをローディングするにはこちらのコードを使用↓\n",
"#agent = DQN(load = True,load_dir=\"ML-Model/\" + \"FinalNN-\")\n",
"\n",
"total_steps = 0 \n",
"steps_list = []\n",
"successTimes = 0\n",
"failTimes = 0\n",
"successTimes_his = []\n",
"failTimes_his = []\n",
"this10TimesWin = 0\n",
"MeanWinPerin10Times = []\n",
"\n",
"for episode in range(EPISODES):\n",
" # episode 開始 \n",
" done = False #ゲーム終了状態をFalse\n",
" steps = 0 \n",
" # 環境初期状態を獲得\n",
" env.reset()\n",
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n",
" state = decision_steps.obs[0]\n",
" state = np.reshape(state, [1, STATE_SIZE])\n",
" \n",
" # ゲームスタート\n",
" while True:\n",
" reward = 0\n",
" steps+=1\n",
" total_steps += 1\n",
" # エージェントナンバーをトラックする\n",
" if tracked_agent == -1 and len(decision_steps) >= 1:\n",
" tracked_agent = decision_steps.agent_id[0]\n",
" \n",
" # REPLACE_STEPS毎でtarget_net にmain_netで入れ替える\n",
" if total_steps % REPLACE_STEPS == 0 and total_steps !=0:\n",
" agent.target_net.set_weights(agent.main_net.get_weights())\n",
" print('target_net replaced')\n",
" \n",
" # SAVE_STEPS毎でNNを保存する\n",
" if total_steps % SAVE_STEPS ==0 and total_steps !=0:\n",
" agent.saveNN()\n",
" \n",
" # main_netで動作選択\n",
" action = agent.select_action(state=state)\n",
" continuous_actions = np.array([[0]], dtype=np.float)\n",
" discrete_actions = np.expand_dims(action,axis=0)\n",
" # 動作をML-Agentsが認識可能なActionTuple型に変換\n",
" action_Tuple = ActionTuple(\n",
" continuous=continuous_actions, discrete=discrete_actions)\n",
"\n",
" # 動作をゲーム環境に渡す。\n",
" env.set_actions(behavior_name=behavior_name, action=action_Tuple)\n",
" env.step()\n",
" \n",
" # 環境が動作を実行し、次の環境状態を返す。\n",
" decision_steps, terminal_steps = env.get_steps(behavior_name)\n",
" if tracked_agent in decision_steps: # ゲーム終了していない場合、環境状態がdecision_stepsに保存される\n",
" next_state = decision_steps[tracked_agent].obs[0]\n",
" next_state = np.reshape(next_state,[1,STATE_SIZE])\n",
" reward = decision_steps[tracked_agent].reward\n",
" if tracked_agent in terminal_steps: # ゲーム終了した場合、環境状態がterminal_stepsに保存される\n",
" next_state = terminal_steps[tracked_agent].obs[0]\n",
" next_state = np.reshape(next_state,[1,STATE_SIZE])\n",
" reward = terminal_steps[tracked_agent].reward\n",
" done = True\n",
" \n",
" # Experience_poolに保存\n",
" agent.exp_pool.add(state,action,reward,next_state,done)\n",
" #print(\"Reward = \",reward)\n",
" # 環境状態を次状態に変更。\n",
" state = next_state\n",
" \n",
" # ゲーム終了後処理\n",
" if done:\n",
" mean_steps = total_steps/(episode+1)\n",
" print(\"\\nEpisode\",episode,\"done,Reward =\",reward,\"mean steps =\",mean_steps,\"exp_num =\",agent.exp_pool.get_len(),\"\\nepsilon =\",agent.epsilon)\n",
" agent.training()\n",
" if(reward >=10):\n",
" successTimes+=1\n",
" this10TimesWin+=1\n",
" else:\n",
" failTimes+=1\n",
" successTimes_his.append(successTimes)\n",
" failTimes_his.append(failTimes)\n",
" if episode % 10 ==0 and episode !=0:\n",
" clear_output()\n",
" this10TimesWinPer = float(this10TimesWin/10)\n",
" this10TimesWin = 0\n",
" MeanWinPerin10Times.append(this10TimesWinPer)\n",
" # 合計成功数(緑)、合計失敗数(赤)を図で表示する\n",
" plt.figure(figsize = (15,10))\n",
" plt.plot(range(len(successTimes_his)), successTimes_his,color='green',linestyle='--', linewidth=1, label='TotalWinTimes')\n",
" plt.plot(range(len(successTimes_his)), failTimes_his,color='red', linewidth=1,marker='o', markersize=1, markerfacecolor='black',markeredgecolor='black',label='TotalFaildTimes')\n",
" # plt.savefig('output.jpg')\n",
" plt.legend()\n",
" plt.savefig(\"wintimes.png\")\n",
" plt.show()\n",
" \n",
" #10回実行した後の成功確率を図で表示する\n",
" plt.figure(figsize=(15,10))\n",
" plt.plot(MeanWinPerin10Times)\n",
" plt.savefig(\"steps.png\")\n",
" plt.show()\n",
" break\n",
"env.close()\n",
"print(\"Finished~\")"
]
}
],
"metadata": {
"interpreter": {
"hash": "c62a1b52b24525839a95f7ca2b53f501cc329096d80c6be9aea5c814c594ecdd"
},
"kernelspec": {
"display_name": "Python 3.9.7 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "undefined.undefined.undefined"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}