{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action, 1 continuous ctrl 2.1\n", "Action, 0 continuous ctrl -1.1\n" ] } ], "source": [ "import gym\n", "from gym.spaces import Dict, Discrete, Box, Tuple\n", "import numpy as np\n", "\n", "\n", "class SampleGym(gym.Env):\n", " def __init__(self, config={}):\n", " self.config = config\n", " self.action_space = Tuple((Discrete(2), Box(-10, 10, (2,))))\n", " self.observation_space = Box(-10, 10, (2, 2))\n", " self.p_done = config.get(\"p_done\", 0.1)\n", "\n", " def reset(self):\n", " return self.observation_space.sample()\n", "\n", " def step(self, action):\n", " chosen_action = action[0]\n", " cnt_control = action[1][chosen_action]\n", "\n", " if chosen_action == 0:\n", " reward = cnt_control\n", " else:\n", " reward = -cnt_control - 1\n", "\n", " print(f\"Action, {chosen_action} continuous ctrl {cnt_control}\")\n", " return (\n", " self.observation_space.sample(),\n", " reward,\n", " bool(np.random.choice([True, False], p=[self.p_done, 1.0 - self.p_done])),\n", " {},\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " env = SampleGym()\n", " env.reset()\n", " env.step((1, [-1, 2.1])) # should say use action 1 with 2.1\n", " env.step((0, [-1.1, 2.1])) # should say use action 0 with -1.1" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from mlagents_envs.environment import UnityEnvironment\n", "from gym_unity.envs import UnityToGymWrapper\n", "import numpy as np\n", "\n", "ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n", "WORKER_ID = 1\n", "BASE_PORT = 2002\n", "\n", "env = UnityEnvironment(\n", " file_name=ENV_PATH,\n", " seed=1,\n", " side_channels=[],\n", " worker_id=WORKER_ID,\n", " base_port=BASE_PORT,\n", ")\n", "\n", "trackedAgent = 0\n", "env.reset()\n", "BEHA_SPECS = env.behavior_specs\n", "BEHA_NAME = list(BEHA_SPECS)[0]\n", "SPEC = BEHA_SPECS[BEHA_NAME]\n", "print(SPEC)\n", "\n", "decisionSteps, terminalSteps = env.get_steps(BEHA_NAME)\n", "\n", "if trackedAgent in decisionSteps: # ゲーム終了していない場合、環境状態がdecision_stepsに保存される\n", " nextState = decisionSteps[trackedAgent].obs[0]\n", " reward = decisionSteps[trackedAgent].reward\n", " done = False\n", "if trackedAgent in terminalSteps: # ゲーム終了した場合、環境状態がterminal_stepsに保存される\n", " nextState = terminalSteps[trackedAgent].obs[0]\n", " reward = terminalSteps[trackedAgent].reward\n", " done = True\n", "print(decisionSteps.agent_id)\n", "print(terminalSteps)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "decisionSteps.agent_id [1 2 5 7]\n", "decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n", "decisionSteps.reward [0. 0. 0. 0.]\n", "decisionSteps.action_mask [array([[False, False, False],\n", " [False, False, False],\n", " [False, False, False],\n", " [False, False, False]]), array([[False, False, False],\n", " [False, False, False],\n", " [False, False, False],\n", " [False, False, False]]), array([[False, False],\n", " [False, False],\n", " [False, False],\n", " [False, False]])]\n", "decisionSteps.obs [ 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0.\n", " 0. 0. -15.994009 1. -26.322788 1.\n", " 1. 1. 1. 1. 1. 2.\n", " 1. 1. 1. 1. 1. 1.\n", " 1. 1.3519633 1.6946528 2.3051548 3.673389 9.067246\n", " 17.521473 21.727095 22.753294 24.167128 25.905216 18.35725\n", " 21.02278 21.053417 0. ]\n" ] }, { "data": { "text/plain": [ "'decisionSteps.obs [array([[-15.994009 , 1. , -26.322788 , 1. , 1. ,\\n 1. , 1. , 1. , 1. , 2. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 1.3519633, 1.6946528, 2.3051548,\\n 3.673389 , 9.067246 , 17.521473 , 21.727095 , 22.753294 ,\\n 24.167128 , 25.905216 , 18.35725 , 21.02278 , 21.053417 ,\\n 0. ],\\n [ -1.8809433, 1. , -25.66834 , 1. , 2. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 16.768637 , 23.414627 , 22.04486 ,\\n 21.050663 , 20.486784 , 20.486784 , 21.050665 , 15.049731 ,\\n 11.578419 , 9.695194 , 20.398016 , 20.368341 , 20.398016 ,\\n...\\n 20.551746 , 20.00118 , 20.001116 , 20.551594 , 21.5222 ,\\n 17.707508 , 14.86889 , 19.914494 , 19.885508 , 19.914463 ,\\n 0. ]], dtype=float32)]'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"decisionSteps.agent_id\",decisionSteps.agent_id)\n", "# decisionSteps.agent_id [1 2 5 7]\n", "print(\"decisionSteps.agent_id_to_index\",decisionSteps.agent_id_to_index)\n", "# decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n", "print(\"decisionSteps.reward\",decisionSteps.reward)\n", "# decisionSteps.reward [0. 0. 0. 0.]\n", "print(\"decisionSteps.action_mask\",decisionSteps.action_mask)\n", "'''\n", "decisionSteps.action_mask [array([[False, False, False],\n", " [False, False, False],\n", " [False, False, False],\n", " [False, False, False]]), array([[False, False, False],\n", " [False, False, False],\n", " [False, False, False],\n", " [False, False, False]]), array([[False, False],\n", " [False, False],\n", " [False, False],\n", " [False, False]])]\n", "'''\n", "print(\"decisionSteps.obs\", decisionSteps.obs[0][0])\n", "'''decisionSteps.obs [array([[-15.994009 , 1. , -26.322788 , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 2. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1.3519633, 1.6946528, 2.3051548,\n", " 3.673389 , 9.067246 , 17.521473 , 21.727095 , 22.753294 ,\n", " 24.167128 , 25.905216 , 18.35725 , 21.02278 , 21.053417 ,\n", " 0. ],\n", " [ -1.8809433, 1. , -25.66834 , 1. , 2. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 16.768637 , 23.414627 , 22.04486 ,\n", " 21.050663 , 20.486784 , 20.486784 , 21.050665 , 15.049731 ,\n", " 11.578419 , 9.695194 , 20.398016 , 20.368341 , 20.398016 ,\n", "...\n", " 20.551746 , 20.00118 , 20.001116 , 20.551594 , 21.5222 ,\n", " 17.707508 , 14.86889 , 19.914494 , 19.885508 , 19.914463 ,\n", " 0. ]], dtype=float32)]'''\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from AimbotEnv import Aimbot\n", "\n", "ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n", "WORKER_ID = 1\n", "BASE_PORT = 2002\n", "\n", "env = Aimbot(envPath=ENV_PATH,workerID= WORKER_ID,basePort= BASE_PORT)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -15.994009 , 1. , -26.322788 , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 2. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1.3519633, 1.6946528,\n", " 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n", " 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n", " 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1.3519667,\n", " 1.6946585, 2.3051722, 3.6734192, 9.067533 , 21.145092 ,\n", " 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n", " 21.022812 , 21.053455 , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -1.8809433, 1. , -25.66834 , 1. ,\n", " 2. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 16.768637 , 23.414627 ,\n", " 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n", " 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n", " 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n", " 1. , 1. , 2. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 2. ,\n", " 2. , 1. , 1. , 1. , 25.098585 ,\n", " 15.749494 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n", " 21.050694 , 15.049746 , 3.872317 , 3.789325 , 20.398046 ,\n", " 20.368372 , 20.398046 , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -13.672583 , 1. , -26.479263 , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 5.3249803, 6.401276 ,\n", " 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n", " 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n", " 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 2. , 1. , 1. ,\n", " 2. , 1. , 1. , 2. , 5.3249855,\n", " 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n", " 16.168503 , 22.922543 , 24.346823 , 7.1110754, 21.210148 ,\n", " 21.17929 , 12.495141 , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -4.9038744, 1. , -25.185507 , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 20.33171 , 22.859762 ,\n", " 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n", " 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n", " 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n", " 1. , 2. , 1. , 2. , 1. ,\n", " 1. , 1. , 1. , 2. , 1. ,\n", " 1. , 1. , 1. , 1. , 15.905993 ,\n", " 22.85977 , 11.566693 , 20.551773 , 20.00121 , 20.001146 ,\n", " 20.551619 , 7.135157 , 17.707582 , 14.868943 , 19.914528 ,\n", " 19.88554 , 19.914494 , 0. ]], dtype=float32),\n", " [[-0.05], [-0.05], [-0.05], [-0.05]],\n", " [[False], [False], [False], [False]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "env.unity_observation_shape\n", "(128, 4) + env.unity_observation_shape\n", "env.reset()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0 0 0 0]\n", " [0 0 0 0]\n", " [0 0 0 0]\n", " [0 0 0 0]]\n", "[[0]\n", " [0]\n", " [0]\n", " [0]]\n", "[[0 0 0]\n", " [0 0 0]\n", " [0 0 0]\n", " [0 0 0]]\n" ] }, { "data": { "text/plain": [ "([array([ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -15.994009 , 1. , -26.322788 , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 2. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1.3519633, 1.6946528,\n", " 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n", " 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n", " 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 2. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1.3519667,\n", " 1.6946585, 2.3051722, 3.6734192, 9.067533 , 17.521563 ,\n", " 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n", " 21.022812 , 21.053455 , 0. ], dtype=float32),\n", " array([ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -1.8809433, 1. , -25.66834 , 1. ,\n", " 2. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 16.768637 , 23.414627 ,\n", " 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n", " 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n", " 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n", " 1. , 2. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 16.768671 ,\n", " 23.414669 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n", " 21.050694 , 15.049746 , 11.578423 , 9.695195 , 20.398046 ,\n", " 20.368372 , 20.398046 , 0. ], dtype=float32),\n", " array([ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -13.672583 , 1. , -26.479263 , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 5.3249803, 6.401276 ,\n", " 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n", " 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n", " 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 5.3249855,\n", " 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n", " 21.888773 , 22.922543 , 24.346823 , 26.097757 , 21.210148 ,\n", " 21.17929 , 21.21015 , 0. ], dtype=float32),\n", " array([ 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0. , 0. ,\n", " 0. , -4.9038744, 1. , -25.185507 , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 20.33171 , 22.859762 ,\n", " 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n", " 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n", " 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 1. ,\n", " 1. , 1. , 1. , 1. , 20.331783 ,\n", " 22.85977 , 21.522448 , 20.551773 , 20.00121 , 20.001146 ,\n", " 20.551619 , 21.522217 , 17.707582 , 14.868943 , 19.914528 ,\n", " 19.88554 , 19.914494 , 0. ], dtype=float32)],\n", " [[-0.05], [-0.05], [-0.05], [-0.05]],\n", " [[False], [False], [False], [False]])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "actions = np.zeros_like(np.arange(16).reshape(4, 4))\n", "print(actions)\n", "env.step(actions)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.7 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32" } } }, "nbformat": 4, "nbformat_minor": 2 }