Totally disparate NN by target

Totally disparate NN by target.
This commit is contained in:
Koha9 2022-12-03 21:35:33 +09:00
parent cbc385ca10
commit ad9817e7a4

View File

@ -47,7 +47,7 @@ CLIP_COEF = 0.1
POLICY_COEF = 1.0
ENTROPY_COEF = 0.01
CRITIC_COEF = 0.5
TARGET_LEARNING_RATE = 5e-5
TARGET_LEARNING_RATE = 1e-5
ANNEAL_LEARNING_RATE = True
CLIP_VLOSS = True
@ -159,23 +159,24 @@ class PPOAgent(nn.Module):
self.discrete_shape = list(env.unity_discrete_branches)
self.continuous_size = env.unity_continuous_size
self.network = nn.Sequential(
layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 500)),
self.network = nn.ModuleList([nn.Sequential(
layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 300)),
nn.ReLU(),
layer_init(nn.Linear(500, 300)),
nn.ReLU(),
)
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(300, self.discrete_size), std=0.01) for i in range(targetNum)])
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(300, self.continuous_size), std=0.01) for i in range(targetNum)])
layer_init(nn.Linear(300, 200)),
nn.ReLU()) for i in range(targetNum)])
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(200, self.discrete_size), std=0.01) for i in range(targetNum)])
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=0.01) for i in range(targetNum)])
self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])
self.critic = layer_init(nn.Linear(300, 1), std=1)
self.critic = nn.ModuleList([layer_init(nn.Linear(200, 1), std=1)for i in range(targetNum)])
def get_value(self, state: torch.Tensor):
return self.critic(self.network(state))
targets = state[:,0].to(torch.int32)
hidden = torch.stack([self.network[targets[i]](state[i]) for i in range(targets.size()[0])])
return torch.stack([self.critic[targets[i]](hidden[i])for i in range(targets.size()[0])])
def get_actions_value(self, state: torch.Tensor, actions=None):
hidden = self.network(state)
targets = state[:,0].to(torch.int32)
hidden = torch.stack([self.network[targets[i]](state[i]) for i in range(targets.size()[0])])
# discrete
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
@ -188,6 +189,8 @@ class PPOAgent(nn.Module):
# print(action_logstd)
action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)
con_probs = Normal(actions_mean, action_std)
# critic
criticV = torch.stack([self.critic[targets[i]](hidden[i])for i in range(targets.size()[0])])
if actions is None:
if args.train:
@ -207,13 +210,14 @@ class PPOAgent(nn.Module):
[ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]
)
dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
return (
actions,
dis_log_prob.sum(0),
dis_entropy.sum(0),
con_probs.log_prob(conAct).sum(1),
con_probs.entropy().sum(1),
self.critic(hidden),
criticV,
)
@ -436,7 +440,7 @@ if __name__ == "__main__":
thisRewardsTensor,
torch.Tensor(dones_bf[i]).to(device),
torch.tensor(values_bf[i]).to(device),
torch.tensor(next_state[i]).to(device),
torch.tensor([next_state[i]]).to(device),
torch.Tensor([next_done[i]]).to(device),
)
# send memories to training datasets