267 lines
11 KiB
Python
267 lines
11 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
import uuid
|
||
|
import airecorder
|
||
|
from torch import nn
|
||
|
from typing import List
|
||
|
from aimbotEnv import Aimbot
|
||
|
from torch.distributions.normal import Normal
|
||
|
from torch.distributions.categorical import Categorical
|
||
|
from mlagents_envs.side_channel.side_channel import (
|
||
|
SideChannel,
|
||
|
IncomingMessage,
|
||
|
OutgoingMessage,
|
||
|
)
|
||
|
|
||
|
|
||
|
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
||
|
nn.init.orthogonal_(layer.weight, std)
|
||
|
nn.init.constant_(layer.bias, bias_const)
|
||
|
return layer
|
||
|
|
||
|
|
||
|
class PPOAgent(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
env: Aimbot,
|
||
|
trainAgent: bool,
|
||
|
targetNum: int,
|
||
|
target_state_size: int,
|
||
|
time_state_size: int,
|
||
|
gun_state_size: int,
|
||
|
my_state_size: int,
|
||
|
total_t_size: int,
|
||
|
):
|
||
|
super(PPOAgent, self).__init__()
|
||
|
self.trainAgent = trainAgent
|
||
|
self.targetNum = targetNum
|
||
|
self.stateSize = env.unity_observation_shape[0]
|
||
|
self.agentNum = env.unity_agent_num
|
||
|
self.targetSize = target_state_size
|
||
|
self.timeSize = time_state_size
|
||
|
self.gunSize = gun_state_size
|
||
|
self.myStateSize = my_state_size
|
||
|
self.raySize = env.unity_observation_shape[0] - total_t_size
|
||
|
self.nonRaySize = total_t_size
|
||
|
self.head_input_size = (
|
||
|
env.unity_observation_shape[0] - self.targetSize - self.timeSize - self.gunSize
|
||
|
) # except target state input
|
||
|
|
||
|
self.unityDiscreteType = env.unity_discrete_type
|
||
|
self.discrete_size = env.unity_discrete_size
|
||
|
self.discrete_shape = list(env.unity_discrete_branches)
|
||
|
self.continuous_size = env.unity_continuous_size
|
||
|
|
||
|
self.viewNetwork = nn.Sequential(layer_init(nn.Linear(self.raySize, 200)), nn.LeakyReLU())
|
||
|
self.targetNetworks = nn.ModuleList(
|
||
|
[
|
||
|
nn.Sequential(layer_init(nn.Linear(self.nonRaySize, 100)), nn.LeakyReLU())
|
||
|
for i in range(targetNum)
|
||
|
]
|
||
|
)
|
||
|
self.middleNetworks = nn.ModuleList(
|
||
|
[
|
||
|
nn.Sequential(layer_init(nn.Linear(300, 200)), nn.LeakyReLU())
|
||
|
for i in range(targetNum)
|
||
|
]
|
||
|
)
|
||
|
self.actor_dis = nn.ModuleList(
|
||
|
[layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(targetNum)]
|
||
|
)
|
||
|
self.actor_mean = nn.ModuleList(
|
||
|
[layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(targetNum)]
|
||
|
)
|
||
|
# self.actor_logstd = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=1) for i in range(targetNum)])
|
||
|
# self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))
|
||
|
self.actor_logstd = nn.ParameterList(
|
||
|
[nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)]
|
||
|
) # nn.Parameter(torch.zeros(1, self.continuous_size))
|
||
|
self.critic = nn.ModuleList(
|
||
|
[layer_init(nn.Linear(200, 1), std=1) for i in range(targetNum)]
|
||
|
)
|
||
|
|
||
|
def get_value(self, state: torch.Tensor):
|
||
|
target = state[:, 0].to(torch.int32) # int
|
||
|
thisStateNum = target.size()[0]
|
||
|
viewInput = state[:, -self.raySize :] # all ray input
|
||
|
targetInput = state[:, : self.nonRaySize]
|
||
|
viewLayer = self.viewNetwork(viewInput)
|
||
|
targetLayer = torch.stack(
|
||
|
[self.targetNetworks[target[i]](targetInput[i]) for i in range(thisStateNum)]
|
||
|
)
|
||
|
middleInput = torch.cat([viewLayer, targetLayer], dim=1)
|
||
|
middleLayer = torch.stack(
|
||
|
[self.middleNetworks[target[i]](middleInput[i]) for i in range(thisStateNum)]
|
||
|
)
|
||
|
criticV = torch.stack(
|
||
|
[self.critic[target[i]](middleLayer[i]) for i in range(thisStateNum)]
|
||
|
) # self.critic
|
||
|
return criticV
|
||
|
|
||
|
def get_actions_value(self, state: torch.Tensor, actions=None):
|
||
|
target = state[:, 0].to(torch.int32) # int
|
||
|
thisStateNum = target.size()[0]
|
||
|
viewInput = state[:, -self.raySize :] # all ray input
|
||
|
targetInput = state[:, : self.nonRaySize]
|
||
|
viewLayer = self.viewNetwork(viewInput)
|
||
|
targetLayer = torch.stack(
|
||
|
[self.targetNetworks[target[i]](targetInput[i]) for i in range(thisStateNum)]
|
||
|
)
|
||
|
middleInput = torch.cat([viewLayer, targetLayer], dim=1)
|
||
|
middleLayer = torch.stack(
|
||
|
[self.middleNetworks[target[i]](middleInput[i]) for i in range(thisStateNum)]
|
||
|
)
|
||
|
|
||
|
# discrete
|
||
|
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
||
|
dis_logits = torch.stack(
|
||
|
[self.actor_dis[target[i]](middleLayer[i]) for i in range(thisStateNum)]
|
||
|
)
|
||
|
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
||
|
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
||
|
# continuous
|
||
|
actions_mean = torch.stack(
|
||
|
[self.actor_mean[target[i]](middleLayer[i]) for i in range(thisStateNum)]
|
||
|
) # self.actor_mean(hidden)
|
||
|
# action_logstd = torch.stack([self.actor_logstd[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.actor_logstd(hidden)
|
||
|
# action_logstd = self.actor_logstd.expand_as(actions_mean) # self.actor_logstd.expand_as(actions_mean)
|
||
|
action_logstd = torch.stack(
|
||
|
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(thisStateNum)]
|
||
|
)
|
||
|
# print(action_logstd)
|
||
|
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
|
||
|
con_probs = Normal(actions_mean, action_std)
|
||
|
# critic
|
||
|
criticV = torch.stack(
|
||
|
[self.critic[target[i]](middleLayer[i]) for i in range(thisStateNum)]
|
||
|
) # self.critic
|
||
|
|
||
|
if actions is None:
|
||
|
if self.trainAgent:
|
||
|
# select actions base on probability distribution model
|
||
|
disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
|
||
|
conAct = con_probs.sample()
|
||
|
actions = torch.cat([disAct.T, conAct], dim=1)
|
||
|
else:
|
||
|
# select actions base on best probability distribution
|
||
|
# disAct = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])
|
||
|
conAct = actions_mean
|
||
|
disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
|
||
|
conAct = con_probs.sample()
|
||
|
actions = torch.cat([disAct.T, conAct], dim=1)
|
||
|
else:
|
||
|
disAct = actions[:, 0 : self.unityDiscreteType].T
|
||
|
conAct = actions[:, self.unityDiscreteType :]
|
||
|
dis_log_prob = torch.stack(
|
||
|
[ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]
|
||
|
)
|
||
|
dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
|
||
|
return (
|
||
|
actions,
|
||
|
dis_log_prob.sum(0),
|
||
|
dis_entropy.sum(0),
|
||
|
con_probs.log_prob(conAct).sum(1),
|
||
|
con_probs.entropy().sum(1),
|
||
|
criticV,
|
||
|
)
|
||
|
|
||
|
|
||
|
def GAE(agent, args, rewards, dones, values, next_obs, next_done, device):
|
||
|
# GAE
|
||
|
with torch.no_grad():
|
||
|
next_value = agent.get_value(next_obs).reshape(1, -1)
|
||
|
data_size = rewards.size()[0]
|
||
|
if args.gae:
|
||
|
advantages = torch.zeros_like(rewards).to(device)
|
||
|
lastgaelam = 0
|
||
|
for t in reversed(range(data_size)):
|
||
|
if t == data_size - 1:
|
||
|
nextnonterminal = 1.0 - next_done
|
||
|
nextvalues = next_value
|
||
|
else:
|
||
|
nextnonterminal = 1.0 - dones[t + 1]
|
||
|
nextvalues = values[t + 1]
|
||
|
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
|
||
|
advantages[t] = lastgaelam = (
|
||
|
delta + args.gamma * args.gaeLambda * nextnonterminal * lastgaelam
|
||
|
)
|
||
|
returns = advantages + values
|
||
|
else:
|
||
|
returns = torch.zeros_like(rewards).to(device)
|
||
|
for t in reversed(range(data_size)):
|
||
|
if t == data_size - 1:
|
||
|
nextnonterminal = 1.0 - next_done
|
||
|
next_return = next_value
|
||
|
else:
|
||
|
nextnonterminal = 1.0 - dones[t + 1]
|
||
|
next_return = returns[t + 1]
|
||
|
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
|
||
|
advantages = returns - values
|
||
|
return advantages, returns
|
||
|
|
||
|
|
||
|
class AimbotSideChannel(SideChannel):
|
||
|
def __init__(self, channel_id: uuid.UUID) -> None:
|
||
|
super().__init__(channel_id)
|
||
|
|
||
|
def on_message_received(self, msg: IncomingMessage) -> None:
|
||
|
global SCrecieved # make sure this variable is global
|
||
|
"""
|
||
|
Note: We must implement this method of the SideChannel interface to
|
||
|
receive messages from Unity
|
||
|
Message will be sent like this:
|
||
|
"Warning|Message1|Message2|Message3" or
|
||
|
"Error|Message1|Message2|Message3"
|
||
|
"""
|
||
|
thisMessage = msg.read_string()
|
||
|
thisResult = thisMessage.split("|")
|
||
|
if(thisResult[0] == "result"):
|
||
|
airecorder.total_rounds[thisResult[1]]+=1
|
||
|
if(thisResult[2] == "Win"):
|
||
|
airecorder.win_rounds[thisResult[1]]+=1
|
||
|
#print(TotalRounds)
|
||
|
#print(WinRounds)
|
||
|
elif(thisResult[0] == "Error"):
|
||
|
print(thisMessage)
|
||
|
|
||
|
# # while Message type is Warning
|
||
|
# if(thisResult[0] == "Warning"):
|
||
|
# # while Message1 is result means one game is over
|
||
|
# if (thisResult[1] == "Result"):
|
||
|
# TotalRounds[thisResult[2]]+=1
|
||
|
# # while Message3 is Win means this agent win this game
|
||
|
# if(thisResult[3] == "Win"):
|
||
|
# WinRounds[thisResult[2]]+=1
|
||
|
# # while Message1 is GameState means this game is just start
|
||
|
# # and tell python which game mode is
|
||
|
# elif (thisResult[1] == "GameState"):
|
||
|
# SCrecieved = 1
|
||
|
# # while Message type is Error
|
||
|
# elif(thisResult[0] == "Error"):
|
||
|
# print(thisMessage)
|
||
|
# 发送函数
|
||
|
def send_string(self, data: str) -> None:
|
||
|
# send a string toC#
|
||
|
msg = OutgoingMessage()
|
||
|
msg.write_string(data)
|
||
|
super().queue_message_to_send(msg)
|
||
|
|
||
|
def send_bool(self, data: bool) -> None:
|
||
|
msg = OutgoingMessage()
|
||
|
msg.write_bool(data)
|
||
|
super().queue_message_to_send(msg)
|
||
|
|
||
|
def send_int(self, data: int) -> None:
|
||
|
msg = OutgoingMessage()
|
||
|
msg.write_int32(data)
|
||
|
super().queue_message_to_send(msg)
|
||
|
|
||
|
def send_float(self, data: float) -> None:
|
||
|
msg = OutgoingMessage()
|
||
|
msg.write_float32(data)
|
||
|
super().queue_message_to_send(msg)
|
||
|
|
||
|
def send_float_list(self, data: List[float]) -> None:
|
||
|
msg = OutgoingMessage()
|
||
|
msg.write_float32_list(data)
|
||
|
super().queue_message_to_send(msg)
|