575 lines
30 KiB
Plaintext
575 lines
30 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Action, 1 continuous ctrl 2.1\n",
|
|
"Action, 0 continuous ctrl -1.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import gym\n",
|
|
"from gym.spaces import Dict, Discrete, Box, Tuple\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"\n",
|
|
"class SampleGym(gym.Env):\n",
|
|
" def __init__(self, config={}):\n",
|
|
" self.config = config\n",
|
|
" self.action_space = Tuple((Discrete(2), Box(-10, 10, (2,))))\n",
|
|
" self.observation_space = Box(-10, 10, (2, 2))\n",
|
|
" self.p_done = config.get(\"p_done\", 0.1)\n",
|
|
"\n",
|
|
" def reset(self):\n",
|
|
" return self.observation_space.sample()\n",
|
|
"\n",
|
|
" def step(self, action):\n",
|
|
" chosen_action = action[0]\n",
|
|
" cnt_control = action[1][chosen_action]\n",
|
|
"\n",
|
|
" if chosen_action == 0:\n",
|
|
" reward = cnt_control\n",
|
|
" else:\n",
|
|
" reward = -cnt_control - 1\n",
|
|
"\n",
|
|
" print(f\"Action, {chosen_action} continuous ctrl {cnt_control}\")\n",
|
|
" return (\n",
|
|
" self.observation_space.sample(),\n",
|
|
" reward,\n",
|
|
" bool(np.random.choice([True, False], p=[self.p_done, 1.0 - self.p_done])),\n",
|
|
" {},\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" env = SampleGym()\n",
|
|
" env.reset()\n",
|
|
" env.step((1, [-1, 2.1])) # should say use action 1 with 2.1\n",
|
|
" env.step((0, [-1.1, 2.1])) # should say use action 0 with -1.1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from mlagents_envs.environment import UnityEnvironment\n",
|
|
"from gym_unity.envs import UnityToGymWrapper\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n",
|
|
"WORKER_ID = 1\n",
|
|
"BASE_PORT = 2002\n",
|
|
"\n",
|
|
"env = UnityEnvironment(\n",
|
|
" file_name=ENV_PATH,\n",
|
|
" seed=1,\n",
|
|
" side_channels=[],\n",
|
|
" worker_id=WORKER_ID,\n",
|
|
" base_port=BASE_PORT,\n",
|
|
")\n",
|
|
"\n",
|
|
"trackedAgent = 0\n",
|
|
"env.reset()\n",
|
|
"BEHA_SPECS = env.behavior_specs\n",
|
|
"BEHA_NAME = list(BEHA_SPECS)[0]\n",
|
|
"SPEC = BEHA_SPECS[BEHA_NAME]\n",
|
|
"print(SPEC)\n",
|
|
"\n",
|
|
"decisionSteps, terminalSteps = env.get_steps(BEHA_NAME)\n",
|
|
"\n",
|
|
"if trackedAgent in decisionSteps: # ゲーム終了していない場合、環境状態がdecision_stepsに保存される\n",
|
|
" nextState = decisionSteps[trackedAgent].obs[0]\n",
|
|
" reward = decisionSteps[trackedAgent].reward\n",
|
|
" done = False\n",
|
|
"if trackedAgent in terminalSteps: # ゲーム終了した場合、環境状態がterminal_stepsに保存される\n",
|
|
" nextState = terminalSteps[trackedAgent].obs[0]\n",
|
|
" reward = terminalSteps[trackedAgent].reward\n",
|
|
" done = True\n",
|
|
"print(decisionSteps.agent_id)\n",
|
|
"print(terminalSteps)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"decisionSteps.agent_id [1 2 5 7]\n",
|
|
"decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n",
|
|
"decisionSteps.reward [0. 0. 0. 0.]\n",
|
|
"decisionSteps.action_mask [array([[False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False]]), array([[False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False]]), array([[False, False],\n",
|
|
" [False, False],\n",
|
|
" [False, False],\n",
|
|
" [False, False]])]\n",
|
|
"decisionSteps.obs [ 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. -15.994009 1. -26.322788 1.\n",
|
|
" 1. 1. 1. 1. 1. 2.\n",
|
|
" 1. 1. 1. 1. 1. 1.\n",
|
|
" 1. 1.3519633 1.6946528 2.3051548 3.673389 9.067246\n",
|
|
" 17.521473 21.727095 22.753294 24.167128 25.905216 18.35725\n",
|
|
" 21.02278 21.053417 0. ]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'decisionSteps.obs [array([[-15.994009 , 1. , -26.322788 , 1. , 1. ,\\n 1. , 1. , 1. , 1. , 2. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 1.3519633, 1.6946528, 2.3051548,\\n 3.673389 , 9.067246 , 17.521473 , 21.727095 , 22.753294 ,\\n 24.167128 , 25.905216 , 18.35725 , 21.02278 , 21.053417 ,\\n 0. ],\\n [ -1.8809433, 1. , -25.66834 , 1. , 2. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 16.768637 , 23.414627 , 22.04486 ,\\n 21.050663 , 20.486784 , 20.486784 , 21.050665 , 15.049731 ,\\n 11.578419 , 9.695194 , 20.398016 , 20.368341 , 20.398016 ,\\n...\\n 20.551746 , 20.00118 , 20.001116 , 20.551594 , 21.5222 ,\\n 17.707508 , 14.86889 , 19.914494 , 19.885508 , 19.914463 ,\\n 0. ]], dtype=float32)]'"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"decisionSteps.agent_id\",decisionSteps.agent_id)\n",
|
|
"# decisionSteps.agent_id [1 2 5 7]\n",
|
|
"print(\"decisionSteps.agent_id_to_index\",decisionSteps.agent_id_to_index)\n",
|
|
"# decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n",
|
|
"print(\"decisionSteps.reward\",decisionSteps.reward)\n",
|
|
"# decisionSteps.reward [0. 0. 0. 0.]\n",
|
|
"print(\"decisionSteps.action_mask\",decisionSteps.action_mask)\n",
|
|
"'''\n",
|
|
"decisionSteps.action_mask [array([[False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False]]), array([[False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False],\n",
|
|
" [False, False, False]]), array([[False, False],\n",
|
|
" [False, False],\n",
|
|
" [False, False],\n",
|
|
" [False, False]])]\n",
|
|
"'''\n",
|
|
"print(\"decisionSteps.obs\", decisionSteps.obs[0][0])\n",
|
|
"'''decisionSteps.obs [array([[-15.994009 , 1. , -26.322788 , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 2. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1.3519633, 1.6946528, 2.3051548,\n",
|
|
" 3.673389 , 9.067246 , 17.521473 , 21.727095 , 22.753294 ,\n",
|
|
" 24.167128 , 25.905216 , 18.35725 , 21.02278 , 21.053417 ,\n",
|
|
" 0. ],\n",
|
|
" [ -1.8809433, 1. , -25.66834 , 1. , 2. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 16.768637 , 23.414627 , 22.04486 ,\n",
|
|
" 21.050663 , 20.486784 , 20.486784 , 21.050665 , 15.049731 ,\n",
|
|
" 11.578419 , 9.695194 , 20.398016 , 20.368341 , 20.398016 ,\n",
|
|
"...\n",
|
|
" 20.551746 , 20.00118 , 20.001116 , 20.551594 , 21.5222 ,\n",
|
|
" 17.707508 , 14.86889 , 19.914494 , 19.885508 , 19.914463 ,\n",
|
|
" 0. ]], dtype=float32)]'''\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from AimbotEnv import Aimbot\n",
|
|
"\n",
|
|
"ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n",
|
|
"WORKER_ID = 1\n",
|
|
"BASE_PORT = 2002\n",
|
|
"\n",
|
|
"env = Aimbot(envPath=ENV_PATH,workerID= WORKER_ID,basePort= BASE_PORT)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(array([[ 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , -15.994009 , 1. , -26.322788 , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 2. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1.3519633, 1.6946528,\n",
|
|
" 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n",
|
|
" 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n",
|
|
" 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1.3519667,\n",
|
|
" 1.6946585, 2.3051722, 3.6734192, 9.067533 , 21.145092 ,\n",
|
|
" 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n",
|
|
" 21.022812 , 21.053455 , 0. ],\n",
|
|
" [ 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , -1.8809433, 1. , -25.66834 , 1. ,\n",
|
|
" 2. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 16.768637 , 23.414627 ,\n",
|
|
" 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n",
|
|
" 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n",
|
|
" 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n",
|
|
" 1. , 1. , 2. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 2. ,\n",
|
|
" 2. , 1. , 1. , 1. , 25.098585 ,\n",
|
|
" 15.749494 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n",
|
|
" 21.050694 , 15.049746 , 3.872317 , 3.789325 , 20.398046 ,\n",
|
|
" 20.368372 , 20.398046 , 0. ],\n",
|
|
" [ 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , -13.672583 , 1. , -26.479263 , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 5.3249803, 6.401276 ,\n",
|
|
" 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n",
|
|
" 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n",
|
|
" 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 2. , 1. , 1. ,\n",
|
|
" 2. , 1. , 1. , 2. , 5.3249855,\n",
|
|
" 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n",
|
|
" 16.168503 , 22.922543 , 24.346823 , 7.1110754, 21.210148 ,\n",
|
|
" 21.17929 , 12.495141 , 0. ],\n",
|
|
" [ 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , -4.9038744, 1. , -25.185507 , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 1. ,\n",
|
|
" 1. , 1. , 1. , 20.33171 , 22.859762 ,\n",
|
|
" 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n",
|
|
" 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n",
|
|
" 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n",
|
|
" 1. , 2. , 1. , 2. , 1. ,\n",
|
|
" 1. , 1. , 1. , 2. , 1. ,\n",
|
|
" 1. , 1. , 1. , 1. , 15.905993 ,\n",
|
|
" 22.85977 , 11.566693 , 20.551773 , 20.00121 , 20.001146 ,\n",
|
|
" 20.551619 , 7.135157 , 17.707582 , 14.868943 , 19.914528 ,\n",
|
|
" 19.88554 , 19.914494 , 0. ]], dtype=float32),\n",
|
|
" [[-0.05], [-0.05], [-0.05], [-0.05]],\n",
|
|
" [[False], [False], [False], [False]])"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"env.unity_observation_shape\n",
|
|
"(128, 4) + env.unity_observation_shape\n",
|
|
"env.reset()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[1, 2, 3],\n",
|
|
" [1, 2, 3],\n",
|
|
" [1, 2, 3],\n",
|
|
" [1, 2, 3]], device='cuda:0')\n",
|
|
"tensor([[1],\n",
|
|
" [2],\n",
|
|
" [3],\n",
|
|
" [4]], device='cuda:0')\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[1, 2, 3, 1],\n",
|
|
" [1, 2, 3, 2],\n",
|
|
" [1, 2, 3, 3],\n",
|
|
" [1, 2, 3, 4]], device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"aa = torch.tensor([[1,2,3],[1,2,3],[1,2,3],[1,2,3]]).to(\"cuda:0\")\n",
|
|
"bb = torch.tensor([[1],[2],[3],[4]]).to(\"cuda:0\")\n",
|
|
"print(aa)\n",
|
|
"print(bb)\n",
|
|
"torch.cat([aa,bb],axis = 1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "AttributeError",
|
|
"evalue": "Can't get attribute 'PPOAgent' on <module '__main__'>",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_31348\\1930153251.py\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmymodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../PPO-Model/SmallArea-256-128-hybrid.pt\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[0mmymodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
"\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[0;32m 710\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mseek\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0morig_position\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 711\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 712\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_zipfile\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 713\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_legacy_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 714\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
"\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m_load\u001b[1;34m(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)\u001b[0m\n\u001b[0;32m 1047\u001b[0m \u001b[0munpickler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mUnpicklerWrapper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1048\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpersistent_load\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1049\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1050\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1051\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_validate_loaded_sparse_tensors\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
"\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mfind_class\u001b[1;34m(self, mod_name, name)\u001b[0m\n\u001b[0;32m 1040\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1041\u001b[0m \u001b[0mmod_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_module_mapping\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmod_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1042\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfind_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1043\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1044\u001b[0m \u001b[1;31m# Load the data (which may in turn use `persistent_load` to load tensors)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
"\u001b[1;31mAttributeError\u001b[0m: Can't get attribute 'PPOAgent' on <module '__main__'>"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"def layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n",
|
|
" torch.nn.init.orthogonal_(layer.weight, std)\n",
|
|
" torch.nn.init.constant_(layer.bias, bias_const)\n",
|
|
" return layer\n",
|
|
"\n",
|
|
"class PPOAgent(nn.Module):\n",
|
|
" def __init__(self, env: Aimbot):\n",
|
|
" super(PPOAgent, self).__init__()\n",
|
|
" self.discrete_size = env.unity_discrete_size\n",
|
|
" self.discrete_shape = list(env.unity_discrete_branches)\n",
|
|
" self.continuous_size = env.unity_continuous_size\n",
|
|
"\n",
|
|
" self.network = nn.Sequential(\n",
|
|
" layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 256)),\n",
|
|
" nn.ReLU(),\n",
|
|
" layer_init(nn.Linear(256, 128)),\n",
|
|
" nn.ReLU(),\n",
|
|
" )\n",
|
|
" self.actor_dis = layer_init(nn.Linear(128, self.discrete_size), std=0.01)\n",
|
|
" self.actor_mean = layer_init(nn.Linear(128, self.continuous_size), std=0.01)\n",
|
|
" self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))\n",
|
|
" self.critic = layer_init(nn.Linear(128, 1), std=1)\n",
|
|
"\n",
|
|
" def get_value(self, state: torch.Tensor):\n",
|
|
" return self.critic(self.network(state))\n",
|
|
"\n",
|
|
" def get_actions_value(self, state: torch.Tensor, actions=None):\n",
|
|
" hidden = self.network(state)\n",
|
|
" # discrete\n",
|
|
" dis_logits = self.actor_dis(hidden)\n",
|
|
" split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)\n",
|
|
" multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]\n",
|
|
" # continuous\n",
|
|
" actions_mean = self.actor_mean(hidden)\n",
|
|
" action_logstd = self.actor_logstd.expand_as(actions_mean)\n",
|
|
" action_std = torch.exp(action_logstd)\n",
|
|
" con_probs = Normal(actions_mean, action_std)\n",
|
|
"\n",
|
|
" if actions is None:\n",
|
|
" disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
|
|
" conAct = con_probs.sample()\n",
|
|
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
|
|
" else:\n",
|
|
" disAct = actions[:, 0 : env.unity_discrete_type].T\n",
|
|
" conAct = actions[:, env.unity_discrete_type :]\n",
|
|
" dis_log_prob = torch.stack(\n",
|
|
" [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]\n",
|
|
" )\n",
|
|
" dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])\n",
|
|
" return (\n",
|
|
" actions,\n",
|
|
" dis_log_prob.sum(0),\n",
|
|
" dis_entropy.sum(0),\n",
|
|
" con_probs.log_prob(conAct).sum(1),\n",
|
|
" con_probs.entropy().sum(1),\n",
|
|
" self.critic(hidden),\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"mymodel = torch.load(\"../PPO-Model/SmallArea-256-128-hybrid.pt\")\n",
|
|
"mymodel.eval()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"x = torch.randn(2, 3).to(\"cuda\")\n",
|
|
"print(x)\n",
|
|
"print(torch.cat((x, x, x), 0))\n",
|
|
"print(torch.cat((x, x, x), 1))\n",
|
|
"\n",
|
|
"aa = torch.empty(0).to(\"cuda\")\n",
|
|
"torch.cat([aa,x])\n",
|
|
"bb = [[]]*2\n",
|
|
"print(bb)\n",
|
|
"bb.append(x.to(\"cpu\").tolist())\n",
|
|
"bb.append(x.to(\"cpu\").tolist())\n",
|
|
"print(bb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 64,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[-1.1090, 0.4686, 0.6883],\n",
|
|
" [-0.1862, -0.3943, -0.0202],\n",
|
|
" [ 0.1436, -0.9444, -1.2079],\n",
|
|
" [-2.9434, -2.5989, -0.6653],\n",
|
|
" [ 0.4668, 0.8548, -0.4641],\n",
|
|
" [-0.3956, -0.2832, -0.1889],\n",
|
|
" [-0.2801, -0.2092, 1.7254],\n",
|
|
" [ 2.7938, -0.7742, 0.7053]], device='cuda:0')\n",
|
|
"(8, 0)\n",
|
|
"---\n",
|
|
"[[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)], [array([-0.1861974 , -0.39429024, -0.02016036], dtype=float32)], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913 , -0.66532046], dtype=float32)], [array([ 0.46684313, 0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891 ], dtype=float32)], [array([-0.28008553, -0.20918302, 1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478, 0.705279 ], dtype=float32)]]\n",
|
|
"[[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)], [], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913 , -0.66532046], dtype=float32)], [array([ 0.46684313, 0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891 ], dtype=float32)], [array([-0.28008553, -0.20918302, 1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478, 0.705279 ], dtype=float32)]]\n",
|
|
"---\n",
|
|
"[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32), array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)]\n",
|
|
"vvv tensor([[-1.1090, 0.4686, 0.6883],\n",
|
|
" [-1.1090, 0.4686, 0.6883]], device='cuda:0')\n",
|
|
"tensor([[-1.1090, 0.4686, 0.6883],\n",
|
|
" [-1.1090, 0.4686, 0.6883]], device='cuda:0')\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"True"
|
|
]
|
|
},
|
|
"execution_count": 64,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"agent_num = 8\n",
|
|
"ob_buffer = [[]for i in range(agent_num)]\n",
|
|
"obs = torch.randn(8, 3).to(\"cuda\")\n",
|
|
"print(obs)\n",
|
|
"print(np.shape(np.array(ob_buffer)))\n",
|
|
"print('---')\n",
|
|
"obs_cpu = obs.to(\"cpu\").numpy()\n",
|
|
"for i in range(agent_num):\n",
|
|
" ob_buffer[i].append(obs_cpu[i])\n",
|
|
"print(ob_buffer)\n",
|
|
"ob_buffer[1] = []\n",
|
|
"print(ob_buffer)\n",
|
|
"print('---')\n",
|
|
"for i in range(agent_num):\n",
|
|
" ob_buffer[i].append(obs_cpu[i])\n",
|
|
"print(ob_buffer[0])\n",
|
|
"vvv = torch.tensor(ob_buffer[0]).to(\"cuda\")\n",
|
|
"print(\"vvv\",vvv)\n",
|
|
"empt = torch.tensor([]).to(\"cuda\")\n",
|
|
"vvvv = torch.cat((empt,vvv),0)\n",
|
|
"print(vvvv)\n",
|
|
"vvvv.size()[0]>0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'Go': 1, 'Attack': 0, 'Free': 0}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"Total = {\"Go\":0,\"Attack\":0,\"Free\":0}\n",
|
|
"\n",
|
|
"Total[\"Go\"] +=1\n",
|
|
"print(Total)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.9.7 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|