完全分离NN
使用根据Target完全分离的NN进行预测和训练。 修改NN构造,修改预测预测算法以配合Full Multi NN
This commit is contained in:
parent
52ccce88bc
commit
9432eaa76e
@ -47,9 +47,8 @@ if __name__ == "__main__":
|
||||
# freeze
|
||||
if args.freeze_viewnet:
|
||||
# freeze the view network
|
||||
for p in agent.viewNetwork.parameters():
|
||||
p.requires_grad = False
|
||||
print("VIEW NETWORK FREEZE")
|
||||
print("FREEZE VIEW NETWORK is not compatible with Full MNN!")
|
||||
raise NotImplementedError
|
||||
print("Load Agent", args.load_dir)
|
||||
print(agent.eval())
|
||||
# optimizer
|
||||
|
@ -10,8 +10,8 @@ WORKER_ID = 1
|
||||
BASE_PORT = 1000
|
||||
|
||||
# tensorboard names
|
||||
GAME_NAME = "Aimbot_Target_Hybrid_PMNN_V3"
|
||||
GAME_TYPE = "Mix_Verification"
|
||||
GAME_NAME = "Aimbot_Target_Hybrid_Full_MNN_V1"
|
||||
GAME_TYPE = "Mix_Train"
|
||||
|
||||
# max round steps per agent is 2500/Decision_period, 25 seconds
|
||||
TOTAL_STEPS = 3150000
|
||||
@ -27,16 +27,16 @@ LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
|
||||
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
|
||||
ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]
|
||||
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
|
||||
TARGET_LEARNING_RATE = 1e-6
|
||||
TARGET_LEARNING_RATE = 1e-5
|
||||
|
||||
FREEZE_VIEW_NETWORK = False
|
||||
BROADCASTREWARD = False
|
||||
ANNEAL_LEARNING_RATE = True
|
||||
CLIP_VLOSS = True
|
||||
NORM_ADV = False
|
||||
TRAIN = False
|
||||
SAVE_MODEL = False
|
||||
WANDB_TACK = False
|
||||
TRAIN = True
|
||||
SAVE_MODEL = True
|
||||
WANDB_TACK = True
|
||||
LOAD_DIR = None
|
||||
#LOAD_DIR = "../PPO-Model/PList_Go_LeakyReLU_9331_1677965178_bestGoto/PList_Go_LeakyReLU_9331_1677965178_10.709002.pt"
|
||||
|
||||
|
@ -46,84 +46,70 @@ class PPOAgent(nn.Module):
|
||||
self.discrete_shape = list(env.unity_discrete_branches)
|
||||
self.continuous_size = env.unity_continuous_size
|
||||
|
||||
self.view_network = nn.Sequential(layer_init(nn.Linear(self.ray_state_size, 200)), nn.LeakyReLU())
|
||||
self.target_networks = nn.ModuleList(
|
||||
self.hidden_networks = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(layer_init(nn.Linear(self.state_size_without_ray, 100)), nn.LeakyReLU())
|
||||
for i in range(self.target_num)
|
||||
]
|
||||
)
|
||||
self.middle_networks = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(layer_init(nn.Linear(300, 200)), nn.LeakyReLU())
|
||||
nn.Sequential(
|
||||
layer_init(nn.Linear(self.state_size, 128)),
|
||||
nn.LeakyReLU(),
|
||||
layer_init(nn.Linear(128, 64)),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
for i in range(self.target_num)
|
||||
]
|
||||
)
|
||||
|
||||
self.actor_dis = nn.ModuleList(
|
||||
[layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(self.target_num)]
|
||||
[layer_init(nn.Linear(64, self.discrete_size), std=0.5) for i in range(self.target_num)]
|
||||
)
|
||||
self.actor_mean = nn.ModuleList(
|
||||
[layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(self.target_num)]
|
||||
[layer_init(nn.Linear(64, self.continuous_size), std=0.5) for i in range(self.target_num)]
|
||||
)
|
||||
self.actor_logstd = nn.ParameterList(
|
||||
[nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
|
||||
) # nn.Parameter(torch.zeros(1, self.continuous_size))
|
||||
)
|
||||
self.critic = nn.ModuleList(
|
||||
[layer_init(nn.Linear(200, 1), std=1) for i in range(self.target_num)]
|
||||
[layer_init(nn.Linear(64, 1), std=1) for i in range(self.target_num)]
|
||||
)
|
||||
|
||||
def get_value(self, state: torch.Tensor):
|
||||
# get critic value
|
||||
# state.size()[0] is batch_size
|
||||
target = state[:, 0].to(torch.int32) # int
|
||||
this_state_num = target.size()[0]
|
||||
view_input = state[:, -self.ray_state_size:] # all ray input
|
||||
target_input = state[:, : self.state_size_without_ray]
|
||||
view_layer = self.view_network(view_input)
|
||||
target_layer = torch.stack(
|
||||
[self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
|
||||
)
|
||||
middle_input = torch.cat([view_layer, target_layer], dim=1)
|
||||
middle_layer = torch.stack(
|
||||
[self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
|
||||
hidden_output = torch.stack(
|
||||
[self.hidden_networks[target[i]](state[i]) for i in range(state.size()[0])]
|
||||
)
|
||||
criticV = torch.stack(
|
||||
[self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
||||
) # self.critic
|
||||
[self.critic[target[i]](hidden_output[i]) for i in range(state.size()[0])]
|
||||
)
|
||||
return criticV
|
||||
|
||||
def get_actions_value(self, state: torch.Tensor, actions=None):
|
||||
# get actions and value
|
||||
target = state[:, 0].to(torch.int32) # int
|
||||
this_state_num = target.size()[0]
|
||||
view_input = state[:, -self.ray_state_size:] # all ray input
|
||||
target_input = state[:, : self.state_size_without_ray]
|
||||
view_layer = self.view_network(view_input)
|
||||
target_layer = torch.stack(
|
||||
[self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
|
||||
)
|
||||
middle_input = torch.cat([view_layer, target_layer], dim=1)
|
||||
middle_layer = torch.stack(
|
||||
[self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
|
||||
hidden_output = torch.stack(
|
||||
[self.hidden_networks[target[i]](state[i]) for i in range(target.size()[0])]
|
||||
)
|
||||
|
||||
# discrete
|
||||
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
||||
dis_logits = torch.stack(
|
||||
[self.actor_dis[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
||||
[self.actor_dis[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
||||
)
|
||||
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
||||
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
||||
# continuous
|
||||
actions_mean = torch.stack(
|
||||
[self.actor_mean[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
||||
[self.actor_mean[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
||||
) # self.actor_mean(hidden)
|
||||
action_logstd = torch.stack(
|
||||
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(this_state_num)]
|
||||
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(target.size()[0])]
|
||||
)
|
||||
# print(action_logstd)
|
||||
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
|
||||
con_probs = Normal(actions_mean, action_std)
|
||||
# critic
|
||||
criticV = torch.stack(
|
||||
[self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
||||
[self.critic[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
||||
) # self.critic
|
||||
|
||||
if actions is None:
|
||||
|
Loading…
Reference in New Issue
Block a user