完全分离NN
使用根据Target完全分离的NN进行预测和训练。 修改NN构造,修改预测预测算法以配合Full Multi NN
This commit is contained in:
parent
52ccce88bc
commit
9432eaa76e
@ -47,9 +47,8 @@ if __name__ == "__main__":
|
|||||||
# freeze
|
# freeze
|
||||||
if args.freeze_viewnet:
|
if args.freeze_viewnet:
|
||||||
# freeze the view network
|
# freeze the view network
|
||||||
for p in agent.viewNetwork.parameters():
|
print("FREEZE VIEW NETWORK is not compatible with Full MNN!")
|
||||||
p.requires_grad = False
|
raise NotImplementedError
|
||||||
print("VIEW NETWORK FREEZE")
|
|
||||||
print("Load Agent", args.load_dir)
|
print("Load Agent", args.load_dir)
|
||||||
print(agent.eval())
|
print(agent.eval())
|
||||||
# optimizer
|
# optimizer
|
||||||
|
@ -10,8 +10,8 @@ WORKER_ID = 1
|
|||||||
BASE_PORT = 1000
|
BASE_PORT = 1000
|
||||||
|
|
||||||
# tensorboard names
|
# tensorboard names
|
||||||
GAME_NAME = "Aimbot_Target_Hybrid_PMNN_V3"
|
GAME_NAME = "Aimbot_Target_Hybrid_Full_MNN_V1"
|
||||||
GAME_TYPE = "Mix_Verification"
|
GAME_TYPE = "Mix_Train"
|
||||||
|
|
||||||
# max round steps per agent is 2500/Decision_period, 25 seconds
|
# max round steps per agent is 2500/Decision_period, 25 seconds
|
||||||
TOTAL_STEPS = 3150000
|
TOTAL_STEPS = 3150000
|
||||||
@ -27,16 +27,16 @@ LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
|
|||||||
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
|
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
|
||||||
ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]
|
ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]
|
||||||
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
|
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
|
||||||
TARGET_LEARNING_RATE = 1e-6
|
TARGET_LEARNING_RATE = 1e-5
|
||||||
|
|
||||||
FREEZE_VIEW_NETWORK = False
|
FREEZE_VIEW_NETWORK = False
|
||||||
BROADCASTREWARD = False
|
BROADCASTREWARD = False
|
||||||
ANNEAL_LEARNING_RATE = True
|
ANNEAL_LEARNING_RATE = True
|
||||||
CLIP_VLOSS = True
|
CLIP_VLOSS = True
|
||||||
NORM_ADV = False
|
NORM_ADV = False
|
||||||
TRAIN = False
|
TRAIN = True
|
||||||
SAVE_MODEL = False
|
SAVE_MODEL = True
|
||||||
WANDB_TACK = False
|
WANDB_TACK = True
|
||||||
LOAD_DIR = None
|
LOAD_DIR = None
|
||||||
#LOAD_DIR = "../PPO-Model/PList_Go_LeakyReLU_9331_1677965178_bestGoto/PList_Go_LeakyReLU_9331_1677965178_10.709002.pt"
|
#LOAD_DIR = "../PPO-Model/PList_Go_LeakyReLU_9331_1677965178_bestGoto/PList_Go_LeakyReLU_9331_1677965178_10.709002.pt"
|
||||||
|
|
||||||
|
@ -46,84 +46,70 @@ class PPOAgent(nn.Module):
|
|||||||
self.discrete_shape = list(env.unity_discrete_branches)
|
self.discrete_shape = list(env.unity_discrete_branches)
|
||||||
self.continuous_size = env.unity_continuous_size
|
self.continuous_size = env.unity_continuous_size
|
||||||
|
|
||||||
self.view_network = nn.Sequential(layer_init(nn.Linear(self.ray_state_size, 200)), nn.LeakyReLU())
|
self.hidden_networks = nn.ModuleList(
|
||||||
self.target_networks = nn.ModuleList(
|
|
||||||
[
|
[
|
||||||
nn.Sequential(layer_init(nn.Linear(self.state_size_without_ray, 100)), nn.LeakyReLU())
|
nn.Sequential(
|
||||||
for i in range(self.target_num)
|
layer_init(nn.Linear(self.state_size, 128)),
|
||||||
]
|
nn.LeakyReLU(),
|
||||||
)
|
layer_init(nn.Linear(128, 64)),
|
||||||
self.middle_networks = nn.ModuleList(
|
nn.LeakyReLU(),
|
||||||
[
|
)
|
||||||
nn.Sequential(layer_init(nn.Linear(300, 200)), nn.LeakyReLU())
|
|
||||||
for i in range(self.target_num)
|
for i in range(self.target_num)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.actor_dis = nn.ModuleList(
|
self.actor_dis = nn.ModuleList(
|
||||||
[layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(self.target_num)]
|
[layer_init(nn.Linear(64, self.discrete_size), std=0.5) for i in range(self.target_num)]
|
||||||
)
|
)
|
||||||
self.actor_mean = nn.ModuleList(
|
self.actor_mean = nn.ModuleList(
|
||||||
[layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(self.target_num)]
|
[layer_init(nn.Linear(64, self.continuous_size), std=0.5) for i in range(self.target_num)]
|
||||||
)
|
)
|
||||||
self.actor_logstd = nn.ParameterList(
|
self.actor_logstd = nn.ParameterList(
|
||||||
[nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
|
[nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
|
||||||
) # nn.Parameter(torch.zeros(1, self.continuous_size))
|
)
|
||||||
self.critic = nn.ModuleList(
|
self.critic = nn.ModuleList(
|
||||||
[layer_init(nn.Linear(200, 1), std=1) for i in range(self.target_num)]
|
[layer_init(nn.Linear(64, 1), std=1) for i in range(self.target_num)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_value(self, state: torch.Tensor):
|
def get_value(self, state: torch.Tensor):
|
||||||
|
# get critic value
|
||||||
|
# state.size()[0] is batch_size
|
||||||
target = state[:, 0].to(torch.int32) # int
|
target = state[:, 0].to(torch.int32) # int
|
||||||
this_state_num = target.size()[0]
|
hidden_output = torch.stack(
|
||||||
view_input = state[:, -self.ray_state_size:] # all ray input
|
[self.hidden_networks[target[i]](state[i]) for i in range(state.size()[0])]
|
||||||
target_input = state[:, : self.state_size_without_ray]
|
|
||||||
view_layer = self.view_network(view_input)
|
|
||||||
target_layer = torch.stack(
|
|
||||||
[self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
|
|
||||||
)
|
|
||||||
middle_input = torch.cat([view_layer, target_layer], dim=1)
|
|
||||||
middle_layer = torch.stack(
|
|
||||||
[self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
|
|
||||||
)
|
)
|
||||||
criticV = torch.stack(
|
criticV = torch.stack(
|
||||||
[self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
[self.critic[target[i]](hidden_output[i]) for i in range(state.size()[0])]
|
||||||
) # self.critic
|
)
|
||||||
return criticV
|
return criticV
|
||||||
|
|
||||||
def get_actions_value(self, state: torch.Tensor, actions=None):
|
def get_actions_value(self, state: torch.Tensor, actions=None):
|
||||||
|
# get actions and value
|
||||||
target = state[:, 0].to(torch.int32) # int
|
target = state[:, 0].to(torch.int32) # int
|
||||||
this_state_num = target.size()[0]
|
hidden_output = torch.stack(
|
||||||
view_input = state[:, -self.ray_state_size:] # all ray input
|
[self.hidden_networks[target[i]](state[i]) for i in range(target.size()[0])]
|
||||||
target_input = state[:, : self.state_size_without_ray]
|
|
||||||
view_layer = self.view_network(view_input)
|
|
||||||
target_layer = torch.stack(
|
|
||||||
[self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
|
|
||||||
)
|
|
||||||
middle_input = torch.cat([view_layer, target_layer], dim=1)
|
|
||||||
middle_layer = torch.stack(
|
|
||||||
[self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# discrete
|
# discrete
|
||||||
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
||||||
dis_logits = torch.stack(
|
dis_logits = torch.stack(
|
||||||
[self.actor_dis[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
[self.actor_dis[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
||||||
)
|
)
|
||||||
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
||||||
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
||||||
# continuous
|
# continuous
|
||||||
actions_mean = torch.stack(
|
actions_mean = torch.stack(
|
||||||
[self.actor_mean[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
[self.actor_mean[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
||||||
) # self.actor_mean(hidden)
|
) # self.actor_mean(hidden)
|
||||||
action_logstd = torch.stack(
|
action_logstd = torch.stack(
|
||||||
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(this_state_num)]
|
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(target.size()[0])]
|
||||||
)
|
)
|
||||||
# print(action_logstd)
|
# print(action_logstd)
|
||||||
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
|
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
|
||||||
con_probs = Normal(actions_mean, action_std)
|
con_probs = Normal(actions_mean, action_std)
|
||||||
# critic
|
# critic
|
||||||
criticV = torch.stack(
|
criticV = torch.stack(
|
||||||
[self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
|
[self.critic[target[i]](hidden_output[i]) for i in range(target.size()[0])]
|
||||||
) # self.critic
|
) # self.critic
|
||||||
|
|
||||||
if actions is None:
|
if actions is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user