优化AIMemory运行效率

对应版本2.9
优化以下错误提示:UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor.
This commit is contained in:
Koha9 2023-07-24 17:49:45 +09:00
parent efb5c61f0d
commit be1322381e

View File

@ -107,12 +107,12 @@ class PPOMem:
next_done=torch.Tensor([next_done[i]]).to(self.device),
)
# send memories to training datasets
self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(self.ob_bf[i]).to(self.device)), 0)
self.actions[roundTargetType] = torch.cat((self.actions[roundTargetType], torch.tensor(self.act_bf[i]).to(self.device)), 0)
self.dis_logprobs[roundTargetType] = torch.cat((self.dis_logprobs[roundTargetType], torch.tensor(self.dis_logprobs_bf[i]).to(self.device)), 0)
self.con_logprobs[roundTargetType] = torch.cat((self.con_logprobs[roundTargetType], torch.tensor(self.con_logprobs_bf[i]).to(self.device)), 0)
self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(np.array(self.ob_bf[i])).to(self.device)), 0)
self.actions[roundTargetType] = torch.cat((self.actions[roundTargetType], torch.tensor(np.array(self.act_bf[i])).to(self.device)), 0)
self.dis_logprobs[roundTargetType] = torch.cat((self.dis_logprobs[roundTargetType], torch.tensor(np.array(self.dis_logprobs_bf[i])).to(self.device)), 0)
self.con_logprobs[roundTargetType] = torch.cat((self.con_logprobs[roundTargetType], torch.tensor(np.array(self.con_logprobs_bf[i])).to(self.device)), 0)
self.rewards[roundTargetType] = torch.cat((self.rewards[roundTargetType], thisRewardsTensor), 0)
self.values[roundTargetType] = torch.cat((self.values[roundTargetType], torch.tensor(self.values_bf[i]).to(self.device)), 0)
self.values[roundTargetType] = torch.cat((self.values[roundTargetType], torch.tensor(np.array(self.values_bf[i])).to(self.device)), 0)
self.advantages[roundTargetType] = torch.cat((self.advantages[roundTargetType], adv), 0)
self.returns[roundTargetType] = torch.cat((self.returns[roundTargetType], rt), 0)