优化AIMemory运行效率
对应版本2.9 优化以下错误提示:UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor.
This commit is contained in:
parent
efb5c61f0d
commit
be1322381e
@ -107,12 +107,12 @@ class PPOMem:
|
||||
next_done=torch.Tensor([next_done[i]]).to(self.device),
|
||||
)
|
||||
# send memories to training datasets
|
||||
self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(self.ob_bf[i]).to(self.device)), 0)
|
||||
self.actions[roundTargetType] = torch.cat((self.actions[roundTargetType], torch.tensor(self.act_bf[i]).to(self.device)), 0)
|
||||
self.dis_logprobs[roundTargetType] = torch.cat((self.dis_logprobs[roundTargetType], torch.tensor(self.dis_logprobs_bf[i]).to(self.device)), 0)
|
||||
self.con_logprobs[roundTargetType] = torch.cat((self.con_logprobs[roundTargetType], torch.tensor(self.con_logprobs_bf[i]).to(self.device)), 0)
|
||||
self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(np.array(self.ob_bf[i])).to(self.device)), 0)
|
||||
self.actions[roundTargetType] = torch.cat((self.actions[roundTargetType], torch.tensor(np.array(self.act_bf[i])).to(self.device)), 0)
|
||||
self.dis_logprobs[roundTargetType] = torch.cat((self.dis_logprobs[roundTargetType], torch.tensor(np.array(self.dis_logprobs_bf[i])).to(self.device)), 0)
|
||||
self.con_logprobs[roundTargetType] = torch.cat((self.con_logprobs[roundTargetType], torch.tensor(np.array(self.con_logprobs_bf[i])).to(self.device)), 0)
|
||||
self.rewards[roundTargetType] = torch.cat((self.rewards[roundTargetType], thisRewardsTensor), 0)
|
||||
self.values[roundTargetType] = torch.cat((self.values[roundTargetType], torch.tensor(self.values_bf[i]).to(self.device)), 0)
|
||||
self.values[roundTargetType] = torch.cat((self.values[roundTargetType], torch.tensor(np.array(self.values_bf[i])).to(self.device)), 0)
|
||||
self.advantages[roundTargetType] = torch.cat((self.advantages[roundTargetType], adv), 0)
|
||||
self.returns[roundTargetType] = torch.cat((self.returns[roundTargetType], rt), 0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user