Totally disparate NN by target
Totally disparate NN by target.
This commit is contained in:
parent
cbc385ca10
commit
ad9817e7a4
@ -47,7 +47,7 @@ CLIP_COEF = 0.1
|
||||
POLICY_COEF = 1.0
|
||||
ENTROPY_COEF = 0.01
|
||||
CRITIC_COEF = 0.5
|
||||
TARGET_LEARNING_RATE = 5e-5
|
||||
TARGET_LEARNING_RATE = 1e-5
|
||||
|
||||
ANNEAL_LEARNING_RATE = True
|
||||
CLIP_VLOSS = True
|
||||
@ -159,23 +159,24 @@ class PPOAgent(nn.Module):
|
||||
self.discrete_shape = list(env.unity_discrete_branches)
|
||||
self.continuous_size = env.unity_continuous_size
|
||||
|
||||
self.network = nn.Sequential(
|
||||
layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 500)),
|
||||
self.network = nn.ModuleList([nn.Sequential(
|
||||
layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 300)),
|
||||
nn.ReLU(),
|
||||
layer_init(nn.Linear(500, 300)),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(300, self.discrete_size), std=0.01) for i in range(targetNum)])
|
||||
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(300, self.continuous_size), std=0.01) for i in range(targetNum)])
|
||||
layer_init(nn.Linear(300, 200)),
|
||||
nn.ReLU()) for i in range(targetNum)])
|
||||
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(200, self.discrete_size), std=0.01) for i in range(targetNum)])
|
||||
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=0.01) for i in range(targetNum)])
|
||||
self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])
|
||||
self.critic = layer_init(nn.Linear(300, 1), std=1)
|
||||
self.critic = nn.ModuleList([layer_init(nn.Linear(200, 1), std=1)for i in range(targetNum)])
|
||||
|
||||
def get_value(self, state: torch.Tensor):
|
||||
return self.critic(self.network(state))
|
||||
targets = state[:,0].to(torch.int32)
|
||||
hidden = torch.stack([self.network[targets[i]](state[i]) for i in range(targets.size()[0])])
|
||||
return torch.stack([self.critic[targets[i]](hidden[i])for i in range(targets.size()[0])])
|
||||
|
||||
def get_actions_value(self, state: torch.Tensor, actions=None):
|
||||
hidden = self.network(state)
|
||||
targets = state[:,0].to(torch.int32)
|
||||
hidden = torch.stack([self.network[targets[i]](state[i]) for i in range(targets.size()[0])])
|
||||
|
||||
# discrete
|
||||
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
||||
@ -188,6 +189,8 @@ class PPOAgent(nn.Module):
|
||||
# print(action_logstd)
|
||||
action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)
|
||||
con_probs = Normal(actions_mean, action_std)
|
||||
# critic
|
||||
criticV = torch.stack([self.critic[targets[i]](hidden[i])for i in range(targets.size()[0])])
|
||||
|
||||
if actions is None:
|
||||
if args.train:
|
||||
@ -207,13 +210,14 @@ class PPOAgent(nn.Module):
|
||||
[ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]
|
||||
)
|
||||
dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])
|
||||
|
||||
return (
|
||||
actions,
|
||||
dis_log_prob.sum(0),
|
||||
dis_entropy.sum(0),
|
||||
con_probs.log_prob(conAct).sum(1),
|
||||
con_probs.entropy().sum(1),
|
||||
self.critic(hidden),
|
||||
criticV,
|
||||
)
|
||||
|
||||
|
||||
@ -436,7 +440,7 @@ if __name__ == "__main__":
|
||||
thisRewardsTensor,
|
||||
torch.Tensor(dones_bf[i]).to(device),
|
||||
torch.tensor(values_bf[i]).to(device),
|
||||
torch.tensor(next_state[i]).to(device),
|
||||
torch.tensor([next_state[i]]).to(device),
|
||||
torch.Tensor([next_done[i]]).to(device),
|
||||
)
|
||||
# send memories to training datasets
|
||||
|
Loading…
Reference in New Issue
Block a user