change network and fix trainset bug

change network and fix trainset bug
This commit is contained in:
Koha9 2022-12-14 02:44:01 +09:00
parent bf77060456
commit 3116831ae6
4 changed files with 127 additions and 120 deletions

View File

@ -26,7 +26,7 @@ from typing import List
bestReward = -1 bestReward = -1
DEFAULT_SEED = 9331 DEFAULT_SEED = 9331
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward-Easy-V2.5-FreeOnly-NormalMapSize/Aimbot-ParallelEnv" ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward-Easy-V2.7-FreeOnly-NormalMapSize/Aimbot-ParallelEnv"
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e") SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
WAND_ENTITY = "koha9" WAND_ENTITY = "koha9"
WORKER_ID = 1 WORKER_ID = 1
@ -36,17 +36,18 @@ BASE_PORT = 1000
# !!!check every parameters before run!!! # !!!check every parameters before run!!!
TOTAL_STEPS = 3150000 TOTAL_STEPS = 3150000
BATCH_SIZE = 256 BATCH_SIZE = 1024
MAX_TRAINNING_DATASETS = 6000 MAX_TRAINNING_DATASETS = 6000
DECISION_PERIOD = 1 DECISION_PERIOD = 2
LEARNING_RATE = 5e-4 FREEZE_HEAD_NETWORK = False
LEARNING_RATE = 1e-3
GAMMA = 0.99 GAMMA = 0.99
GAE_LAMBDA = 0.95 GAE_LAMBDA = 0.9
EPOCHS = 4 EPOCHS = 2
CLIP_COEF = 0.11 CLIP_COEF = 0.1
LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
POLICY_COEF = [1.0, 1.0, 1.0, 1.0] POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
ENTROPY_COEF = [0.1, 0.1, 0.1, 0.1] ENTROPY_COEF = [1.0, 1.0, 1.0, 1.0]
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5] CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
TARGET_LEARNING_RATE = 1e-6 TARGET_LEARNING_RATE = 1e-6
@ -57,7 +58,7 @@ TRAIN = True
WANDB_TACK = True WANDB_TACK = True
LOAD_DIR = None LOAD_DIR = None
#LOAD_DIR = "../PPO-Model/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670522099-freeonly-12/Aimbot-target-last.pt" #LOAD_DIR = "../PPO-Model/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670634636-freeonly-14/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670634636_-0.35597783.pt"
# public data # public data
class Targets(Enum): class Targets(Enum):
@ -70,7 +71,7 @@ TARGET_STATE_SIZE = 7 # 6+1
TIME_STATE_SIZE = 1 TIME_STATE_SIZE = 1
GUN_STATE_SIZE = 1 GUN_STATE_SIZE = 1
MY_STATE_SIZE = 4 MY_STATE_SIZE = 4
TOTAL_MIDDLE_STATE_SIZE = TARGET_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE TOTAL_T_STATE_SIZE = TARGET_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE
BASE_WINREWARD = 999 BASE_WINREWARD = 999
BASE_LOSEREWARD = -999 BASE_LOSEREWARD = -999
TARGETNUM= 4 TARGETNUM= 4
@ -106,6 +107,8 @@ def parse_args():
# model parameters # model parameters
parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True, parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True,
help="Train Model or not") help="Train Model or not")
parser.add_argument("--freeze-headnet", type=lambda x: bool(strtobool(x)), default=FREEZE_HEAD_NETWORK, nargs="?", const=True,
help="freeze head network or not")
parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS, parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS,
help="training dataset size,start training while dataset collect enough data") help="training dataset size,start training while dataset collect enough data")
parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE, parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE,
@ -167,49 +170,52 @@ class PPOAgent(nn.Module):
self.timeSize = TIME_STATE_SIZE self.timeSize = TIME_STATE_SIZE
self.gunSize = GUN_STATE_SIZE self.gunSize = GUN_STATE_SIZE
self.myStateSize = MY_STATE_SIZE self.myStateSize = MY_STATE_SIZE
self.totalMiddleSize = TOTAL_MIDDLE_STATE_SIZE self.totalTSize = TOTAL_T_STATE_SIZE
self.head_input_size = env.unity_observation_shape[0] - self.targetSize-self.timeSize-self.gunSize# except target state input self.targetInputSize = TOTAL_T_STATE_SIZE - TIME_STATE_SIZE - 1 # all target except time and target state
self.totalRaySize = env.unity_observation_shape[0] - TOTAL_T_STATE_SIZE
self.criticInputSize = env.unity_observation_shape[0] - TIME_STATE_SIZE - 1 # all except time and target state
self.discrete_size = env.unity_discrete_size self.discrete_size = env.unity_discrete_size
self.discrete_shape = list(env.unity_discrete_branches) self.discrete_shape = list(env.unity_discrete_branches)
self.continuous_size = env.unity_continuous_size self.continuous_size = env.unity_continuous_size
self.network = nn.Sequential( self.viewNetwork = nn.Sequential(
layer_init(nn.Linear(self.head_input_size, 256)), layer_init(nn.Linear(self.totalRaySize, 200)),
nn.Tanh(),
layer_init(nn.Linear(256, 200)),
nn.Tanh(), nn.Tanh(),
) )
self.targetNetwork = nn.ModuleList([nn.Sequential( self.targetNetworks = nn.ModuleList([nn.Sequential(
layer_init(nn.Linear(self.totalMiddleSize+200,128)), layer_init(nn.Linear(self.targetInputSize,128)),
nn.Tanh(),
layer_init(nn.Linear(128,64)),
nn.Tanh() nn.Tanh()
)for i in range(targetNum)]) )for i in range(targetNum)])
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(64, self.discrete_size), std=0.01) for i in range(targetNum)]) self.middleNetworks = nn.ModuleList([nn.Sequential(
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(64, self.continuous_size), std=0.01) for i in range(targetNum)]) layer_init(nn.Linear(328,256)),
nn.Softplus()
)for i in range(targetNum)])
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(256, self.discrete_size), std=0.5) for i in range(targetNum)])
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(256, self.continuous_size), std=0) for i in range(targetNum)])
# self.actor_logstd = nn.ModuleList([layer_init(nn.Linear(256, self.continuous_size), std=1) for i in range(targetNum)])
self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)]) self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])
self.critic = nn.ModuleList([layer_init(nn.Linear(64, 1), std=1)for i in range(targetNum)]) self.critic = nn.ModuleList([nn.Sequential(
layer_init(nn.Linear(self.criticInputSize, 512)),
nn.Tanh(),
layer_init(nn.Linear(512, 256)),
nn.Tanh(),
layer_init(nn.Linear(256, 1), std=0.5))for i in range(targetNum)])
def get_value(self, state: torch.Tensor): def get_value(self, state: torch.Tensor):
headInput = state[:,-self.head_input_size:] # except target state
hidden = self.network(headInput) # (n,200)
targets = state[:,0].to(torch.int32) # int targets = state[:,0].to(torch.int32) # int
headInput = torch.cat([state[:,1:self.targetSize],state[:,self.targetSize+self.timeSize:]],dim=1) # except target state
middleInput = state[:,0:self.totalMiddleSize] # (n,targetSize) return torch.stack([self.critic[targets[i]](headInput[i])for i in range(targets.size()[0])])
middleInput = torch.cat([middleInput,hidden],dim=1) # targetState+hidden(n,targetSize+200)
middleLayer = torch.stack([self.targetNetwork[targets[i]](middleInput[i]) for i in range(targets.size()[0])])
return torch.stack([self.critic[targets[i]](middleLayer[i])for i in range(targets.size()[0])])
def get_actions_value(self, state: torch.Tensor, actions=None): def get_actions_value(self, state: torch.Tensor, actions=None):
headInput = state[:,-self.head_input_size:] # except target state
hidden = self.network(headInput) # (n,200)
targets = state[:,0].to(torch.int32) # int targets = state[:,0].to(torch.int32) # int
viewInput = state[:,-self.totalRaySize:] # all ray input
targetInput = torch.cat([state[:,1:self.targetSize],state[:,self.targetSize+self.timeSize:self.totalTSize]],dim=1) # all target except time and target intselt
middleInput = state[:,0:self.totalMiddleSize] # (n,targetSize) viewLayer = self.viewNetwork(viewInput)
middleInput = torch.cat([middleInput,hidden],dim=1) # targetState+hidden(n,targetSize+200) targetLayer = torch.stack([self.targetNetworks[targets[i]](targetInput[i]) for i in range(targets.size()[0])])
middleLayer = torch.stack([self.targetNetwork[targets[i]](middleInput[i]) for i in range(targets.size()[0])]) middleInput = torch.cat([viewLayer,targetLayer],dim = 1)
middleLayer = torch.stack([self.middleNetworks[targets[i]](middleInput[i]) for i in range(targets.size()[0])])
# discrete # discrete
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出 # 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
dis_logits = torch.stack([self.actor_dis[targets[i]](middleLayer[i]) for i in range(targets.size()[0])]) dis_logits = torch.stack([self.actor_dis[targets[i]](middleLayer[i]) for i in range(targets.size()[0])])
@ -220,9 +226,10 @@ class PPOAgent(nn.Module):
# action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean) # action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean)
# print(action_logstd) # print(action_logstd)
action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd) action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)
action_std = torch.clamp(action_std,1e-10)
con_probs = Normal(actions_mean, action_std) con_probs = Normal(actions_mean, action_std)
# critic # critic
criticV = torch.stack([self.critic[targets[i]](middleLayer[i])for i in range(targets.size()[0])]) criticV = self.get_value(state)
if actions is None: if actions is None:
if args.train: if args.train:
@ -361,6 +368,12 @@ if __name__ == "__main__":
agent = PPOAgent(env,TARGETNUM).to(device) agent = PPOAgent(env,TARGETNUM).to(device)
else: else:
agent = torch.load(args.load_dir) agent = torch.load(args.load_dir)
# freeze
if args.freeze_headnet:
# freeze the head network
for p in agent.viewNetwork.parameters():
p.requires_grad = False
print("HEAD NETWORK FREEZED")
print("Load Agent", args.load_dir) print("Load Agent", args.load_dir)
print(agent.eval()) print(agent.eval())
@ -429,12 +442,13 @@ if __name__ == "__main__":
# MAIN LOOP: run agent in environment # MAIN LOOP: run agent in environment
i = 0 step = 0
training = False training = False
trainQueue = [] trainQueue = []
last_reward = [0.for i in range(env.unity_agent_num)]
while True: while True:
if i % args.decision_period == 0: if step % args.decision_period == 0:
step = round(i / args.decision_period) step += 1
# Choose action by agent # Choose action by agent
with torch.no_grad(): with torch.no_grad():
@ -459,7 +473,7 @@ if __name__ == "__main__":
act_bf[i].append(action_cpu[i]) act_bf[i].append(action_cpu[i])
dis_logprobs_bf[i].append(dis_logprob_cpu[i]) dis_logprobs_bf[i].append(dis_logprob_cpu[i])
con_logprobs_bf[i].append(con_logprob_cpu[i]) con_logprobs_bf[i].append(con_logprob_cpu[i])
rewards_bf[i].append(reward[i]) rewards_bf[i].append(reward[i]+last_reward[i])
dones_bf[i].append(done[i]) dones_bf[i].append(done[i])
values_bf[i].append(value_cpu[i]) values_bf[i].append(value_cpu[i])
remainTime = state[i,TARGET_STATE_SIZE] remainTime = state[i,TARGET_STATE_SIZE]
@ -475,7 +489,7 @@ if __name__ == "__main__":
thisRewardsTensor, thisRewardsTensor,
torch.Tensor(dones_bf[i]).to(device), torch.Tensor(dones_bf[i]).to(device),
torch.tensor(values_bf[i]).to(device), torch.tensor(values_bf[i]).to(device),
torch.tensor([next_state[i]]).to(device), torch.Tensor(next_state[i]).to(device).unsqueeze(dim = 0),
torch.Tensor([next_done[i]]).to(device), torch.Tensor([next_done[i]]).to(device),
) )
# send memories to training datasets # send memories to training datasets
@ -508,15 +522,16 @@ if __name__ == "__main__":
trainQueue.append(i) trainQueue.append(i)
if(len(trainQueue)>0): if(len(trainQueue)>0):
break break
state, done = next_state, next_done # state, done = next_state, next_done
else: else:
step += 1
# skip this step use last predict action # skip this step use last predict action
next_obs, reward, next_done = env.step(action_cpu) next_state, reward, next_done = env.step(action_cpu)
# save memories # save memories
for i in range(env.unity_agent_num): for i in range(env.unity_agent_num):
if next_done[i] == True: if next_done[i] == True:
#print(i,"over???") #print(i,"over???")
# save last memories to buffers # save memories to buffers
ob_bf[i].append(state[i]) ob_bf[i].append(state[i])
act_bf[i].append(action_cpu[i]) act_bf[i].append(action_cpu[i])
dis_logprobs_bf[i].append(dis_logprob_cpu[i]) dis_logprobs_bf[i].append(dis_logprob_cpu[i])
@ -524,30 +539,33 @@ if __name__ == "__main__":
rewards_bf[i].append(reward[i]) rewards_bf[i].append(reward[i])
dones_bf[i].append(done[i]) dones_bf[i].append(done[i])
values_bf[i].append(value_cpu[i]) values_bf[i].append(value_cpu[i])
remainTime = state[i,TARGET_STATE_SIZE]
# finished a round, send finished memories to training datasets # finished a round, send finished memories to training datasets
# compute advantage and discounted reward # compute advantage and discounted reward
roundTargetType = int(state[i,0])
thisRewardsTensor = broadCastEndReward(rewards_bf[i],remainTime)
adv, rt = GAE( adv, rt = GAE(
agent, agent,
args, args,
torch.tensor(rewards_bf[i]).to(device), thisRewardsTensor,
torch.Tensor(dones_bf[i]).to(device), torch.Tensor(dones_bf[i]).to(device),
torch.tensor(values_bf[i]).to(device), torch.tensor(values_bf[i]).to(device),
torch.tensor(next_state[i]).to(device), torch.Tensor(next_state[i]).to(device).unsqueeze(dim = 0),
torch.Tensor([next_done[i]]).to(device), torch.Tensor([next_done[i]]).to(device),
) )
# send memories to training datasets # send memories to training datasets
obs = torch.cat((obs, torch.tensor(ob_bf[i]).to(device)), 0) obs[roundTargetType] = torch.cat((obs[roundTargetType], torch.tensor(ob_bf[i]).to(device)), 0)
actions = torch.cat((actions, torch.tensor(act_bf[i]).to(device)), 0) actions[roundTargetType] = torch.cat((actions[roundTargetType], torch.tensor(act_bf[i]).to(device)), 0)
dis_logprobs = torch.cat( dis_logprobs[roundTargetType] = torch.cat(
(dis_logprobs, torch.tensor(dis_logprobs_bf[i]).to(device)), 0 (dis_logprobs[roundTargetType], torch.tensor(dis_logprobs_bf[i]).to(device)), 0
) )
con_logprobs = torch.cat( con_logprobs[roundTargetType] = torch.cat(
(con_logprobs, torch.tensor(con_logprobs_bf[i]).to(device)), 0 (con_logprobs[roundTargetType], torch.tensor(con_logprobs_bf[i]).to(device)), 0
) )
rewards = torch.cat((rewards, torch.tensor(rewards_bf[i]).to(device)), 0) rewards[roundTargetType] = torch.cat((rewards[roundTargetType], thisRewardsTensor), 0)
values = torch.cat((values, torch.tensor(values_bf[i]).to(device)), 0) values[roundTargetType] = torch.cat((values[roundTargetType], torch.tensor(values_bf[i]).to(device)), 0)
advantages = torch.cat((advantages, adv), 0) advantages[roundTargetType] = torch.cat((advantages[roundTargetType], adv), 0)
returns = torch.cat((returns, rt), 0) returns[roundTargetType] = torch.cat((returns[roundTargetType], rt), 0)
# clear buffers # clear buffers
ob_bf[i] = [] ob_bf[i] = []
@ -557,8 +575,10 @@ if __name__ == "__main__":
rewards_bf[i] = [] rewards_bf[i] = []
dones_bf[i] = [] dones_bf[i] = []
values_bf[i] = [] values_bf[i] = []
# print(f"train dataset added:{obs.size()[0]}/{args.datasetSize}") print(f"train dataset {Targets(roundTargetType).name} added:{obs[roundTargetType].size()[0]}/{args.datasetSize}")
state, done = next_state, next_done
state = next_state
last_reward = reward
i += 1 i += 1
if args.train: if args.train:
@ -574,14 +594,16 @@ if __name__ == "__main__":
b_advantages = advantages[thisT].reshape(-1) b_advantages = advantages[thisT].reshape(-1)
b_returns = returns[thisT].reshape(-1) b_returns = returns[thisT].reshape(-1)
b_values = values[thisT].reshape(-1) b_values = values[thisT].reshape(-1)
b_size = b_obs[thisT].size()[0] b_size = b_obs.size()[0]
# Optimizing the policy and value network # Optimizing the policy and value network
b_inds = np.arange(b_size) b_inds = np.arange(b_size)
# clipfracs = [] # clipfracs = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
print(epoch,end="")
# shuffle all datasets # shuffle all datasets
np.random.shuffle(b_inds) np.random.shuffle(b_inds)
for start in range(0, b_size, args.minibatchSize): for start in range(0, b_size, args.minibatchSize):
print(".",end="")
end = start + args.minibatchSize end = start + args.minibatchSize
mb_inds = b_inds[start:end] mb_inds = b_inds[start:end]
mb_advantages = b_advantages[mb_inds] mb_advantages = b_advantages[mb_inds]
@ -603,9 +625,11 @@ if __name__ == "__main__":
# discrete ratio # discrete ratio
dis_logratio = new_dis_logprob - b_dis_logprobs[mb_inds] dis_logratio = new_dis_logprob - b_dis_logprobs[mb_inds]
dis_ratio = dis_logratio.exp() dis_ratio = dis_logratio.exp()
# dis_ratio = (new_dis_logprob / (b_dis_logprobs[mb_inds]+1e-8)).mean()
# continuous ratio # continuous ratio
con_logratio = new_con_logprob - b_con_logprobs[mb_inds] con_logratio = new_con_logprob - b_con_logprobs[mb_inds]
con_ratio = con_logratio.exp() con_ratio = con_logratio.exp()
# con_ratio = (new_con_logprob / (b_con_logprobs[mb_inds]+1e-8)).mean()
""" """
# early stop # early stop
@ -665,6 +689,7 @@ if __name__ == "__main__":
break break
""" """
# record mean reward before clear history # record mean reward before clear history
print("done")
targetRewardMean = np.mean(rewards[thisT].to("cpu").detach().numpy().copy()) targetRewardMean = np.mean(rewards[thisT].to("cpu").detach().numpy().copy())
meanRewardList.append(targetRewardMean) meanRewardList.append(targetRewardMean)
targetName = Targets(thisT).name targetName = Targets(thisT).name

