Change Critic NN as Multi-NN
Change Critic NN as Multi-NN wrong remain Time Fix wrong remain Time Fix, what a stupid mistake... and fix doubled WANDB writer Deeper TargetNN deeper target NN and will get target state while receive hidden layer's output. Change Middle input let every thing expect raycast input to target network. Change Activation function to Tanh Change Activation function to Tanh, and it's works a little bit better than before.
This commit is contained in:
parent
cbc385ca10
commit
bf77060456
@ -23,31 +23,32 @@ from mlagents_envs.side_channel.side_channel import (
|
|||||||
)
|
)
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
bestReward = 0
|
bestReward = -1
|
||||||
|
|
||||||
DEFAULT_SEED = 933139
|
DEFAULT_SEED = 9331
|
||||||
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward-Easy/Aimbot-ParallelEnv"
|
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel-EndReward-Easy-V2.5-FreeOnly-NormalMapSize/Aimbot-ParallelEnv"
|
||||||
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
|
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
|
||||||
WAND_ENTITY = "koha9"
|
WAND_ENTITY = "koha9"
|
||||||
WORKER_ID = 2
|
WORKER_ID = 1
|
||||||
BASE_PORT = 1001
|
BASE_PORT = 1000
|
||||||
|
|
||||||
# max round steps per agent is 2500/Decision_period, 25 seconds
|
# max round steps per agent is 2500/Decision_period, 25 seconds
|
||||||
# !!!check every parameters before run!!!
|
# !!!check every parameters before run!!!
|
||||||
|
|
||||||
TOTAL_STEPS = 6750000
|
TOTAL_STEPS = 3150000
|
||||||
BATCH_SIZE = 512
|
BATCH_SIZE = 256
|
||||||
MAX_TRAINNING_DATASETS = 3000
|
MAX_TRAINNING_DATASETS = 6000
|
||||||
DECISION_PERIOD = 1
|
DECISION_PERIOD = 1
|
||||||
LEARNING_RATE = 1e-3
|
LEARNING_RATE = 5e-4
|
||||||
GAMMA = 0.99
|
GAMMA = 0.99
|
||||||
GAE_LAMBDA = 0.95
|
GAE_LAMBDA = 0.95
|
||||||
EPOCHS = 4
|
EPOCHS = 4
|
||||||
CLIP_COEF = 0.1
|
CLIP_COEF = 0.11
|
||||||
POLICY_COEF = 1.0
|
LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
|
||||||
ENTROPY_COEF = 0.01
|
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
|
||||||
CRITIC_COEF = 0.5
|
ENTROPY_COEF = [0.1, 0.1, 0.1, 0.1]
|
||||||
TARGET_LEARNING_RATE = 5e-5
|
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
|
||||||
|
TARGET_LEARNING_RATE = 1e-6
|
||||||
|
|
||||||
ANNEAL_LEARNING_RATE = True
|
ANNEAL_LEARNING_RATE = True
|
||||||
CLIP_VLOSS = True
|
CLIP_VLOSS = True
|
||||||
@ -56,7 +57,7 @@ TRAIN = True
|
|||||||
|
|
||||||
WANDB_TACK = True
|
WANDB_TACK = True
|
||||||
LOAD_DIR = None
|
LOAD_DIR = None
|
||||||
#LOAD_DIR = "../PPO-Model/Aimbot-target-last.pt"
|
#LOAD_DIR = "../PPO-Model/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670522099-freeonly-12/Aimbot-target-last.pt"
|
||||||
|
|
||||||
# public data
|
# public data
|
||||||
class Targets(Enum):
|
class Targets(Enum):
|
||||||
@ -65,11 +66,16 @@ class Targets(Enum):
|
|||||||
Attack = 2
|
Attack = 2
|
||||||
Defence = 3
|
Defence = 3
|
||||||
Num = 4
|
Num = 4
|
||||||
|
TARGET_STATE_SIZE = 7 # 6+1
|
||||||
|
TIME_STATE_SIZE = 1
|
||||||
|
GUN_STATE_SIZE = 1
|
||||||
|
MY_STATE_SIZE = 4
|
||||||
|
TOTAL_MIDDLE_STATE_SIZE = TARGET_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE
|
||||||
BASE_WINREWARD = 999
|
BASE_WINREWARD = 999
|
||||||
BASE_LOSEREWARD = -999
|
BASE_LOSEREWARD = -999
|
||||||
TARGETNUM= 4
|
TARGETNUM= 4
|
||||||
ENV_TIMELIMIT = 30
|
ENV_TIMELIMIT = 30
|
||||||
RESULT_BROADCAST_RATIO = 2/ENV_TIMELIMIT
|
RESULT_BROADCAST_RATIO = 1/ENV_TIMELIMIT
|
||||||
TotalRounds = {"Free":0,"Go":0,"Attack":0}
|
TotalRounds = {"Free":0,"Go":0,"Attack":0}
|
||||||
WinRounds = {"Free":0,"Go":0,"Attack":0}
|
WinRounds = {"Free":0,"Go":0,"Attack":0}
|
||||||
|
|
||||||
@ -116,6 +122,8 @@ def parse_args():
|
|||||||
help="load model directory")
|
help="load model directory")
|
||||||
parser.add_argument("--decision-period", type=int, default=DECISION_PERIOD,
|
parser.add_argument("--decision-period", type=int, default=DECISION_PERIOD,
|
||||||
help="the number of steps to run in each environment per policy rollout")
|
help="the number of steps to run in each environment per policy rollout")
|
||||||
|
parser.add_argument("--result-broadcast-ratio", type=float, default=RESULT_BROADCAST_RATIO,
|
||||||
|
help="broadcast result when win round is reached,r=result-broadcast-ratio*remainTime")
|
||||||
|
|
||||||
# GAE loss
|
# GAE loss
|
||||||
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
||||||
@ -155,39 +163,66 @@ class PPOAgent(nn.Module):
|
|||||||
def __init__(self, env: Aimbot,targetNum:int):
|
def __init__(self, env: Aimbot,targetNum:int):
|
||||||
super(PPOAgent, self).__init__()
|
super(PPOAgent, self).__init__()
|
||||||
self.targetNum = targetNum
|
self.targetNum = targetNum
|
||||||
|
self.targetSize = TARGET_STATE_SIZE
|
||||||
|
self.timeSize = TIME_STATE_SIZE
|
||||||
|
self.gunSize = GUN_STATE_SIZE
|
||||||
|
self.myStateSize = MY_STATE_SIZE
|
||||||
|
self.totalMiddleSize = TOTAL_MIDDLE_STATE_SIZE
|
||||||
|
self.head_input_size = env.unity_observation_shape[0] - self.targetSize-self.timeSize-self.gunSize# except target state input
|
||||||
self.discrete_size = env.unity_discrete_size
|
self.discrete_size = env.unity_discrete_size
|
||||||
self.discrete_shape = list(env.unity_discrete_branches)
|
self.discrete_shape = list(env.unity_discrete_branches)
|
||||||
self.continuous_size = env.unity_continuous_size
|
self.continuous_size = env.unity_continuous_size
|
||||||
|
|
||||||
self.network = nn.Sequential(
|
self.network = nn.Sequential(
|
||||||
layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 500)),
|
layer_init(nn.Linear(self.head_input_size, 256)),
|
||||||
nn.ReLU(),
|
nn.Tanh(),
|
||||||
layer_init(nn.Linear(500, 300)),
|
layer_init(nn.Linear(256, 200)),
|
||||||
nn.ReLU(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(300, self.discrete_size), std=0.01) for i in range(targetNum)])
|
self.targetNetwork = nn.ModuleList([nn.Sequential(
|
||||||
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(300, self.continuous_size), std=0.01) for i in range(targetNum)])
|
layer_init(nn.Linear(self.totalMiddleSize+200,128)),
|
||||||
|
nn.Tanh(),
|
||||||
|
layer_init(nn.Linear(128,64)),
|
||||||
|
nn.Tanh()
|
||||||
|
)for i in range(targetNum)])
|
||||||
|
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(64, self.discrete_size), std=0.01) for i in range(targetNum)])
|
||||||
|
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(64, self.continuous_size), std=0.01) for i in range(targetNum)])
|
||||||
self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])
|
self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])
|
||||||
self.critic = layer_init(nn.Linear(300, 1), std=1)
|
self.critic = nn.ModuleList([layer_init(nn.Linear(64, 1), std=1)for i in range(targetNum)])
|
||||||
|
|
||||||
def get_value(self, state: torch.Tensor):
|
def get_value(self, state: torch.Tensor):
|
||||||
return self.critic(self.network(state))
|
headInput = state[:,-self.head_input_size:] # except target state
|
||||||
|
hidden = self.network(headInput) # (n,200)
|
||||||
|
targets = state[:,0].to(torch.int32) # int
|
||||||
|
|
||||||
|
middleInput = state[:,0:self.totalMiddleSize] # (n,targetSize)
|
||||||
|
middleInput = torch.cat([middleInput,hidden],dim=1) # targetState+hidden(n,targetSize+200)
|
||||||
|
middleLayer = torch.stack([self.targetNetwork[targets[i]](middleInput[i]) for i in range(targets.size()[0])])
|
||||||
|
|
||||||
|
return torch.stack([self.critic[targets[i]](middleLayer[i])for i in range(targets.size()[0])])
|
||||||
|
|
||||||
def get_actions_value(self, state: torch.Tensor, actions=None):
|
def get_actions_value(self, state: torch.Tensor, actions=None):
|
||||||
hidden = self.network(state)
|
headInput = state[:,-self.head_input_size:] # except target state
|
||||||
targets = state[:,0].to(torch.int32)
|
hidden = self.network(headInput) # (n,200)
|
||||||
|
targets = state[:,0].to(torch.int32) # int
|
||||||
|
|
||||||
|
middleInput = state[:,0:self.totalMiddleSize] # (n,targetSize)
|
||||||
|
middleInput = torch.cat([middleInput,hidden],dim=1) # targetState+hidden(n,targetSize+200)
|
||||||
|
middleLayer = torch.stack([self.targetNetwork[targets[i]](middleInput[i]) for i in range(targets.size()[0])])
|
||||||
|
|
||||||
# discrete
|
# discrete
|
||||||
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
||||||
dis_logits = torch.stack([self.actor_dis[targets[i]](hidden[i]) for i in range(targets.size()[0])])
|
dis_logits = torch.stack([self.actor_dis[targets[i]](middleLayer[i]) for i in range(targets.size()[0])])
|
||||||
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
||||||
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
||||||
# continuous
|
# continuous
|
||||||
actions_mean = torch.stack([self.actor_mean[targets[i]](hidden[i]) for i in range(targets.size()[0])]) # self.actor_mean(hidden)
|
actions_mean = torch.stack([self.actor_mean[targets[i]](middleLayer[i]) for i in range(targets.size()[0])]) # self.actor_mean(hidden)
|
||||||
# action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean)
|
# action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean)
|
||||||
# print(action_logstd)
|
# print(action_logstd)
|
||||||
action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)
|
action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)
|
||||||
con_probs = Normal(actions_mean, action_std)
|
con_probs = Normal(actions_mean, action_std)
|
||||||
|
# critic
|
||||||
|
criticV = torch.stack([self.critic[targets[i]](middleLayer[i])for i in range(targets.size()[0])])
|
||||||
|
|
||||||
if actions is None:
|
if actions is None:
|
||||||
if args.train:
|
if args.train:
|
||||||
@ -213,7 +248,7 @@ class PPOAgent(nn.Module):
|
|||||||
dis_entropy.sum(0),
|
dis_entropy.sum(0),
|
||||||
con_probs.log_prob(conAct).sum(1),
|
con_probs.log_prob(conAct).sum(1),
|
||||||
con_probs.entropy().sum(1),
|
con_probs.entropy().sum(1),
|
||||||
self.critic(hidden),
|
criticV,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -301,11 +336,11 @@ def broadCastEndReward(rewardBF:list,remainTime:float):
|
|||||||
if (rewardBF[-1]<=-500):
|
if (rewardBF[-1]<=-500):
|
||||||
# print("Lose DO NOT BROAD CAST",rewardBF[-1])
|
# print("Lose DO NOT BROAD CAST",rewardBF[-1])
|
||||||
thisRewardBF[-1] = rewardBF[-1]-BASE_LOSEREWARD
|
thisRewardBF[-1] = rewardBF[-1]-BASE_LOSEREWARD
|
||||||
thisRewardBF = (np.asarray(thisRewardBF)).tolist()
|
thisRewardBF = thisRewardBF
|
||||||
elif (rewardBF[-1]>=500):
|
elif (rewardBF[-1]>=500):
|
||||||
# print("Win! Broadcast reward!",rewardBF[-1])
|
# print("Win! Broadcast reward!",rewardBF[-1])
|
||||||
thisRewardBF[-1] = rewardBF[-1]-BASE_WINREWARD
|
thisRewardBF[-1] = rewardBF[-1]-BASE_WINREWARD
|
||||||
thisRewardBF = (np.asarray(thisRewardBF)+(remainTime*RESULT_BROADCAST_RATIO)).tolist()
|
thisRewardBF = (np.asarray(thisRewardBF)+(remainTime*args.result_broadcast_ratio)).tolist()
|
||||||
else:
|
else:
|
||||||
print("!!!!!DIDNT GET RESULT REWARD!!!!!!",rewardBF[-1])
|
print("!!!!!DIDNT GET RESULT REWARD!!!!!!",rewardBF[-1])
|
||||||
return torch.Tensor(thisRewardBF).to(device)
|
return torch.Tensor(thisRewardBF).to(device)
|
||||||
@ -332,7 +367,7 @@ if __name__ == "__main__":
|
|||||||
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
|
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
|
||||||
|
|
||||||
# Tensorboard and WandB Recorder
|
# Tensorboard and WandB Recorder
|
||||||
game_name = "Aimbot_Target_Hybrid_Multi_Output"
|
game_name = "Aimbot_Target_Hybrid_PMNN_V2"
|
||||||
game_type = "OffPolicy_EndBC"
|
game_type = "OffPolicy_EndBC"
|
||||||
run_name = f"{game_name}_{game_type}_{args.seed}_{int(time.time())}"
|
run_name = f"{game_name}_{game_type}_{args.seed}_{int(time.time())}"
|
||||||
if args.wandb_track:
|
if args.wandb_track:
|
||||||
@ -382,12 +417,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for total_steps in range(total_update_step):
|
for total_steps in range(total_update_step):
|
||||||
# discunt learning rate, while step == total_update_step lr will be 0
|
# discunt learning rate, while step == total_update_step lr will be 0
|
||||||
print("new episode")
|
|
||||||
if args.annealLR:
|
if args.annealLR:
|
||||||
finalRatio = TARGET_LEARNING_RATE/args.lr
|
finalRatio = TARGET_LEARNING_RATE/args.lr
|
||||||
frac = 1.0 - finalRatio*((total_steps - 1.0) / total_update_step)
|
frac = 1.0 - ((total_steps + 1.0) / total_update_step)
|
||||||
lrnow = frac * args.lr
|
lrnow = frac * args.lr
|
||||||
optimizer.param_groups[0]["lr"] = lrnow
|
optimizer.param_groups[0]["lr"] = lrnow
|
||||||
|
else:
|
||||||
|
lrnow = args.lr
|
||||||
|
print("new episode",total_steps,"learning rate = ",lrnow)
|
||||||
|
|
||||||
|
|
||||||
# MAIN LOOP: run agent in environment
|
# MAIN LOOP: run agent in environment
|
||||||
@ -424,19 +462,20 @@ if __name__ == "__main__":
|
|||||||
rewards_bf[i].append(reward[i])
|
rewards_bf[i].append(reward[i])
|
||||||
dones_bf[i].append(done[i])
|
dones_bf[i].append(done[i])
|
||||||
values_bf[i].append(value_cpu[i])
|
values_bf[i].append(value_cpu[i])
|
||||||
|
remainTime = state[i,TARGET_STATE_SIZE]
|
||||||
if next_done[i] == True:
|
if next_done[i] == True:
|
||||||
# finished a round, send finished memories to training datasets
|
# finished a round, send finished memories to training datasets
|
||||||
# compute advantage and discounted reward
|
# compute advantage and discounted reward
|
||||||
#print(i,"over")
|
#print(i,"over")
|
||||||
roundTargetType = int(state[i,0])
|
roundTargetType = int(state[i,0])
|
||||||
thisRewardsTensor = broadCastEndReward(rewards_bf[i],roundTargetType)
|
thisRewardsTensor = broadCastEndReward(rewards_bf[i],remainTime)
|
||||||
adv, rt = GAE(
|
adv, rt = GAE(
|
||||||
agent,
|
agent,
|
||||||
args,
|
args,
|
||||||
thisRewardsTensor,
|
thisRewardsTensor,
|
||||||
torch.Tensor(dones_bf[i]).to(device),
|
torch.Tensor(dones_bf[i]).to(device),
|
||||||
torch.tensor(values_bf[i]).to(device),
|
torch.tensor(values_bf[i]).to(device),
|
||||||
torch.tensor(next_state[i]).to(device),
|
torch.tensor([next_state[i]]).to(device),
|
||||||
torch.Tensor([next_done[i]]).to(device),
|
torch.Tensor([next_done[i]]).to(device),
|
||||||
)
|
)
|
||||||
# send memories to training datasets
|
# send memories to training datasets
|
||||||
@ -518,7 +557,7 @@ if __name__ == "__main__":
|
|||||||
rewards_bf[i] = []
|
rewards_bf[i] = []
|
||||||
dones_bf[i] = []
|
dones_bf[i] = []
|
||||||
values_bf[i] = []
|
values_bf[i] = []
|
||||||
print(f"train dataset added:{obs.size()[0]}/{args.datasetSize}")
|
# print(f"train dataset added:{obs.size()[0]}/{args.datasetSize}")
|
||||||
state, done = next_state, next_done
|
state, done = next_state, next_done
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
@ -608,11 +647,11 @@ if __name__ == "__main__":
|
|||||||
# total loss
|
# total loss
|
||||||
entropy_loss = dis_entropy.mean() + con_entropy.mean()
|
entropy_loss = dis_entropy.mean() + con_entropy.mean()
|
||||||
loss = (
|
loss = (
|
||||||
dis_pg_loss * args.policy_coef
|
dis_pg_loss * POLICY_COEF[thisT]
|
||||||
+ con_pg_loss * args.policy_coef
|
+ con_pg_loss * POLICY_COEF[thisT]
|
||||||
- entropy_loss * args.ent_coef
|
- entropy_loss * ENTROPY_COEF[thisT]
|
||||||
+ v_loss * args.critic_coef
|
+ v_loss * CRITIC_COEF[thisT]
|
||||||
)
|
)*LOSS_COEF[thisT]
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -642,7 +681,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# record rewards for plotting purposes
|
# record rewards for plotting purposes
|
||||||
writer.add_scalar(f"Target{targetName}/value_loss", v_loss.item(), target_steps[thisT])
|
writer.add_scalar(f"Target{targetName}/value_loss", v_loss.item(), target_steps[thisT])
|
||||||
writer.add_scalar(f"Target{targetName}/value_loss", v_loss.item(), target_steps[thisT])
|
|
||||||
writer.add_scalar(f"Target{targetName}/dis_policy_loss", dis_pg_loss.item(), target_steps[thisT])
|
writer.add_scalar(f"Target{targetName}/dis_policy_loss", dis_pg_loss.item(), target_steps[thisT])
|
||||||
writer.add_scalar(f"Target{targetName}/con_policy_loss", con_pg_loss.item(), target_steps[thisT])
|
writer.add_scalar(f"Target{targetName}/con_policy_loss", con_pg_loss.item(), target_steps[thisT])
|
||||||
writer.add_scalar(f"Target{targetName}/total_loss", loss.item(), target_steps[thisT])
|
writer.add_scalar(f"Target{targetName}/total_loss", loss.item(), target_steps[thisT])
|
||||||
@ -656,10 +694,10 @@ if __name__ == "__main__":
|
|||||||
# New Record!
|
# New Record!
|
||||||
if TotalRewardMean > bestReward:
|
if TotalRewardMean > bestReward:
|
||||||
bestReward = targetRewardMean
|
bestReward = targetRewardMean
|
||||||
saveDir = "../PPO-Model/Hybrid-MNN-500-300" + str(TotalRewardMean) + ".pt"
|
saveDir = "../PPO-Model/" + run_name +"_"+ str(TotalRewardMean) + ".pt"
|
||||||
torch.save(agent, saveDir)
|
torch.save(agent, saveDir)
|
||||||
|
|
||||||
saveDir = "../PPO-Model/Hybrid-MNN-500-300-Last" + ".pt"
|
saveDir = "../PPO-Model/"+ run_name + "_last.pt"
|
||||||
torch.save(agent, saveDir)
|
torch.save(agent, saveDir)
|
||||||
env.close()
|
env.close()
|
||||||
writer.close()
|
writer.close()
|
||||||
|
@ -792,6 +792,88 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"env.close()"
|
"env.close()"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"tensor([[1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8],\n",
|
||||||
|
" [1, 2, 3, 4, 5, 6, 7, 8]])\n",
|
||||||
|
"(tensor([[1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3],\n",
|
||||||
|
" [1, 2, 3]]), tensor([[4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6],\n",
|
||||||
|
" [4, 5, 6]]), tensor([[7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8],\n",
|
||||||
|
" [7, 8]]))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"tensor([[2, 0, 0],\n",
|
||||||
|
" [2, 2, 1],\n",
|
||||||
|
" [2, 2, 1],\n",
|
||||||
|
" [2, 1, 1],\n",
|
||||||
|
" [2, 2, 1],\n",
|
||||||
|
" [2, 2, 1],\n",
|
||||||
|
" [1, 1, 1],\n",
|
||||||
|
" [1, 2, 1],\n",
|
||||||
|
" [1, 1, 0],\n",
|
||||||
|
" [2, 2, 0]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.distributions.categorical import Categorical\n",
|
||||||
|
"\n",
|
||||||
|
"aaa = torch.tensor([[1,2,3,4,5,6,7,8] for i in range(10)])\n",
|
||||||
|
"aaasplt = torch.split(aaa,[3,3,2],dim=1)\n",
|
||||||
|
"multicate = [Categorical(logits=thislo) for thislo in aaasplt]\n",
|
||||||
|
"disact = torch.stack([ctgr.sample() for ctgr in multicate])\n",
|
||||||
|
"print(aaa)\n",
|
||||||
|
"print(aaasplt)\n",
|
||||||
|
"disact.T"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -810,7 +892,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.7"
|
"version": "3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)]"
|
||||||
},
|
},
|
||||||
"orig_nbformat": 4,
|
"orig_nbformat": 4,
|
||||||
"vscode": {
|
"vscode": {
|
||||||
|
Loading…
Reference in New Issue
Block a user