Add EndReward Broadcast function

while game over add remaintime/15 to every step's rewards. to improve this round's training weight.
fix get target from states still using onehot decoder bug.
This commit is contained in:
Koha9 2022-12-03 03:58:19 +09:00
parent 3930bcd953
commit 895cd5c118
2 changed files with 220 additions and 73 deletions

View File

@ -24,12 +24,12 @@ from typing import List
bestReward = 0
DEFAULT_SEED = 9331
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-ExtremeReward/Aimbot-ParallelEnv"
DEFAULT_SEED = 933139
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward-Easy/Aimbot-ParallelEnv"
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
WAND_ENTITY = "koha9"
WORKER_ID = 1
BASE_PORT = 1000
WORKER_ID = 2
BASE_PORT = 1001
# max round steps per agent is 2500/Decision_period, 25 seconds
# !!!check every parameters before run!!!
@ -38,7 +38,7 @@ TOTAL_STEPS = 6000000
BATCH_SIZE = 512
MAX_TRAINNING_DATASETS = 8000
DECISION_PERIOD = 1
LEARNING_RATE = 1e-3
LEARNING_RATE = 8e-4
GAMMA = 0.99
GAE_LAMBDA = 0.95
EPOCHS = 4
@ -58,7 +58,11 @@ WANDB_TACK = True
LOAD_DIR = "../PPO-Model/Aimbot-target-last.pt"
# public data
BASE_WINREWARD = 999
BASE_LOSEREWARD = -999
TARGETNUM= 4
ENV_TIMELIMIT = 30
RESULT_BROADCAST_RATIO = 2/ENV_TIMELIMIT
TotalRounds = {"Go":0,"Attack":0,"Free":0}
WinRounds = {"Go":0,"Attack":0,"Free":0}
@ -160,7 +164,7 @@ class PPOAgent(nn.Module):
def get_actions_value(self, state: torch.Tensor, actions=None):
hidden = self.network(state)
targets = torch.argmax(state[:,0:self.targetNum],dim=1)
targets = state[:,0]
# discrete
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
@ -244,7 +248,7 @@ class AimbotSideChannel(SideChannel):
receive messages from Unity
"""
thisMessage = msg.read_string()
#print(thisMessage)
# print(thisMessage)
thisResult = thisMessage.split("|")
if(thisResult[0] == "result"):
TotalRounds[thisResult[1]]+=1
@ -256,7 +260,7 @@ class AimbotSideChannel(SideChannel):
print(thisMessage)
# 发送函数
def send_string(self, data: str) -> None:
"""发送一个字符串给C#"""
# send a string toC#
msg = OutgoingMessage()
msg.write_string(data)
super().queue_message_to_send(msg)
@ -281,6 +285,20 @@ class AimbotSideChannel(SideChannel):
msg.write_float32_list(data)
super().queue_message_to_send(msg)
def broadCastEndReward(rewardBF:list,remainTime:float):
thisRewardBF = rewardBF
if (rewardBF[-1]<=-500):
# print("Lose DO NOT BROAD CAST",rewardBF[-1])
thisRewardBF[-1] = rewardBF[-1]-BASE_LOSEREWARD
thisRewardBF = (np.asarray(thisRewardBF)).tolist()
elif (rewardBF[-1]>=500):
# print("Win! Broadcast reward!",rewardBF[-1])
thisRewardBF[-1] = rewardBF[-1]-BASE_WINREWARD
thisRewardBF = (np.asarray(thisRewardBF)+(remainTime*RESULT_BROADCAST_RATIO)).tolist()
else:
print("!!!!!DIDNT GET RESULT REWARD!!!!!!",rewardBF[-1])
return torch.Tensor(thisRewardBF).to(device)
if __name__ == "__main__":
args = parse_args()
@ -304,7 +322,7 @@ if __name__ == "__main__":
# Tensorboard and WandB Recorder
game_name = "Aimbot_Target"
game_type = "OffPolicy"
game_type = "OffPolicy_HMNN_EndBC"
run_name = f"{game_name}_{game_type}_{args.seed}_{int(time.time())}"
if args.wandb_track:
wandb.init(
@ -398,10 +416,11 @@ if __name__ == "__main__":
# finished a round, send finished memories to training datasets
# compute advantage and discounted reward
#print(i,"over")
thisRewardsTensor = broadCastEndReward(rewards_bf[i],state[i,6])
adv, rt = GAE(
agent,
args,
torch.tensor(rewards_bf[i]).to(device),
thisRewardsTensor,
torch.Tensor(dones_bf[i]).to(device),
torch.tensor(values_bf[i]).to(device),
torch.tensor(next_state[i]).to(device),
@ -416,7 +435,7 @@ if __name__ == "__main__":
con_logprobs = torch.cat(
(con_logprobs, torch.tensor(con_logprobs_bf[i]).to(device)), 0
)
rewards = torch.cat((rewards, torch.tensor(rewards_bf[i]).to(device)), 0)
rewards = torch.cat((rewards, thisRewardsTensor), 0)
values = torch.cat((values, torch.tensor(values_bf[i]).to(device)), 0)
advantages = torch.cat((advantages, adv), 0)
returns = torch.cat((returns, rt), 0)

View File

@ -525,7 +525,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -594,75 +594,203 @@
" super().queue_message_to_send(msg)\n",
" \n",
"SIDE_CHANNEL_UUID = uuid.UUID(\"8bbfb62a-99b4-457c-879d-b78b69066b5e\")\n",
"ENV_PATH = \"../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-ExtremeReward/Aimbot-ParallelEnv\"\n",
"ENV_PATH = \"../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward/Aimbot-ParallelEnv\"\n",
"aimBotsideChannel = AimbotSideChannel(SIDE_CHANNEL_UUID)\n",
"env = Aimbot(envPath=ENV_PATH, workerID=1, basePort=100,side_channels=[aimBotsideChannel])"
"env = Aimbot(envPath=ENV_PATH, workerID=123, basePort=999,side_channels=[aimBotsideChannel])"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'env' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_34852\\4061840787.py\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetSteps\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mNameError\u001b[0m: name 'env' is not defined"
]
}
],
"source": [
"import numpy as np\n",
"state,_,_ = env.getSteps()\n",
"print(state[:,0:4])\n",
"print(np.argmax(state[:,0:4], axis=0))\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3]])\n",
"3\n",
"0\n",
"1\n",
"2\n",
"3\n"
]
}
],
"source": [
"import torch\n",
"one_hot = torch.tensor([])\n",
"aaa = torch.tensor([1,2,3])\n",
"bbb = torch.tensor([1,2,3])\n",
"print(torch.stack([aaa,aaa,aaa]))\n",
"print(aaa.size()[0])\n",
"\n",
"for i in range(4):\n",
" print(i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"[Categorical(logits: torch.Size([8, 3])), Categorical(logits: torch.Size([8, 3])), Categorical(logits: torch.Size([8, 2]))]\n",
"Normal(loc: torch.Size([8, 1]), scale: torch.Size([8, 1]))"
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.distributions.normal import Normal\n",
"from torch.distributions.categorical import Categorical\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() and True else \"cpu\")\n",
"\n",
"def layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n",
" torch.nn.init.orthogonal_(layer.weight, std)\n",
" torch.nn.init.constant_(layer.bias, bias_const)\n",
" return layer\n",
"\n",
"class PPOAgent(nn.Module):\n",
" def __init__(self, env: Aimbot,targetNum:int):\n",
" super(PPOAgent, self).__init__()\n",
" self.targetNum = targetNum\n",
" self.discrete_size = env.unity_discrete_size\n",
" self.discrete_shape = list(env.unity_discrete_branches)\n",
" self.continuous_size = env.unity_continuous_size\n",
"\n",
" self.network = nn.Sequential(\n",
" layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 500)),\n",
" nn.ReLU(),\n",
" layer_init(nn.Linear(500, 300)),\n",
" nn.ReLU(),\n",
" )\n",
" self.actor_dis = nn.ModuleList([layer_init(nn.Linear(300, self.discrete_size), std=0.01) for i in range(targetNum)])\n",
" self.actor_mean = nn.ModuleList([layer_init(nn.Linear(300, self.continuous_size), std=0.01) for i in range(targetNum)])\n",
" self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])\n",
" self.critic = layer_init(nn.Linear(300, 1), std=1)\n",
"\n",
" def get_value(self, state: torch.Tensor):\n",
" return self.critic(self.network(state))\n",
"\n",
" def get_actions_value(self, state: torch.Tensor, actions=None):\n",
" hidden = self.network(state)\n",
" targets = torch.argmax(state[:,0:self.targetNum],dim=1)\n",
"\n",
" # discrete\n",
" # 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出\n",
" dis_logits = torch.stack([self.actor_dis[targets[i]](hidden[i]) for i in range(targets.size()[0])])\n",
" split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)\n",
" multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]\n",
" # continuous\n",
" actions_mean = torch.stack([self.actor_mean[targets[i]](hidden[i]) for i in range(targets.size()[0])]) # self.actor_mean(hidden)\n",
" # action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean)\n",
" # print(action_logstd)\n",
" action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)\n",
" con_probs = Normal(actions_mean, action_std)\n",
"\n",
" if actions is None:\n",
" if True:\n",
" # select actions base on probability distribution model\n",
" disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
" conAct = con_probs.sample()\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" # select actions base on best probability distribution\n",
" disAct = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])\n",
" conAct = actions_mean\n",
" actions = torch.cat([disAct.T, conAct], dim=1)\n",
" else:\n",
" disAct = actions[:, 0 : env.unity_discrete_type].T\n",
" conAct = actions[:, env.unity_discrete_type :]\n",
" dis_log_prob = torch.stack(\n",
" [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]\n",
" )\n",
" dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])\n",
" return (\n",
" actions,\n",
" dis_log_prob.sum(0),\n",
" dis_entropy.sum(0),\n",
" con_probs.log_prob(conAct).sum(1),\n",
" con_probs.entropy().sum(1),\n",
" self.critic(hidden),\n",
" )\n",
"agent = PPOAgent(env,4).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1. , -10.343613 , 0. , -7.367299 ,\n",
" 0. , 0. , 30. , -10.343662 ,\n",
" 1. , -33.708736 , 1. , 1. ,\n",
" 1. , 1. , 2. , 1. ,\n",
" 1. , 1. , 2. , 2. ,\n",
" 2. , 1. , 1. , 1. ,\n",
" 33.270493 , 39.50663 , 49.146526 , 32.595673 ,\n",
" 30.21616 , 21.163797 , 46.9299 , 1.3264331 ,\n",
" 1.2435672 , 1.2541904 , 30.08522 , 30.041445 ,\n",
" 21.072094 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -5.5892515 ,\n",
" 1. , -29.907726 , 1. , 1. ,\n",
" 1. , 1. , 2. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 41.408752 , 47.830173 , 45.03225 , 31.905174 ,\n",
" 41.849663 , 41.849648 , 43.001434 , 45.0322 ,\n",
" 47.48242 , 40.00285 , 41.668346 , 41.607723 ,\n",
" 41.668335 , 0. ],\n",
" [ 1. , 2.9582403 , 0. , -4.699738 ,\n",
" 0. , 0. , 30. , -5.412487 ,\n",
" 1. , -32.79967 , 1. , 2. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 2. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 20.17488 , 49.507687 , 48.162056 , 45.98998 ,\n",
" 44.75835 , 31.08564 , 32.865173 , 24.676666 ,\n",
" 12.952409 , 39.69923 , 44.564423 , 44.49966 ,\n",
" 44.564495 , 0. ],\n",
" [ 2. , -0.20171738, 0. , -10.340863 ,\n",
" 0. , 0. , 30. , -22.987915 ,\n",
" 1. , -34.37514 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 2. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 11.631058 , 13.872022 , 18.006863 , 27.457632 ,\n",
" 46.343067 , 46.343094 , 20.155125 , 49.867714 ,\n",
" 52.965984 , 56.775608 , 46.14223 , 46.075138 ,\n",
" 46.142246 , 0. ],\n",
" [ 2. , -14.687862 , 0. , -12.615574 ,\n",
" 0. , 0. , 30. , 15.125373 ,\n",
" 1. , -30.849268 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 2. ,\n",
" 52.430542 , 48.912865 , 46.05145 , 43.974594 ,\n",
" 42.796673 , 26.467875 , 11.072432 , 7.190229 ,\n",
" 5.483198 , 4.5500183 , 42.611244 , 42.549267 ,\n",
" 18.856438 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -4.0314903 ,\n",
" 1. , -29.164669 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 44.074184 , 46.9762 , 44.228096 , 42.2335 ,\n",
" 41.102253 , 41.102367 , 42.233757 , 44.22849 ,\n",
" 44.321827 , 37.335304 , 40.924183 , 40.86467 ,\n",
" 40.924236 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -18.603981 ,\n",
" 1. , -29.797592 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 2. , 2. , 2. ,\n",
" 19.134174 , 22.76088 , 29.468704 , 42.88739 ,\n",
" 41.738823 , 41.739002 , 42.88781 , 44.913647 ,\n",
" 47.704174 , 51.135338 , 20.418388 , 12.470214 ,\n",
" 12.670923 , 0. ],\n",
" [ 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 30. , -19.07032 ,\n",
" 1. , -30.246218 , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 1. , 1. , 1. , 1. ,\n",
" 18.336487 , 21.81617 , 28.251017 , 42.977867 ,\n",
" 42.18994 , 42.19034 , 43.351707 , 45.399582 ,\n",
" 48.22037 , 51.68873 , 42.00719 , 41.94621 ,\n",
" 42.00739 , 0. ]], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state,_,_ = env.getSteps()\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"env.close()"
]
}
],