292 lines
12 KiB
Python
292 lines
12 KiB
Python
import numpy as np
|
|
import torch
|
|
import argparse
|
|
import time
|
|
|
|
from torch import nn
|
|
from aimbotEnv import Aimbot
|
|
from torch.distributions.normal import Normal
|
|
from torch.distributions.categorical import Categorical
|
|
|
|
|
|
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
|
nn.init.orthogonal_(layer.weight, std)
|
|
nn.init.constant_(layer.bias, bias_const)
|
|
return layer
|
|
|
|
neural_size_1 = 400
|
|
neural_size_2 = 300
|
|
|
|
class PPOAgent(nn.Module):
|
|
def __init__(
|
|
self,
|
|
env: Aimbot,
|
|
this_args: argparse.Namespace,
|
|
device: torch.device,
|
|
):
|
|
super(PPOAgent, self).__init__()
|
|
self.device = device
|
|
self.args = this_args
|
|
self.train_agent = self.args.train
|
|
self.target_num = self.args.target_num
|
|
self.unity_observation_shape = env.unity_observation_shape
|
|
self.unity_action_size = env.unity_action_size
|
|
self.state_size = self.unity_observation_shape[0]
|
|
self.agent_num = env.unity_agent_num
|
|
|
|
self.unity_discrete_type = env.unity_discrete_type
|
|
self.discrete_size = env.unity_discrete_size
|
|
self.discrete_shape = list(env.unity_discrete_branches)
|
|
self.continuous_size = env.unity_continuous_size
|
|
|
|
self.hidden_networks = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
layer_init(nn.Linear(self.state_size, neural_size_1)),
|
|
nn.LeakyReLU(),
|
|
layer_init(nn.Linear(neural_size_1, neural_size_2)),
|
|
nn.LeakyReLU(),
|
|
)
|
|
for i in range(self.target_num)
|
|
]
|
|
)
|
|
|
|
self.actor_dis = nn.ModuleList(
|
|
[layer_init(nn.Linear(neural_size_2, self.discrete_size), std=0.5) for i in range(self.target_num)]
|
|
)
|
|
self.actor_mean = nn.ModuleList(
|
|
[layer_init(nn.Linear(neural_size_2, self.continuous_size), std=0) for i in range(self.target_num)]
|
|
)
|
|
self.actor_logstd = nn.ParameterList(
|
|
[nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
|
|
)
|
|
self.critic = nn.ModuleList(
|
|
[layer_init(nn.Linear(neural_size_2, 1), std=0) for i in range(self.target_num)]
|
|
)
|
|
|
|
def get_value(self, state: torch.Tensor):
|
|
# get critic value
|
|
# state.size()[0] is batch_size
|
|
target = state[:, 0].to(torch.int32) # int
|
|
hidden_output = torch.stack(
|
|
[self.hidden_networks[target[i]](state[i]) for i in range(state.size()[0])]
|
|
)
|
|
criticV = torch.stack(
|
|
[self.critic[target[i]](hidden_output[i]) for i in range(state.size()[0])]
|
|
)
|
|
return criticV
|
|
|
|
def get_actions_value(self, state: torch.Tensor, actions=None):
|
|
# get actions and value
|
|
target = state[:, 0].to(torch.int32) # int
|
|
hidden_output = torch.stack(
|
|
[self.hidden_networks[target[i]](state[i]) for i in range(target.size()[0])]
|
|
)
|
|
|
|
# discrete
|
|
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
|
dis_logits = torch.stack(
|
|
[self.actor_dis[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
|
)
|
|
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
|
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
|
# continuous
|
|
actions_mean = torch.stack(
|
|
[self.actor_mean[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
|
) # self.actor_mean(hidden)
|
|
action_logstd = torch.stack(
|
|
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(target.size()[0])]
|
|
)
|
|
# print(action_logstd)
|
|
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
|
|
con_probs = Normal(actions_mean, action_std)
|
|
# critic
|
|
criticV = torch.stack(
|
|
[self.critic[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
|
) # self.critic
|
|
|
|
if actions is None:
|
|
if self.train_agent:
|
|
# select actions base on probability distribution model
|
|
dis_act = torch.stack([ctgr.sample() for ctgr in multi_categoricals])
|
|
con_act = con_probs.sample()
|
|
actions = torch.cat([dis_act.T, con_act], dim=1)
|
|
else:
|
|
# select actions base on best probability distribution
|
|
dis_act = torch.stack([torch.argmax(logit, dim=1) for logit in split_logits])
|
|
con_act = actions_mean
|
|
actions = torch.cat([dis_act.T, con_act], dim=1)
|
|
else:
|
|
dis_act = actions[:, 0: self.unity_discrete_type].T
|
|
con_act = actions[:, self.unity_discrete_type:]
|
|
dis_log_prob = torch.stack(
|
|
[ctgr.log_prob(act) for act, ctgr in zip(dis_act, multi_categoricals)]
|
|
)
|
|
dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
|
|
return (
|
|
actions,
|
|
dis_log_prob.sum(0),
|
|
dis_entropy.sum(0),
|
|
con_probs.log_prob(con_act).sum(1),
|
|
con_probs.entropy().sum(1),
|
|
criticV,
|
|
)
|
|
|
|
def train_net(self, this_train_ind: int, ppo_memories, optimizer) -> tuple:
|
|
start_time = time.time()
|
|
# flatten the batch
|
|
b_obs = ppo_memories.obs[this_train_ind].reshape((-1,) + self.unity_observation_shape)
|
|
b_dis_logprobs = ppo_memories.dis_logprobs[this_train_ind].reshape(-1)
|
|
b_con_logprobs = ppo_memories.con_logprobs[this_train_ind].reshape(-1)
|
|
b_actions = ppo_memories.actions[this_train_ind].reshape((-1,) + (self.unity_action_size,))
|
|
b_advantages = ppo_memories.advantages[this_train_ind].reshape(-1)
|
|
b_returns = ppo_memories.returns[this_train_ind].reshape(-1)
|
|
b_values = ppo_memories.values[this_train_ind].reshape(-1)
|
|
b_size = b_obs.size()[0]
|
|
# optimizing the policy and value network
|
|
b_index = np.arange(b_size)
|
|
|
|
for epoch in range(self.args.epochs):
|
|
print("epoch:", epoch, end="")
|
|
# shuffle all datasets
|
|
np.random.shuffle(b_index)
|
|
for start in range(0, b_size, self.args.minibatchSize):
|
|
print(".", end="")
|
|
end = start + self.args.minibatchSize
|
|
mb_index = b_index[start:end]
|
|
if np.size(mb_index) <= 1:
|
|
break
|
|
mb_advantages = b_advantages[mb_index]
|
|
|
|
# normalize advantages
|
|
if self.args.norm_adv:
|
|
mb_advantages = (mb_advantages - mb_advantages.mean()) / (
|
|
mb_advantages.std() + 1e-8
|
|
)
|
|
|
|
(
|
|
_,
|
|
new_dis_logprob,
|
|
dis_entropy,
|
|
new_con_logprob,
|
|
con_entropy,
|
|
new_value,
|
|
) = self.get_actions_value(b_obs[mb_index], b_actions[mb_index])
|
|
# discrete ratio
|
|
dis_log_ratio = new_dis_logprob - b_dis_logprobs[mb_index]
|
|
dis_ratio = dis_log_ratio.exp()
|
|
# continuous ratio
|
|
con_log_ratio = new_con_logprob - b_con_logprobs[mb_index]
|
|
con_ratio = con_log_ratio.exp()
|
|
|
|
"""
|
|
# early stop
|
|
with torch.no_grad():
|
|
# calculate approx_kl http://joschu.net/blog/kl-approx.html
|
|
old_approx_kl = (-logratio).mean()
|
|
approx_kl = ((ratio - 1) - logratio).mean()
|
|
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
|
|
"""
|
|
|
|
# discrete Policy loss
|
|
dis_pg_loss_orig = -mb_advantages * dis_ratio
|
|
dis_pg_loss_clip = -mb_advantages * torch.clamp(
|
|
dis_ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef
|
|
)
|
|
dis_pg_loss = torch.max(dis_pg_loss_orig, dis_pg_loss_clip).mean()
|
|
# continuous Policy loss
|
|
con_pg_loss_orig = -mb_advantages * con_ratio
|
|
con_pg_loss_clip = -mb_advantages * torch.clamp(
|
|
con_ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef
|
|
)
|
|
con_pg_loss = torch.max(con_pg_loss_orig, con_pg_loss_clip).mean()
|
|
|
|
# Value loss
|
|
new_value = new_value.view(-1)
|
|
if self.args.clip_vloss:
|
|
v_loss_unclipped = (new_value - b_returns[mb_index]) ** 2
|
|
v_clipped = b_values[mb_index] + torch.clamp(
|
|
new_value - b_values[mb_index],
|
|
-self.args.clip_coef,
|
|
self.args.clip_coef,
|
|
)
|
|
v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
|
|
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
|
v_loss = 0.5 * v_loss_max.mean()
|
|
else:
|
|
v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()
|
|
|
|
# total loss
|
|
entropy_loss = dis_entropy.mean() + con_entropy.mean()
|
|
loss = (
|
|
dis_pg_loss * self.args.policy_coef[this_train_ind]
|
|
+ con_pg_loss * self.args.policy_coef[this_train_ind]
|
|
+ entropy_loss * self.args.entropy_coef[this_train_ind]
|
|
+ v_loss * self.args.critic_coef[this_train_ind]
|
|
) * self.args.loss_coef[this_train_ind]
|
|
|
|
if torch.isnan(loss).any():
|
|
print("LOSS Include NAN!!!")
|
|
if torch.isnan(dis_pg_loss.any()):
|
|
print("dis_pg_loss include nan")
|
|
if torch.isnan(con_pg_loss.any()):
|
|
print("con_pg_loss include nan")
|
|
if torch.isnan(entropy_loss.any()):
|
|
print("entropy_loss include nan")
|
|
if torch.isnan(v_loss.any()):
|
|
print("v_loss include nan")
|
|
raise
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
# Clips gradient norm of an iterable of parameters.
|
|
nn.utils.clip_grad_norm_(self.parameters(), self.args.max_grad_norm)
|
|
optimizer.step()
|
|
|
|
"""
|
|
if args.target_kl is not None:
|
|
if approx_kl > args.target_kl:
|
|
break
|
|
"""
|
|
return v_loss, dis_pg_loss, con_pg_loss, loss, entropy_loss
|
|
|
|
def gae(
|
|
self,
|
|
rewards: torch.Tensor,
|
|
dones: torch.Tensor,
|
|
values: torch.Tensor,
|
|
next_obs: torch.Tensor,
|
|
next_done: torch.Tensor,
|
|
) -> tuple:
|
|
# GAE
|
|
with torch.no_grad():
|
|
next_value = self.get_value(next_obs).reshape(1, -1)
|
|
data_size = rewards.size()[0]
|
|
if self.args.gae:
|
|
advantages = torch.zeros_like(rewards).to(self.device)
|
|
last_gae_lam = 0
|
|
for t in reversed(range(data_size)):
|
|
if t == data_size - 1:
|
|
next_non_terminal = 1.0 - next_done
|
|
next_values = next_value
|
|
else:
|
|
next_non_terminal = 1.0 - dones[t + 1]
|
|
next_values = values[t + 1]
|
|
delta = rewards[t] + self.args.gamma * next_values * next_non_terminal - values[t]
|
|
advantages[t] = last_gae_lam = (
|
|
delta + self.args.gamma * self.args.gaeLambda * next_non_terminal * last_gae_lam
|
|
)
|
|
returns = advantages + values
|
|
else:
|
|
returns = torch.zeros_like(rewards).to(self.device)
|
|
for t in reversed(range(data_size)):
|
|
if t == data_size - 1:
|
|
next_non_terminal = 1.0 - next_done
|
|
next_return = next_value
|
|
else:
|
|
next_non_terminal = 1.0 - dones[t + 1]
|
|
next_return = returns[t + 1]
|
|
returns[t] = rewards[t] + self.args.gamma * next_non_terminal * next_return
|
|
advantages = returns - values
|
|
return advantages, returns
|