Aimbot-PPO/Aimbot-PPO-Python/Pytorch/Archive/testarea.ipynb

1299 lines
103 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Action, 1 continuous ctrl 2.1\n",
"Action, 0 continuous ctrl -1.1\n"
]
}
],
"source": [
"import gym\n",
"from gym.spaces import Dict, Discrete, Box, Tuple\n",
"import numpy as np\n",
"\n",
"\n",
"class SampleGym(gym.Env):\n",
" def __init__(self, config={}):\n",
" self.config = config\n",
" self.action_space = Tuple((Discrete(2), Box(-10, 10, (2,))))\n",
" self.observation_space = Box(-10, 10, (2, 2))\n",
" self.p_done = config.get(\"p_done\", 0.1)\n",
"\n",
" def reset(self):\n",
" return self.observation_space.sample()\n",
"\n",
" def step(self, action):\n",
" chosen_action = action[0]\n",
" cnt_control = action[1][chosen_action]\n",
"\n",
" if chosen_action == 0:\n",
" reward = cnt_control\n",
" else:\n",
" reward = -cnt_control - 1\n",
"\n",
" print(f\"Action, {chosen_action} continuous ctrl {cnt_control}\")\n",
" return (\n",
" self.observation_space.sample(),\n",
" reward,\n",
" bool(np.random.choice([True, False], p=[self.p_done, 1.0 - self.p_done])),\n",
" {},\n",
" )\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" env = SampleGym()\n",
" env.reset()\n",
" env.step((1, [-1, 2.1])) # should say use action 1 with 2.1\n",
" env.step((0, [-1.1, 2.1])) # should say use action 0 with -1.1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from mlagents_envs.environment import UnityEnvironment\n",
"import numpy as np\n",
"\n",
"ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n",
"WORKER_ID = 1\n",
"BASE_PORT = 2002\n",
"\n",
"env = UnityEnvironment(\n",
" file_name=ENV_PATH,\n",
" seed=1,\n",
" side_channels=[],\n",
" worker_id=WORKER_ID,\n",
" base_port=BASE_PORT,\n",
")\n",
"\n",
"trackedAgent = 0\n",
"env.reset()\n",
"BEHA_SPECS = env.behavior_specs\n",
"BEHA_NAME = list(BEHA_SPECS)[0]\n",
"SPEC = BEHA_SPECS[BEHA_NAME]\n",
"print(SPEC)\n",
"\n",
"decisionSteps, terminalSteps = env.get_steps(BEHA_NAME)\n",
"\n",
"if trackedAgent in decisionSteps: # ゲーム終了していない場合、環境状態がdecision_stepsに保存される\n",
" nextState = decisionSteps[trackedAgent].obs[0]\n",
" reward = decisionSteps[trackedAgent].reward\n",
" done = False\n",
"if trackedAgent in terminalSteps: # ゲーム終了した場合、環境状態がterminal_stepsに保存される\n",
" nextState = terminalSteps[trackedAgent].obs[0]\n",
" reward = terminalSteps[trackedAgent].reward\n",
" done = True\n",
"print(decisionSteps.agent_id)\n",
"print(terminalSteps)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"decisionSteps.agent_id [1 2 5 7]\n",
"decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n",
"decisionSteps.reward [0. 0. 0. 0.]\n",
"decisionSteps.action_mask [array([[False, False, False],\n",
" [False, False, False],\n",
" [False, False, False],\n",
" [False, False, False]]), array([[False, False, False],\n",
" [False, False, False],\n",
" [False, False, False],\n",
" [False, False, False]]), array([[False, False],\n",
" [False, False],\n",
" [False, False],\n",
" [False, False]])]\n",
"decisionSteps.obs [ 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0.\n",
" 0. 0. -15.994009 1. -26.322788 1.\n",
" 1. 1. 1. 1. 1. 2.\n",
" 1. 1. 1. 1. 1. 1.\n",
" 1. 1.3519633 1.6946528 2.3051548 3.673389 9.067246\n",
" 17.521473 21.727095 22.753294 24.167128 25.905216 18.35725\n",
" 21.02278 21.053417 0. ]\n"
]
},
{
"data": {
"text/plain": [
"'decisionSteps.obs [array([[-15.994009 , 1. , -26.322788 , 1. , 1. ,\\n 1. , 1. , 1. , 1. , 2. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 1.3519633, 1.6946528, 2.3051548,\\n 3.673389 , 9.067246 , 17.521473 , 21.727095 , 22.753294 ,\\n 24.167128 , 25.905216 , 18.35725 , 21.02278 , 21.053417 ,\\n 0. ],\\n [ -1.8809433, 1. , -25.66834 , 1. , 2. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 1. , 1. , 1. ,\\n 1. , 1. , 16.768637 , 23.414627 , 22.04486 ,\\n 21.050663 , 20.486784 , 20.486784 , 21.050665 , 15.049731 ,\\n 11.578419 , 9.695194 , 20.398016 , 20.368341 , 20.398016 ,\\n...\\n 20.551746 , 20.00118 , 20.001116 , 20.551594 , 21.5222 ,\\n 17.707508 , 14.86889 , 19.914494 , 19.885508 , 19.914463 ,\\n 0. ]], dtype=float32)]'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"decisionSteps.agent_id\",decisionSteps.agent_id)\n",
"# decisionSteps.agent_id [1 2 5 7]\n",
"print(\"decisionSteps.agent_id_to_index\",decisionSteps.agent_id_to_index)\n",
"# decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n",
"print(\"decisionSteps.reward\",decisionSteps.reward)\n",
"# decisionSteps.reward [0. 0. 0. 0.]\n",
"print(\"decisionSteps.action_mask\",decisionSteps.action_mask)\n",
"'''\n",
"decisionSteps.action_mask [array([[False, False, False],\n",
" [False, False, False],\n",
" [False, False, False],\n",
" [False, False, False]]), array([[False, False, False],\n",
" [False, False, False],\n",
" [False, False, False],\n",
" [False, False, False]]), array([[False, False],\n",
" [False, False],\n",
" [False, False],\n",
" [False, False]])]\n",
"'''\n",
"print(\"decisionSteps.obs\", decisionSteps.obs[0][0])\n",
"'''decisionSteps.obs [array([[-15.994009 , 1. , -26.322788 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 2. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1.3519633, 1.6946528, 2.3051548,\n",
" 3.673389 , 9.067246 , 17.521473 , 21.727095 , 22.753294 ,\n",
" 24.167128 , 25.905216 , 18.35725 , 21.02278 , 21.053417 ,\n",
" 0. ],\n",
" [ -1.8809433, 1. , -25.66834 , 1. , 2. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 16.768637 , 23.414627 , 22.04486 ,\n",
" 21.050663 , 20.486784 , 20.486784 , 21.050665 , 15.049731 ,\n",
" 11.578419 , 9.695194 , 20.398016 , 20.368341 , 20.398016 ,\n",
"...\n",
" 20.551746 , 20.00118 , 20.001116 , 20.551594 , 21.5222 ,\n",
" 17.707508 , 14.86889 , 19.914494 , 19.885508 , 19.914463 ,\n",
" 0. ]], dtype=float32)]'''\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from AimbotEnv import Aimbot\n",
"\n",
"ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n",
"WORKER_ID = 1\n",
"BASE_PORT = 2002\n",
"\n",
"env = Aimbot(envPath=ENV_PATH,workerID= WORKER_ID,basePort= BASE_PORT)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -15.994009 , 1. , -26.322788 , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 2. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1.3519633, 1.6946528,\n",
" 2.3051548, 3.673389 , 9.067246 , 17.521473 , 21.727095 ,\n",
" 22.753294 , 24.167128 , 25.905216 , 18.35725 , 21.02278 ,\n",
" 21.053417 , 0. , -15.994003 , 1. , -26.322784 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1.3519667,\n",
" 1.6946585, 2.3051722, 3.6734192, 9.067533 , 21.145092 ,\n",
" 21.727148 , 22.753365 , 24.167217 , 25.905317 , 18.358263 ,\n",
" 21.022812 , 21.053455 , 0. ],\n",
" [ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -1.8809433, 1. , -25.66834 , 1. ,\n",
" 2. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 16.768637 , 23.414627 ,\n",
" 22.04486 , 21.050663 , 20.486784 , 20.486784 , 21.050665 ,\n",
" 15.049731 , 11.578419 , 9.695194 , 20.398016 , 20.368341 ,\n",
" 20.398016 , 0. , -1.8809433, 1. , -25.66834 ,\n",
" 1. , 1. , 2. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 2. ,\n",
" 2. , 1. , 1. , 1. , 25.098585 ,\n",
" 15.749494 , 22.044899 , 21.050697 , 20.486813 , 20.486813 ,\n",
" 21.050694 , 15.049746 , 3.872317 , 3.789325 , 20.398046 ,\n",
" 20.368372 , 20.398046 , 0. ],\n",
" [ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -13.672583 , 1. , -26.479263 , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 5.3249803, 6.401276 ,\n",
" 8.374101 , 12.8657875, 21.302414 , 21.30242 , 21.888742 ,\n",
" 22.92251 , 24.346794 , 26.09773 , 21.210114 , 21.179258 ,\n",
" 21.210117 , 0. , -13.672583 , 1. , -26.479263 ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 2. , 1. , 1. ,\n",
" 2. , 1. , 1. , 2. , 5.3249855,\n",
" 6.4012837, 8.374114 , 12.865807 , 21.302446 , 21.30245 ,\n",
" 16.168503 , 22.922543 , 24.346823 , 7.1110754, 21.210148 ,\n",
" 21.17929 , 12.495141 , 0. ],\n",
" [ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , -4.9038744, 1. , -25.185507 , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 20.33171 , 22.859762 ,\n",
" 21.522427 , 20.551746 , 20.00118 , 20.001116 , 20.551594 ,\n",
" 21.5222 , 17.707508 , 14.86889 , 19.914494 , 19.885508 ,\n",
" 19.914463 , 0. , -4.9038773, 1. , -25.185507 ,\n",
" 1. , 2. , 1. , 2. , 1. ,\n",
" 1. , 1. , 1. , 2. , 1. ,\n",
" 1. , 1. , 1. , 1. , 15.905993 ,\n",
" 22.85977 , 11.566693 , 20.551773 , 20.00121 , 20.001146 ,\n",
" 20.551619 , 7.135157 , 17.707582 , 14.868943 , 19.914528 ,\n",
" 19.88554 , 19.914494 , 0. ]], dtype=float32),\n",
" [[-0.05], [-0.05], [-0.05], [-0.05]],\n",
" [[False], [False], [False], [False]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"env.unity_observation_shape\n",
"(128, 4) + env.unity_observation_shape\n",
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3]], device='cuda:0')\n",
"tensor([[1],\n",
" [2],\n",
" [3],\n",
" [4]], device='cuda:0')\n"
]
},
{
"data": {
"text/plain": [
"tensor([[1, 2, 3, 1],\n",
" [1, 2, 3, 2],\n",
" [1, 2, 3, 3],\n",
" [1, 2, 3, 4]], device='cuda:0')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"aa = torch.tensor([[1,2,3],[1,2,3],[1,2,3],[1,2,3]]).to(\"cuda:0\")\n",
"bb = torch.tensor([[1],[2],[3],[4]]).to(\"cuda:0\")\n",
"print(aa)\n",
"print(bb)\n",
"torch.cat([aa,bb],axis = 1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "Can't get attribute 'PPOAgent' on <module '__main__'>",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_31348\\1930153251.py\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmymodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../PPO-Model/SmallArea-256-128-hybrid.pt\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[0mmymodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[0;32m 710\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mseek\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0morig_position\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 711\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 712\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_zipfile\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 713\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_legacy_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 714\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m_load\u001b[1;34m(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)\u001b[0m\n\u001b[0;32m 1047\u001b[0m \u001b[0munpickler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mUnpicklerWrapper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1048\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpersistent_load\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1049\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1050\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1051\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_validate_loaded_sparse_tensors\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mfind_class\u001b[1;34m(self, mod_name, name)\u001b[0m\n\u001b[0;32m 1040\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1041\u001b[0m \u001b[0mmod_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_module_mapping\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmod_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1042\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfind_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1043\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1044\u001b[0m \u001b[1;31m# Load the data (which may in turn use `persistent_load` to load tensors)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mAttributeError\u001b[0m: Can't get attribute 'PPOAgent' on <module '__main__'>"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"def layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n",
" torch.nn.init.orthogonal_(layer.weight, std)\n",
" torch.nn.init.constant_(layer.bias, bias_const)\n",
" return layer\n",
"\n",
"class PPOAgent(nn.Module):\n",
" def __init__(self, env: Aimbot):\n",
" super(PPOAgent, self).__init__()\n",
" self.discrete_size = env.unity_discrete_size\n",
" self.discrete_shape = list(env.unity_discrete_branches)\n",
" self.continuous_size = env.unity_continuous_size\n",
"\n",
" self.network = nn.Sequential(\n",
" layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 256)),\n",
" nn.ReLU(),\n",
" layer_init(nn.Linear(256, 128)),\n",
" nn.ReLU(),\n",
" )\n",
" self.actor_dis = layer_init(nn.Linear(128, self.discrete_size), std=0.01)\n",
" self.actor_mean = layer_init(nn.Linear(128, self.continuous_size), std=0.01)\n",
" self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))\n",
" self.critic = layer_init(nn.Linear(128, 1), std=1)\n",
"\n",
" def get_value(self, state: torch.Tensor):\n",
" return self.critic(self.network(state))\n",
"\n",
" def get_actions_value(self, state: torch.Tensor, actions=None):\n",
" hidden = self.network(state)\n",
" # discrete\n",
" dis_logits = self.actor_dis(hidden)\n",
" split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)\n",
" multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]\n",
" # continuous\n",
" actions_mean = self.actor_mean(hidden)\n",
" action_logstd = self.actor_logstd.expand_as(actions_mean)\n",
" action_std = torch.exp(action_logstd)\n",
" con_probs = Normal(actions_mean, action_std)\n",
"\n",
" if actions is None:\n",
" disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
" conAct = con_probs.sample()\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" disAct = actions[:, 0 : env.unity_discrete_type].T\n",
" conAct = actions[:, env.unity_discrete_type :]\n",
" dis_log_prob = torch.stack(\n",
" [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]\n",
" )\n",
" dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])\n",
" return (\n",
" actions,\n",
" dis_log_prob.sum(0),\n",
" dis_entropy.sum(0),\n",
" con_probs.log_prob(conAct).sum(1),\n",
" con_probs.entropy().sum(1),\n",
" self.critic(hidden),\n",
" )\n",
"\n",
"\n",
"mymodel = torch.load(\"../PPO-Model/SmallArea-256-128-hybrid.pt\")\n",
"mymodel.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"x = torch.randn(2, 3).to(\"cuda\")\n",
"print(x)\n",
"print(torch.cat((x, x, x), 0))\n",
"print(torch.cat((x, x, x), 1))\n",
"\n",
"aa = torch.empty(0).to(\"cuda\")\n",
"torch.cat([aa,x])\n",
"bb = [[]]*2\n",
"print(bb)\n",
"bb.append(x.to(\"cpu\").tolist())\n",
"bb.append(x.to(\"cpu\").tolist())\n",
"print(bb)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-1.1090, 0.4686, 0.6883],\n",
" [-0.1862, -0.3943, -0.0202],\n",
" [ 0.1436, -0.9444, -1.2079],\n",
" [-2.9434, -2.5989, -0.6653],\n",
" [ 0.4668, 0.8548, -0.4641],\n",
" [-0.3956, -0.2832, -0.1889],\n",
" [-0.2801, -0.2092, 1.7254],\n",
" [ 2.7938, -0.7742, 0.7053]], device='cuda:0')\n",
"(8, 0)\n",
"---\n",
"[[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)], [array([-0.1861974 , -0.39429024, -0.02016036], dtype=float32)], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913 , -0.66532046], dtype=float32)], [array([ 0.46684313, 0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891 ], dtype=float32)], [array([-0.28008553, -0.20918302, 1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478, 0.705279 ], dtype=float32)]]\n",
"[[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)], [], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913 , -0.66532046], dtype=float32)], [array([ 0.46684313, 0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891 ], dtype=float32)], [array([-0.28008553, -0.20918302, 1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478, 0.705279 ], dtype=float32)]]\n",
"---\n",
"[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32), array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)]\n",
"vvv tensor([[-1.1090, 0.4686, 0.6883],\n",
" [-1.1090, 0.4686, 0.6883]], device='cuda:0')\n",
"tensor([[-1.1090, 0.4686, 0.6883],\n",
" [-1.1090, 0.4686, 0.6883]], device='cuda:0')\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"agent_num = 8\n",
"ob_buffer = [[]for i in range(agent_num)]\n",
"obs = torch.randn(8, 3).to(\"cuda\")\n",
"print(obs)\n",
"print(np.shape(np.array(ob_buffer)))\n",
"print('---')\n",
"obs_cpu = obs.to(\"cpu\").numpy()\n",
"for i in range(agent_num):\n",
" ob_buffer[i].append(obs_cpu[i])\n",
"print(ob_buffer)\n",
"ob_buffer[1] = []\n",
"print(ob_buffer)\n",
"print('---')\n",
"for i in range(agent_num):\n",
" ob_buffer[i].append(obs_cpu[i])\n",
"print(ob_buffer[0])\n",
"vvv = torch.tensor(ob_buffer[0]).to(\"cuda\")\n",
"print(\"vvv\",vvv)\n",
"empt = torch.tensor([]).to(\"cuda\")\n",
"vvvv = torch.cat((empt,vvv),0)\n",
"print(vvvv)\n",
"vvvv.size()[0]>0"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from AimbotEnv import Aimbot\n",
"from enum import Enum\n",
"import uuid\n",
"from mlagents_envs.side_channel.side_channel import (\n",
" SideChannel,\n",
" IncomingMessage,\n",
" OutgoingMessage,\n",
")\n",
"from typing import List\n",
"\n",
"class Targets(Enum):\n",
" Free = 0\n",
" Go = 1\n",
" Attack = 2\n",
" Num = 3\n",
"TotalRounds = {\"Go\":0,\"Attack\":0,\"Free\":0}\n",
"WinRounds = {\"Go\":0,\"Attack\":0,\"Free\":0}\n",
"\n",
"class AimbotSideChannel(SideChannel):\n",
" def __init__(self, channel_id: uuid.UUID) -> None:\n",
" super().__init__(channel_id)\n",
" def on_message_received(self, msg: IncomingMessage) -> None:\n",
" \"\"\"\n",
" Note: We must implement this method of the SideChannel interface to\n",
" receive messages from Unity\n",
" \"\"\"\n",
" thisMessage = msg.read_string()\n",
" #print(thisMessage)\n",
" thisResult = thisMessage.split(\"|\")\n",
" if(thisResult[0] == \"result\"):\n",
" TotalRounds[thisResult[1]]+=1\n",
" if(thisResult[2] == \"Win\"):\n",
" WinRounds[thisResult[1]]+=1\n",
" #print(TotalRounds)\n",
" #print(WinRounds)\n",
" elif(thisResult[0] == \"Error\"):\n",
" print(thisMessage)\n",
"\t# 发送函数\n",
" def send_string(self, data: str) -> None:\n",
" \"\"\"发送一个字符串给C#\"\"\"\n",
" msg = OutgoingMessage()\n",
" msg.write_string(data)\n",
" super().queue_message_to_send(msg)\n",
"\n",
" def send_bool(self, data: bool) -> None:\n",
" msg = OutgoingMessage()\n",
" msg.write_bool(data)\n",
" super().queue_message_to_send(msg)\n",
"\n",
" def send_int(self, data: int) -> None:\n",
" msg = OutgoingMessage()\n",
" msg.write_int32(data)\n",
" super().queue_message_to_send(msg)\n",
"\n",
" def send_float(self, data: float) -> None:\n",
" msg = OutgoingMessage()\n",
" msg.write_float32(data)\n",
" super().queue_message_to_send(msg)\n",
"\n",
" def send_float_list(self, data: List[float]) -> None:\n",
" msg = OutgoingMessage()\n",
" msg.write_float32_list(data)\n",
" super().queue_message_to_send(msg)\n",
" \n",
"SIDE_CHANNEL_UUID = uuid.UUID(\"8bbfb62a-99b4-457c-879d-b78b69066b5e\")\n",
"ENV_PATH = \"../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward/Aimbot-ParallelEnv\"\n",
"aimBotsideChannel = AimbotSideChannel(SIDE_CHANNEL_UUID)\n",
"env = Aimbot(envPath=ENV_PATH, workerID=123, basePort=999,side_channels=[aimBotsideChannel])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.distributions.normal import Normal\n",
"from torch.distributions.categorical import Categorical\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() and True else \"cpu\")\n",
"\n",
"def layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n",
" torch.nn.init.orthogonal_(layer.weight, std)\n",
" torch.nn.init.constant_(layer.bias, bias_const)\n",
" return layer\n",
"\n",
"class PPOAgent(nn.Module):\n",
" def __init__(self, env: Aimbot,targetNum:int):\n",
" super(PPOAgent, self).__init__()\n",
" self.targetNum = targetNum\n",
" self.discrete_size = env.unity_discrete_size\n",
" self.discrete_shape = list(env.unity_discrete_branches)\n",
" self.continuous_size = env.unity_continuous_size\n",
"\n",
" self.network = nn.Sequential(\n",
" layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 500)),\n",
" nn.ReLU(),\n",
" layer_init(nn.Linear(500, 300)),\n",
" nn.ReLU(),\n",
" )\n",
" self.actor_dis = nn.ModuleList([layer_init(nn.Linear(300, self.discrete_size), std=0.01) for i in range(targetNum)])\n",
" self.actor_mean = nn.ModuleList([layer_init(nn.Linear(300, self.continuous_size), std=0.01) for i in range(targetNum)])\n",
" self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])\n",
" self.critic = layer_init(nn.Linear(300, 1), std=1)\n",
"\n",
" def get_value(self, state: torch.Tensor):\n",
" return self.critic(self.network(state))\n",
"\n",
" def get_actions_value(self, state: torch.Tensor, actions=None):\n",
" hidden = self.network(state)\n",
" targets = torch.argmax(state[:,0:self.targetNum],dim=1)\n",
"\n",
" # discrete\n",
" # 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出\n",
" dis_logits = torch.stack([self.actor_dis[targets[i]](hidden[i]) for i in range(targets.size()[0])])\n",
" split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)\n",
" multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]\n",
" # continuous\n",
" actions_mean = torch.stack([self.actor_mean[targets[i]](hidden[i]) for i in range(targets.size()[0])]) # self.actor_mean(hidden)\n",
" # action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean)\n",
" # print(action_logstd)\n",
" action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)\n",
" con_probs = Normal(actions_mean, action_std)\n",
"\n",
" if actions is None:\n",
" if True:\n",
" # select actions base on probability distribution model\n",
" disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
" conAct = con_probs.sample()\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" # select actions base on best probability distribution\n",
" disAct = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])\n",
" conAct = actions_mean\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" disAct = actions[:, 0 : env.unity_discrete_type].T\n",
" conAct = actions[:, env.unity_discrete_type :]\n",
" dis_log_prob = torch.stack(\n",
" [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]\n",
" )\n",
" dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])\n",
" return (\n",
" actions,\n",
" dis_log_prob.sum(0),\n",
" dis_entropy.sum(0),\n",
" con_probs.log_prob(conAct).sum(1),\n",
" con_probs.entropy().sum(1),\n",
" self.critic(hidden),\n",
" )\n",
"agent = PPOAgent(env,4).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1. , -10.343613 , 0. , -7.367299 ,\n",
" 0. , 0. , 30. , -10.343662 ,\n",
" 1. , -33.708736 , 1. , 1. ,\n",
" 1. , 1. , 2. , 1. ,\n",
" 1. , 1. , 2. , 2. ,\n",
" 2. , 1. , 1. , 1. ,\n",
" 33.270493 , 39.50663 , 49.146526 , 32.595673 ,\n",
" 30.21616 , 21.163797 , 46.9299 , 1.3264331 ,\n",
" 1.2435672 , 1.2541904 , 30.08522 , 30.041445 ,\n",
" 21.072094 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -5.5892515 ,\n",
" 1. , -29.907726 , 1. , 1. ,\n",
" 1. , 1. , 2. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 41.408752 , 47.830173 , 45.03225 , 31.905174 ,\n",
" 41.849663 , 41.849648 , 43.001434 , 45.0322 ,\n",
" 47.48242 , 40.00285 , 41.668346 , 41.607723 ,\n",
" 41.668335 , 0. ],\n",
" [ 1. , 2.9582403 , 0. , -4.699738 ,\n",
" 0. , 0. , 30. , -5.412487 ,\n",
" 1. , -32.79967 , 1. , 2. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 2. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 20.17488 , 49.507687 , 48.162056 , 45.98998 ,\n",
" 44.75835 , 31.08564 , 32.865173 , 24.676666 ,\n",
" 12.952409 , 39.69923 , 44.564423 , 44.49966 ,\n",
" 44.564495 , 0. ],\n",
" [ 2. , -0.20171738, 0. , -10.340863 ,\n",
" 0. , 0. , 30. , -22.987915 ,\n",
" 1. , -34.37514 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 2. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 11.631058 , 13.872022 , 18.006863 , 27.457632 ,\n",
" 46.343067 , 46.343094 , 20.155125 , 49.867714 ,\n",
" 52.965984 , 56.775608 , 46.14223 , 46.075138 ,\n",
" 46.142246 , 0. ],\n",
" [ 2. , -14.687862 , 0. , -12.615574 ,\n",
" 0. , 0. , 30. , 15.125373 ,\n",
" 1. , -30.849268 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 2. ,\n",
" 52.430542 , 48.912865 , 46.05145 , 43.974594 ,\n",
" 42.796673 , 26.467875 , 11.072432 , 7.190229 ,\n",
" 5.483198 , 4.5500183 , 42.611244 , 42.549267 ,\n",
" 18.856438 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -4.0314903 ,\n",
" 1. , -29.164669 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 44.074184 , 46.9762 , 44.228096 , 42.2335 ,\n",
" 41.102253 , 41.102367 , 42.233757 , 44.22849 ,\n",
" 44.321827 , 37.335304 , 40.924183 , 40.86467 ,\n",
" 40.924236 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -18.603981 ,\n",
" 1. , -29.797592 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 2. , 2. , 2. ,\n",
" 19.134174 , 22.76088 , 29.468704 , 42.88739 ,\n",
" 41.738823 , 41.739002 , 42.88781 , 44.913647 ,\n",
" 47.704174 , 51.135338 , 20.418388 , 12.470214 ,\n",
" 12.670923 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -19.07032 ,\n",
" 1. , -30.246218 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 18.336487 , 21.81617 , 28.251017 , 42.977867 ,\n",
" 42.18994 , 42.19034 , 43.351707 , 45.399582 ,\n",
" 48.22037 , 51.68873 , 42.00719 , 41.94621 ,\n",
" 42.00739 , 0. ]], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state,_,_ = env.getSteps()\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8],\n",
" [2, 3, 5, 6, 7, 8]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from torch.distributions.categorical import Categorical\n",
"\n",
"aaa = torch.tensor([[1,2,3,4,5,6,7,8] for i in range(10)])\n",
"aaasplt = torch.split(aaa,[3,3,2],dim=1)\n",
"multicate = [Categorical(logits=thislo) for thislo in aaasplt]\n",
"disact = torch.stack([ctgr.sample() for ctgr in multicate])\n",
"#print(aaa)\n",
"#print(aaasplt)\n",
"torch.cat([aaa[:,1:3],aaa[:,4:]],dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.distributions.categorical import Categorical\n",
"\n",
"logits = torch.Tensor([[0.5,0.25]])\n",
"lgst = Categorical(logits=logits)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0000, -0.2500]])\n",
"tensor([[1.0000, 0.7788]])\n",
"tensor([[1.7788]])\n",
"tensor([[0.5622, 0.4378]])\n"
]
},
{
"data": {
"text/plain": [
"tensor([[0.6854]])"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# calculate entropy of log probability\n",
"def entropy(logits):\n",
" a0 = logits - logits.max(1, keepdim=True)[0]\n",
" print(a0)\n",
" ea0 = torch.exp(a0)\n",
" print(ea0)\n",
" z0 = ea0.sum(1, keepdim=True)\n",
" print(z0)\n",
" p0 = ea0 / z0\n",
" print(p0)\n",
" return (p0 * (torch.log(z0) - a0)).sum(1, keepdim=True)\n",
"entropy(logits)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.5000, 0.2500, 0.2500]])\n",
"tensor([[1.0397]])\n"
]
},
{
"data": {
"text/plain": [
"tensor([1.0397])"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"probs = torch.Tensor([[0.5,0.25,0.25]])\n",
"print(probs)\n",
"# calculate entropy of probability\n",
"def entropy2(probs):\n",
" return -(probs * torch.log(probs)).sum(1, keepdim=True)\n",
"print(entropy2(probs))\n",
"lgst2 = Categorical(probs=probs)\n",
"lgst2.entropy()"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[2.1121]])\n"
]
},
{
"data": {
"text/plain": [
"tensor([[2.1121]])"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torch.distributions.normal import Normal\n",
"mu = torch.Tensor([[1]])\n",
"sigma = torch.Tensor([[2]])\n",
"# calculate entropy of Normal distribution\n",
"def entropy3(mu,sigma):\n",
" return 0.5 * (1 + torch.log(2 * sigma * sigma * 3.1415926))\n",
"\n",
"print(entropy3(mu,sigma))\n",
"nm = Normal(mu,sigma)\n",
"nm.entropy()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from AimbotEnv import Aimbot\n",
"from enum import Enum\n",
"from torch.distributions.normal import Normal\n",
"from torch.distributions.categorical import Categorical\n",
"\n",
"DEFAULT_SEED = 9331\n",
"ENV_PATH = \"../Build/3.0/Mix/Aimbot-ParallelEnv\"\n",
"WAND_ENTITY = \"koha9\"\n",
"WORKER_ID = 1\n",
"BASE_PORT = 1000\n",
"\n",
"# tensorboard names\n",
"game_name = \"Aimbot_Target_Hybrid_PMNN_V3\"\n",
"game_type = \"PList_Mix_LeakyReLU_512Batch\"\n",
"\n",
"# max round steps per agent is 2500/Decision_period, 25 seconds\n",
"# !!!check every parameters before run!!!\n",
"\n",
"TOTAL_STEPS = 3150000\n",
"BATCH_SIZE = 512\n",
"MAX_TRAINNING_DATASETS = 6000\n",
"DECISION_PERIOD = 1\n",
"LEARNING_RATE = 6.5e-4\n",
"GAMMA = 0.99\n",
"GAE_LAMBDA = 0.95\n",
"EPOCHS = 3\n",
"CLIP_COEF = 0.11\n",
"LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence\n",
"POLICY_COEF = [1.0, 1.0, 1.0, 1.0]\n",
"ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]\n",
"CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]\n",
"TARGET_LEARNING_RATE = 1e-6\n",
"FREEZE_VIEW_NETWORK = False\n",
"\n",
"BROADCASTREWARD = False\n",
"ANNEAL_LEARNING_RATE = True\n",
"CLIP_VLOSS = True\n",
"NORM_ADV = False\n",
"TRAIN = True\n",
"SAVE_MODEL = True\n",
"WANDB_TACK = True\n",
"LOAD_DIR = None\n",
"LOAD_DIR = \"../PPO-Model/PList_Goto_LeakyReLU_256Batch_9331_1678785562/PList_Goto_LeakyReLU_256Batch_9331_1678785562_8.370919.pt\"\n",
"\n",
"# public data\n",
"class Targets(Enum):\n",
" Free = 0\n",
" Go = 1\n",
" Attack = 2\n",
" Defence = 3\n",
" Num = 4\n",
"TARGET_STATE_SIZE = 6\n",
"INAREA_STATE_SIZE = 1\n",
"TIME_STATE_SIZE = 1\n",
"GUN_STATE_SIZE = 1\n",
"MY_STATE_SIZE = 4\n",
"TOTAL_T_SIZE = TARGET_STATE_SIZE+INAREA_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE\n",
"BASE_WINREWARD = 999\n",
"BASE_LOSEREWARD = -999\n",
"TARGETNUM= 4\n",
"ENV_TIMELIMIT = 30\n",
"RESULT_BROADCAST_RATIO = 1/ENV_TIMELIMIT\n",
"\n",
"def layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n",
" torch.nn.init.orthogonal_(layer.weight, std)\n",
" torch.nn.init.constant_(layer.bias, bias_const)\n",
" return layer\n",
"\n",
"\n",
"class PPOAgent(nn.Module):\n",
" def __init__(self, env: Aimbot,targetNum:int):\n",
" super(PPOAgent, self).__init__()\n",
" self.targetNum = targetNum\n",
" self.stateSize = env.unity_observation_shape[0]\n",
" self.agentNum = env.unity_agent_num\n",
" self.targetSize = TARGET_STATE_SIZE\n",
" self.timeSize = TIME_STATE_SIZE\n",
" self.gunSize = GUN_STATE_SIZE\n",
" self.myStateSize = MY_STATE_SIZE\n",
" self.raySize = env.unity_observation_shape[0] - TOTAL_T_SIZE\n",
" self.nonRaySize = TOTAL_T_SIZE\n",
" self.head_input_size = env.unity_observation_shape[0] - self.targetSize-self.timeSize-self.gunSize# except target state input\n",
"\n",
" self.discrete_size = env.unity_discrete_size\n",
" self.discrete_shape = list(env.unity_discrete_branches)\n",
" self.continuous_size = env.unity_continuous_size\n",
"\n",
" self.viewNetwork = nn.Sequential(\n",
" layer_init(nn.Linear(self.raySize, 200)),\n",
" nn.LeakyReLU()\n",
" )\n",
" self.targetNetworks = nn.ModuleList([nn.Sequential(\n",
" layer_init(nn.Linear(self.nonRaySize, 100)),\n",
" nn.LeakyReLU()\n",
" )for i in range(targetNum)])\n",
" self.middleNetworks = nn.ModuleList([nn.Sequential(\n",
" layer_init(nn.Linear(300,200)),\n",
" nn.LeakyReLU()\n",
" )for i in range(targetNum)])\n",
" self.actor_dis = nn.ModuleList([layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(targetNum)])\n",
" self.actor_mean = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(targetNum)])\n",
" # self.actor_logstd = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=1) for i in range(targetNum)])\n",
" # self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))\n",
" self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1,self.continuous_size))for i in range(targetNum)]) # nn.Parameter(torch.zeros(1, self.continuous_size))\n",
" self.critic = nn.ModuleList([layer_init(nn.Linear(200, 1), std=1)for i in range(targetNum)])\n",
"\n",
" def get_value(self, state: torch.Tensor):\n",
" target = state[:,0].to(torch.int32) # int\n",
" thisStateNum = target.size()[0]\n",
" viewInput = state[:,-self.raySize:] # all ray input\n",
" targetInput = state[:,:self.nonRaySize]\n",
" viewLayer = self.viewNetwork(viewInput)\n",
" targetLayer = torch.stack([self.targetNetworks[target[i]](targetInput[i]) for i in range(thisStateNum)])\n",
" middleInput = torch.cat([viewLayer,targetLayer],dim = 1)\n",
" middleLayer = torch.stack([self.middleNetworks[target[i]](middleInput[i]) for i in range(thisStateNum)])\n",
" criticV = torch.stack([self.critic[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.critic\n",
" return criticV\n",
"\n",
" def get_actions_value(self, state: torch.Tensor, actions=None):\n",
" target = state[:,0].to(torch.int32) # int\n",
" thisStateNum = target.size()[0]\n",
" viewInput = state[:,-self.raySize:] # all ray input\n",
" targetInput = state[:,:self.nonRaySize]\n",
" viewLayer = self.viewNetwork(viewInput)\n",
" targetLayer = torch.stack([self.targetNetworks[target[i]](targetInput[i]) for i in range(thisStateNum)])\n",
" middleInput = torch.cat([viewLayer,targetLayer],dim = 1)\n",
" middleLayer = torch.stack([self.middleNetworks[target[i]](middleInput[i]) for i in range(thisStateNum)])\n",
"\n",
" # discrete\n",
" # 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出\n",
" dis_logits = torch.stack([self.actor_dis[target[i]](middleLayer[i]) for i in range(thisStateNum)])\n",
" split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)\n",
" multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]\n",
" # continuous\n",
" actions_mean = torch.stack([self.actor_mean[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.actor_mean(hidden)\n",
" # action_logstd = torch.stack([self.actor_logstd[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.actor_logstd(hidden)\n",
" # action_logstd = self.actor_logstd.expand_as(actions_mean) # self.actor_logstd.expand_as(actions_mean)\n",
" action_logstd = torch.stack([torch.squeeze(self.actor_logstd[target[i]],0) for i in range(thisStateNum)])\n",
" # print(action_logstd)\n",
" action_std = torch.exp(action_logstd) # torch.exp(action_logstd)\n",
" con_probs = Normal(actions_mean, action_std)\n",
" # critic\n",
" criticV = torch.stack([self.critic[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.critic\n",
"\n",
" if actions is None:\n",
" if True:\n",
" # select actions base on probability distribution model\n",
" disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
" conAct = con_probs.sample()\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" # select actions base on best probability distribution\n",
" # disAct = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])\n",
" conAct = actions_mean\n",
" disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
" conAct = con_probs.sample()\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" disAct = actions[:, 0 : env.unity_discrete_type].T\n",
" conAct = actions[:, env.unity_discrete_type :]\n",
" dis_log_prob = torch.stack(\n",
" [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]\n",
" )\n",
" dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])\n",
" return (\n",
" actions,\n",
" dis_log_prob.sum(0),\n",
" dis_entropy.sum(0),\n",
" con_probs.log_prob(conAct).sum(1),\n",
" con_probs.entropy().sum(1),\n",
" criticV,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"niceGotoLOAD_DIR = \"C:/Users/UCUNI/OneDrive/Unity/ML-Agents/Aimbot-PPO/Aimbot-PPO-Python/PPO-Model/PList_Go_LeakyReLU_9331_1677965178_GOTOModel/PList_Go_LeakyReLU_9331_1677965178_last.pt\"\n",
"badGotoLoar_Dir = \"C:/Users/UCUNI/OneDrive/Unity/ML-Agents/Aimbot-PPO/Aimbot-PPO-Python/PPO-Model/PList_Attack_LeakyReLU_9331_1678547500/PList_Attack_LeakyReLU_9331_1678547500_last.pt\"\n",
"\n",
"niceGotoAgent = torch.load(niceGotoLOAD_DIR)\n",
"badGotoAgent = torch.load(badGotoLoar_Dir)\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"odict_keys(['viewNetwork.0.weight', 'viewNetwork.0.bias', 'targetNetworks.0.0.weight', 'targetNetworks.0.0.bias', 'targetNetworks.1.0.weight', 'targetNetworks.1.0.bias', 'targetNetworks.2.0.weight', 'targetNetworks.2.0.bias', 'targetNetworks.3.0.weight', 'targetNetworks.3.0.bias', 'middleNetworks.0.0.weight', 'middleNetworks.0.0.bias', 'middleNetworks.1.0.weight', 'middleNetworks.1.0.bias', 'middleNetworks.2.0.weight', 'middleNetworks.2.0.bias', 'middleNetworks.3.0.weight', 'middleNetworks.3.0.bias', 'actor_dis.0.weight', 'actor_dis.0.bias', 'actor_dis.1.weight', 'actor_dis.1.bias', 'actor_dis.2.weight', 'actor_dis.2.bias', 'actor_dis.3.weight', 'actor_dis.3.bias', 'actor_mean.0.weight', 'actor_mean.0.bias', 'actor_mean.1.weight', 'actor_mean.1.bias', 'actor_mean.2.weight', 'actor_mean.2.bias', 'actor_mean.3.weight', 'actor_mean.3.bias', 'actor_logstd.0', 'actor_logstd.1', 'actor_logstd.2', 'actor_logstd.3', 'critic.0.weight', 'critic.0.bias', 'critic.1.weight', 'critic.1.bias', 'critic.2.weight', 'critic.2.bias', 'critic.3.weight', 'critic.3.bias'])\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8QAAAEuCAYAAABI7Ns6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABAoElEQVR4nO3debzdVX3v//faw5nHzIQMEEIIk4CAggLihLPi0NbWoXa41VbbarXtr+3tXHvvrVWrxWq1eq11KBat81QVAqiggMxDgEBIyHxyTs64z57W74+z7S/FJJ93f9LLDd/X8/Hg8SDZ73y+372+67vWd+19kpVyzgIAAAAAoGhKj/UJAAAAAADwWGBBDAAAAAAoJBbEAAAAAIBCYkEMAAAAACgkFsQAAAAAgEJiQQwAAAAAKCQWxAAAPIpSSg+mlJ71X1j/VSmlb5jZ16WUrv2vOhcAAI52LIgBAI97KaVXppSuTynNpJT2dP7/11JK6bE+t/+snPMncs6XPBq1UkpXpZR++dGoBQDA0YgFMQDgcS2l9FZJ75H0DkkrJC2X9AZJT5XU9RieGgAAeIyxIAYAPG6llIYl/ZmkX8s5X5FznsoLfphzflXOef5HuZTSx1JKe1NKW1NK/z2lVOq8Vur8emvn2+WPder+6Biv6bw2llL6gyOcy/EppYmD6n4opbTnoNf/KaX05oPO58MppZ0ppYdTSn+RUip3XvsPPwadUrokpXRPSulASunvUkqbHvmtb0rpr1NK4ymlB1JKz+v83tslXSjpspTSdErpsrTg3Z33OZlSui2ldNpPeBkAAPi/FgtiAMDj2fmSuiV9Psj9raRhSeskPU3SayX9Que113X+e3rn9QFJl0lSSukUSe+X9BpJKyUtlrTqUAfIOT8gaVLSWZ3fukjSdErp5M6vnyZpU+f/PyqpKWl9J3+JpB/70eaU0hJJV0j6vc6x75H0lEfEntz5/SWS/krSh1NKKef8B5KukfSmnPNAzvlNneNcJGlDpz1+WtLYod4PAACPByyIAQCPZ0sk7cs5N3/0Gyml73a+qZ1LKV3U+eb1lZJ+r/MN8oOS3qmFRa4kvUrSu3LOW3LO01pYfL4ypVSR9ApJX8o5X935tvkPJbWPcD6bJD0tpbSi8+srOr8+XtKQpFtSSsslPV/Sm3POMznnPZLe3TnHR3q+pDtyzp/tvMf3Str1iMzWnPOHcs4tSf8o6Rgt/Nj4oTQkDUraKCnlnO/KOe88wvsBAOCoVnmsTwAAgP9CY5KWpJQqP1oU55yfIkkppe1a+GB4iaSqpK0H/bmtko7t/P/KQ7xW0cKicqWkbT96Iec8k1I60jeqmyS9WNJ2SVdLukoLC++apGtyzu2U0trO+ew86N/8Kh18nIM88vi5874Otuug12c7NQcOdXI552+nlC6T9D5Ja1NKn5X0tpzz5BHeEwAARy2+IQYAPJ59T9K8pJccIbNPC9+Mrj3o99ZIerjz/zsO8VpT0m5JOyWt/tELKaU+Lfzo8uFs0sLf27248//XauEf9zr4x6W3dc55Sc55pPPfUM751EPU26mDfkS7869mH/JHtg8j/9hv5PzenPPZkk7Rwo9O//Z/oh4AAEcVFsQAgMetnPOEpD+V9HcppVeklAY7/0jWmZL6O5mWpE9Lenvn9bWSfkvSxztlPiXpLZ1/FGtA0l9KurzzjfMVkl6YUrogpdSlhX/A67Bza875Xklzkl4taVPnm9fdkl6uzoK48yPK35D0zpTSUOd8T0gpPe0QJb8s6fSU0qWdH+F+oxb+JW3Xbi38vWhJUkrp3JTSk1NKVUkzWvjm+kg/Ag4AwFGNBTEA4HEt5/xXWljg/o4WFoC7Jf29pN+V9N1O7Ne1sADcooVvbT8p6SOd1z4i6Z+08CPOD2hhkfjrndp3aGER+kktfFs7roUfhz6STZLGcs7bDvp1knTTQZnXamFLqDs7Na/Qwt/9feR72yfpp7Twj2WNaeFb3Ru08A2z4z2SXtH5F6jfq4W/x/yhzjG3dmq+w6wFAMBRJ+X8Yz8tBQAAjkKdLZ22S3pVzvnKx/p8AAD4vx3fEAMAcBRLKT0npTSSUuqW9Pta+Lb5usf4tAAAOCqwIAYA4Oh2vqT7tfCPg71I0qU557nH9pQAADg68CPTAAAAAIBC4htiAAAAAEAhsSAGAAAAABRS5Ugvnv+N3w1/nnrH9kXekcrej2ZvWLM7zDz4vdVWrcZgfMzSEm9nitZcOQ41vc8XBrYcsdklSXNneH/9K6X4PZYrLatW93cGw8zoC3ZYtRrtuC3Gr/G2ysxnTsWh2+Jzl2R9BNTYMOvVyimOmH8jofRwT5hpDnvXcXTlgTBTKXvbiu7dMxRmqru6rFrHnhP3nYdvWGnVaq2qhZlqd9OqtWRoJszsuH+pVStX43YdWBofT5KaN4+EmfIZ8bWWvH64asSrtfne+Br1bY3HOEmaXdcIM6nL6/flnd1hZvT0fVatRjMe70f6vDH6wQeXhZnTN2wLM5J01/XHW7m0Ju5jrYYxp0latmQyzOzZF48TrrIxFkrSUy++PczcsNN7Vpibi8ewvj7vWSEb88L07gGrlozxJBkZSeruje+1ru9682j56WNhpmk8A0jSxiV7wsyNPzjRqtWzJn5WmB3v9WoNx9e7XvPGuTxhzJHD8fWRpNKeuFZplfcMs2I0bq9t2xZbtYZuj89r8vS6VSvV4rEp93jzQv9oPE7PTsVzhySVjHutNev1iaXHxPPtvjHvfsxGe61e6819O2+On8ubg17bl4biPt3d4/X75cNxX9260+urD7769w45SPMNMQAAAACgkFgQAwAAAAAKiQUxAAAAAKCQWBADAAAAAAqJBTEAAAAAoJBYEAMAAAAACokFMQAAAACgkFgQAwAAAAAKiQUxAAAAAKCQUs75sC+uv/wvDv9iR6NW8Y40V7ZiS9eOh5nxA/1WrZTC01er5X0mUNrRE2aaQy2rVmWoHteaqlq1nI80Kv0Nq1SrHl+jnoF5q9bc/t4wU+prWrWqXXFu3jieJHWN1sJMfX98rSWpPBxfx54er+1np7rDTJ4177VK3O+7hrzrWKnEfbrd9u6h2pjRJwa89srN+Ji55o05Tp9oPeSNOa0lcZ+odHvjRNO53jlZtaqD8fVuTHdZtboG4/dYnzXHL+P8u/rj40lS3Tj/VG1btRwVY1ySpO7uODe9z+tfg0unrdzMVDyGtY17SJLKXXGbtSa9610ajO9v596WpNz2+r5XLI6kshGSVDGud6XqjQG16Xhe0JQ3L/SujPtOs+mNmY058/425Hp8vavGmCNJi0fi97hr14hVS/NxW1QmvPZqLo/P3+1fpUp8P/b3efP7zIwxTrS8+8y5jhXzma9ptH3fUDxvS9Ls/r44ZKwVJCkZY2HJHSeqcVskc4hrNOL26jLnq7m9cXs5z76SVCrF7VUx+rMkzdcevTHngZ/7/UO2LN8QAwAAAAAKiQUxAAAAAKCQWBADAAAAAAqJBTEAAAAAoJBYEAMAAAAACokFMQAAAACgkFgQAwAAAAAKiQUxAAAAAKCQjrije2OyK65Q9TahVo+3+fLebaNhpjTQsGolY7nfnvY2e64eF2/43jI3ji6V47YoD3ibaGejWd2N1Tes2RVm7t+11KpV6o3Pv6fX29y7Nmf0w+6WVatUivtradDrX63aEW8fSdKM2b8Gl8X9a6bUY9XqH4w3rJ/aO2DVqjfivlMdnbdqyWj7xaNxO0jSxGS8eXy7y+sT9QPdYSYt995jd0/cd+rGxveS1Lcibova1kGrVnM2PmZladxvJKkxH/f7gZE5q9b07rgftnq8z20XLZsMM/t3D1m11IyP2ZiN20GSupcbfbrpjdHTE17fyeaYb9UypvjysDeWO+fV1efVaj0Y9532Sq9P9/XH9/ecMw9J6u+La7WyeX2SMV8t8tpr/sF4rFh52m6r1t5S3PZ1Y5yQpDwf32vJmDskadfDxvOj+aygSvxw1e73nhXKzrOoOQaUjefHSWNclaTUiE+ssti7h9rl+Bq12957rPbEz4+NRtmqVTauUWvKe06T0Q1
"text/plain": [
"<Figure size 1440x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(niceGotoAgent.state_dict().keys())\n",
"# 获取第一层权重张量\n",
"Goodweights = niceGotoAgent.state_dict()['targetNetworks.1.0.weight'].cpu()\n",
"Badweights = badGotoAgent.state_dict()['targetNetworks.1.0.weight'].cpu()\n",
"# 将权重张量转换为numpy数组并可视化\n",
"\n",
"fig,(ax1,ax2) = plt.subplots(2,1,figsize=(20,5))\n",
"ax1.imshow(np.rot90(Goodweights.numpy()))\n",
"ax1.set_title('Good weights')\n",
"ax1.axis('off')\n",
"ax2.imshow(np.rot90(Badweights.numpy()))\n",
"ax2.set_title('Bad weights')\n",
"ax2.axis('off')\n",
"# 显示图表\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8QAAAEuCAYAAABI7Ns6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA7jElEQVR4nO3de6xl6Vnf+d+71r6dvc+1uqq6Xe6LbXDhW8cOdoYAji/CYxQTY0shGYS5OYqUDCaBuWSQYTIBBBMJwiCcDpmMA0oCceSog8YGgtISwd3mYhIbG2xj2cbubld13bqqTp3bvq291jt/nN2kaFed58d0kU71+n6klrpqP+dZ73rXe91713lTzlkAAAAAALRN8WwXAAAAAACAZwMbYgAAAABAK7EhBgAAAAC0EhtiAAAAAEArsSEGAAAAALQSG2IAAAAAQCuxIQYA4BZKKT2WUnrTn2H+d6SUHjJjvzul9Jt/VmUBAOB2x4YYAPCcl1L61pTS76aUDlJKl5b//z0ppfRsl+1PK+f8r3POb74VuVJKH0op/c1bkQsAgNsRG2IAwHNaSul/kfQzkn5S0l2S7pT0tyV9vaTes1g0AADwLGNDDAB4zkopbUj6UUnfk3N+MOe8lw99POf8jpzz7Km4lNK/Sik9mVJ6PKX0v6eUiuVrxfLPjy8/Xf5Xy7xPXeM7lq9dSSn90BFleWFK6dp1ed+bUrp03eu/kFL6/uvK83MppfMppSdSSj+WUiqXr/2Jr0GnlN6cUvpsSmknpfSzKaWHn/6pb0rpH6WUtlNKj6aU/vLy735c0l+S9EBKaT+l9EA69NPL+9xNKX0ypfSKZ/gYAAD4bxYbYgDAc9nXSupL+kAQ948lbUh6kaTXS/pOSe9cvvbdy//euHx9VdIDkpRSepmkfyrpOySdknSHpLtvdIGc86OSdiX9+eVfvU7Sfkrppcs/v17Sw8v//xeSFpK+chn/Zklf9tXmlNJxSQ9Kevfy2p+V9HVPC/ua5d8fl/QTkn4upZRyzj8k6cOSvjfnvJpz/t7ldV4n6fSyPv66pCs3uh8AAJ4L2BADAJ7Ljku6nHNePPUXKaXfXn5SO0kpvW75yeu3Snr38hPkxyT9lA43uZL0Dkn/V875iznnfR1uPr81pdSR9C2SfiXn/Mjy0+a/L6k5ojwPS3p9Sumu5Z8fXP75hZLWJf1+SulOSW+R9P0554Oc8yVJP70s49O9RdKnc86/tLzH90i68LSYx3PO780515L+paTn6fBr4zdSSVqT9BJJKef8mZzz+SPuBwCA21rn2S4AAAB/hq5IOp5S6jy1Kc45f50kpZTO6vCN4eOSupIev+7nHpf0/OX/n7rBax0dbipPSTrz1As554OU0lGfqD4s6ZslnZX0iKQP6XDjPZX04Zxzk1K6b1me89f9zq/i+utc5+nXz8v7ut6F614fL3Ou3qhwOef/mFJ6QNI/kXRfSumXJP2vOefdI+4JAIDbFp8QAwCey35H0kzS246IuazDT0bvu+7v7pX0xPL/z93gtYWki5LOS7rnqRdSSkMdfnX5Zh7W4b/bfcPy/39Th7/c6/qvS59Zlvl4znlz+d96zvnlN8h3Xtd9RXv5W7Nv+JXtm8hf9hc5vyfn/GpJL9PhV6f/3p8iHwAAtxU2xACA56yc8zVJPyLpZ1NK35JSWlv+kqxXSRotY2pJ/1bSjy9fv0/S/yzpF5dp/o2k/2n5S7FWJf2fkt6//MT5QUl/JaX02pRST4e/wOumc2vO+fOSJpK+XdLDy09eL0r6q1puiJdfUX5I0k+llNaX5f2KlNLrb5DyVyXdn1J6+/Ir3O/S4W/Sdl3U4b+LliSllP5CSulrUkpdSQc6/OT6qK+AAwBwW2NDDAB4Tss5/4QON7j/mw43gBcl/TNJPyDpt5dhf0eHG8Av6vBT2/dJ+vnlaz8v6Rd0+BXnR3W4Sfw7y9yf1uEm9H06/LR2W4dfhz7Kw5Ku5JzPXPfnJOn3rov5Th0eCfWHy5wP6vDf/j793i5L+ms6/GVZV3T4qe5HdfgJs+NnJH3L8jdQv0eH/475vctrPr7M+ZNmLgAAbjsp5y/7thQAALgNLY90OivpHTnn33i2ywMAwH/r+IQYAIDbWErpG1NKmymlvqQf1OGnzR95losFAMBtgQ0xAAC3t6+V9AUd/nKwt0p6e8558uwWCQCA2wNfmQYAAAAAtBKfEAMAAAAAWokNMQAAAACglTpHvfi1D/1A+H3qc2ePeVcqva9mn773Yhjz2O/cY+Wq1uJrFse9kynqSRkHLbz3F1a/eGS1S5Imr/T++VdK8T2WndrK1f+ttTBm65vOWbmqJq6L7Q97R2XmV+3FQZ+Myy7JeguoOj32cuUUh5j/IqF4YhDGLDa857h1aieM6ZTesaJPXloPY7oXelau578mbjtPfPSUlau+exrGdPsLK9fx9YMw5twXTli5cjeu19UT8fUkafGJzTCmfGX8rCWvHd696eX63OfjZzR8PB7jJGn8oiqMST2v3Zfn+2HM1v2XrVzVIh7vN4feGP3YYyfDmPtPnwljJOkzv/tCKy7dG7exujLmNEknj++GMZcux+OEqzTGQkn6+jd8Koz56HlvrTCZxGPYcOitFbIxL+xfXLVyyRhPkhEjSf2VuK/1ftubR8s3XgljFsYaQJJecvxSGPOx//xiK9fg3nitMN5e8XJtxM97PvXGuXzNmCM34ucjScWlOFdxt7eGuWsrrq8zZ+6wcq1/Ki7X7v1zK1eaxmNTHnjzwmgrHqfHe/HcIUmF0dfqsdcmTjwvnm8vX/H6Yzbq6577vLnv/Cfidflizav7Yj1u0/2B1+7v3Ijb6uPnvbb62Le/+4aDNJ8QAwAAAABaiQ0xAAAAAKCV2BADAAAAAFqJDTEAAAAAoJXYEAMAAAAAWokNMQAAAACgldgQAwAAAABaiQ0xAAAAAKCV2BADAAAAAFop5Zxv+uJXvv/Hbv7iUjXteFealFbYifu2w5jtnZGVK6Ww+Kpr7z2B4twgjFms11auzvo8zrXXtXI5b2l0RpWVqp7Hz2iwOrNyTa6uhDHFcGHl6vbiuJlxPUnqbU3DmPnV+FlLUrkRP8fBwKv78V4/jMljs6914nbfW/eeY6cTt+mm8frQ9IrRJla9+sqL+Jp56o05Tpuov+SNOfXxuE10+t44sXCed05Wru5a/Lyr/Z6Vq7cW3+N8bI5fRvl7o/h6kjQ3yp+6jZXL0THGJUnq9+O4/cte+1o7sW/FHezFY1hj9CFJKntxndW73vMu1uL+7fRtScqN1/a9ZHFIKo0gSR3jeXe63hgw3Y/nBe1588LKqbjtLBbemFlNzP5tyPP4eXeNMUeS7tiM7/HChU0rl2ZxXXSuefW1uDMuv9u+ik7cH0dDb34/ODDGidrrZ85z7JhrvoVR98P1eN6WpPHVYRxk7BUkKRljYeGOE924LpI5xFVVXF89c76aPBnXl7P2laSiiOurY7RnSZpNb92Y8+i3/eANa5ZPiAEAAAAArcSGGAAAAADQSmyIAQAAAACtxIYYAAAAANBKbIgBAAAAAK3EhhgAAAAA0EpsiAEAAAAArcSGGAAAAADQSkee6F7t9uIMXe8Qag28w5efPLMVxhSrlZUrGdv9Zt877Ln7gvjA99o8OLoo47ooV71DtLNRre7B6qfvvRDGfOHCCStXsRKXf7DiHe49nRjtsF9buYoibq/Fmte+6umR3UeSdGC2r7WTcfs6KAZWrtFafGD93pOrVq55Fbed7tbMyiWj7u/YiutBkq7txofHNz2vTcx3+mFMutO7x/4gbjtz4+B7SRreFdfF9PE1K9diHF+zcyJuN5JUzeJ2v7o5sXLtX4zbYT3w3rc9dnI3jLl6cd3KpUV8zWoc14Mk9e802vTCG6P3r3ltJ5tjvpXLmOLLDW8sd8rVG3q56sfittOc8tr0cBT374kzD0kaDeNcdTafTzLmq2Nefc0ei8eKU6+4aOV6sojrfm6ME5KUZ3FfS8bcIUkXnjDWj+ZaQZ14cdWMvLVC6axFzTGgNNaPu8a4KkmpigvWucPrQ00ZP6Om8e6
"text/plain": [
"<Figure size 1440x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"badGotoAgent.targetNetworks[1] = niceGotoAgent.targetNetworks[1]\n",
"badGotoAgent.middleNetworks[1] = niceGotoAgent.middleNetworks[1]\n",
"badGotoAgent.actor_dis[1] = niceGotoAgent.actor_dis[1]\n",
"badGotoAgent.actor_mean[1] = niceGotoAgent.actor_mean[1]\n",
"badGotoAgent.actor_logstd[1] = niceGotoAgent.actor_logstd[1]\n",
"badGotoAgent.critic[1] = niceGotoAgent.critic[1]\n",
"# 获取第一层权重张量\n",
"Goodweights = niceGotoAgent.state_dict()['targetNetworks.1.0.weight'].cpu()\n",
"Badweights = badGotoAgent.state_dict()['targetNetworks.1.0.weight'].cpu()\n",
"# 将权重张量转换为numpy数组并可视化\n",
"\n",
"fig,(ax1,ax2) = plt.subplots(2,1,figsize=(20,5))\n",
"ax1.imshow(np.rot90(Goodweights.numpy()))\n",
"ax1.set_title('Good weights')\n",
"ax1.axis('off')\n",
"ax2.imshow(np.rot90(Badweights.numpy()))\n",
"ax2.set_title('Bad weights')\n",
"ax2.axis('off')\n",
"# 显示图表\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"saveDir = \"C:/Users/UCUNI/OneDrive/Unity/ML-Agents/Aimbot-PPO/Aimbot-PPO-Python/PPO-Model/Chimera-1677965178-1678547500.pt\"\n",
"torch.save(badGotoAgent,saveDir)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"import torch\n",
"print(torch.cuda.is_available())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}