View File

@ -795,69 +795,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 2, 3, 4, 5, 6, 7, 8]])\n",
"(tensor([[1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [1, 2, 3]]), tensor([[4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6],\n",
" [4, 5, 6]]), tensor([[7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8],\n",
" [7, 8]]))\n"
]
},
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[2, 0, 0],\n", "tensor([[2, 3, 5, 6, 7, 8],\n",
" [2, 2, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [2, 2, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [2, 1, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [2, 2, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [2, 2, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [1, 1, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [1, 2, 1],\n", " [2, 3, 5, 6, 7, 8],\n",
" [1, 1, 0],\n", " [2, 3, 5, 6, 7, 8],\n",
" [2, 2, 0]])" " [2, 3, 5, 6, 7, 8]])"
] ]
}, },
"execution_count": 9, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -870,9 +826,35 @@
"aaasplt = torch.split(aaa,[3,3,2],dim=1)\n", "aaasplt = torch.split(aaa,[3,3,2],dim=1)\n",
"multicate = [Categorical(logits=thislo) for thislo in aaasplt]\n", "multicate = [Categorical(logits=thislo) for thislo in aaasplt]\n",
"disact = torch.stack([ctgr.sample() for ctgr in multicate])\n", "disact = torch.stack([ctgr.sample() for ctgr in multicate])\n",
"print(aaa)\n", "#print(aaa)\n",
"print(aaasplt)\n", "#print(aaasplt)\n",
"disact.T" "torch.cat([aaa[:,1:3],aaa[:,4:]],dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "new(): data must be a sequence (got bool)",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_42068\\1624049819.py\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdistributions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormal\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mNormal\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0maaa\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'cuda'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 6\u001b[0m \u001b[0maaa\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mTypeError\u001b[0m: new(): data must be a sequence (got bool)"
]
}
],
"source": [
"import torch\n",
"import numpy as np\n",
"from torch.distributions.normal import Normal\n",
"\n",
"aaa = torch.Tensor(True).to('cuda').unsqueeze(0)\n",
"aaa"
] ]
} }
], ],
@ -892,7 +874,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)]" "version": "3.9.7"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {