Change Learning timing
change learning timing to each episode end.
This commit is contained in:
parent
a0895c7449
commit
32d398dbef
@ -8,6 +8,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from AimbotEnv import Aimbot
|
||||
from tqdm import tqdm
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.categorical import Categorical
|
||||
from distutils.util import strtobool
|
||||
@ -16,34 +17,34 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
bestReward = 0
|
||||
|
||||
DEFAULT_SEED = 9331
|
||||
ENV_PATH = "../Build/Build-ParallelEnv-BigArea-6Enemy/Aimbot-ParallelEnv"
|
||||
ENV_PATH = "../Build/Build-ParallelEnv-BigArea-6Enemy-EndBonus/Aimbot-ParallelEnv"
|
||||
WAND_ENTITY = "koha9"
|
||||
WORKER_ID = 1
|
||||
BASE_PORT = 1000
|
||||
|
||||
# max round steps per agent is 2500, 25 seconds
|
||||
|
||||
TOTAL_STEPS = 2000000
|
||||
STEP_NUM = 314
|
||||
TOTAL_STEPS = 4000000
|
||||
BATCH_SIZE = 512
|
||||
MAX_TRAINNING_DATASETS = 8000
|
||||
DECISION_PERIOD = 2
|
||||
LEARNING_RATE = 7e-4
|
||||
GAMMA = 0.99
|
||||
GAE_LAMBDA = 0.95
|
||||
MINIBATCH_NUM = 4
|
||||
EPOCHS = 4
|
||||
CLIP_COEF = 0.1
|
||||
POLICY_COEF = 1.0
|
||||
ENTROPY_COEF = 0.01
|
||||
CRITIC_COEF = 0.5
|
||||
|
||||
ANNEAL_LEARNING_RATE = True
|
||||
ANNEAL_LEARNING_RATE = False
|
||||
CLIP_VLOSS = True
|
||||
NORM_ADV = True
|
||||
TRAIN = True
|
||||
TRAIN = False
|
||||
|
||||
WANDB_TACK = False
|
||||
LOAD_DIR = None
|
||||
# LOAD_DIR = "../PPO-Model/SmallArea-256-128-hybrid-2nd-trainning.pt"
|
||||
|
||||
LOAD_DIR = "../PPO-Model/bigArea-4.pt"
|
||||
|
||||
def parse_args():
|
||||
# fmt: off
|
||||
@ -67,10 +68,10 @@ def parse_args():
|
||||
# model parameters
|
||||
parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True,
|
||||
help="Train Model or not")
|
||||
parser.add_argument("--stepNum", type=int, default=STEP_NUM,
|
||||
help="the number of steps to run in each environment per policy rollout")
|
||||
parser.add_argument("--minibatchesNum", type=int, default=MINIBATCH_NUM,
|
||||
help="the number of mini-batches")
|
||||
parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS,
|
||||
help="training dataset size,start training while dataset collect enough data")
|
||||
parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE,
|
||||
help="nimi batch size")
|
||||
parser.add_argument("--epochs", type=int, default=EPOCHS,
|
||||
help="the K epochs to update the policy")
|
||||
parser.add_argument("--annealLR", type=lambda x: bool(strtobool(x)), default=ANNEAL_LEARNING_RATE, nargs="?", const=True,
|
||||
@ -179,6 +180,40 @@ class PPOAgent(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def GAE(agent, args, rewards, dones, values, next_obs, next_done):
|
||||
# GAE
|
||||
with torch.no_grad():
|
||||
next_value = agent.get_value(next_obs).reshape(1, -1)
|
||||
data_size = rewards.size()[0]
|
||||
if args.gae:
|
||||
advantages = torch.zeros_like(rewards).to(device)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(data_size)):
|
||||
if t == data_size - 1:
|
||||
nextnonterminal = 1.0 - next_done
|
||||
nextvalues = next_value
|
||||
else:
|
||||
nextnonterminal = 1.0 - dones[t + 1]
|
||||
nextvalues = values[t + 1]
|
||||
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
|
||||
advantages[t] = lastgaelam = (
|
||||
delta + args.gamma * args.gaeLambda * nextnonterminal * lastgaelam
|
||||
)
|
||||
returns = advantages + values
|
||||
else:
|
||||
returns = torch.zeros_like(rewards).to(device)
|
||||
for t in reversed(range(data_size)):
|
||||
if t == data_size - 1:
|
||||
nextnonterminal = 1.0 - next_done
|
||||
next_return = next_value
|
||||
else:
|
||||
nextnonterminal = 1.0 - dones[t + 1]
|
||||
next_return = returns[t + 1]
|
||||
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
|
||||
advantages = returns - values
|
||||
return advantages, returns
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
random.seed(args.seed)
|
||||
@ -199,11 +234,11 @@ if __name__ == "__main__":
|
||||
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
|
||||
|
||||
# Tensorboard and WandB Recorder
|
||||
game_name = "Aimbot"
|
||||
game_name = "Aimbot-BigArea-6Enemy-EndBonus"
|
||||
run_name = f"{game_name}_{args.seed}_{int(time.time())}"
|
||||
if args.wandb_track:
|
||||
wandb.init(
|
||||
project=run_name,
|
||||
project=game_name,
|
||||
entity=args.wandb_entity,
|
||||
sync_tensorboard=True,
|
||||
config=vars(args),
|
||||
@ -219,94 +254,165 @@ if __name__ == "__main__":
|
||||
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
||||
)
|
||||
|
||||
# Memory Record
|
||||
obs = torch.zeros((args.stepNum, env.unity_agent_num) + env.unity_observation_shape).to(device)
|
||||
actions = torch.zeros((args.stepNum, env.unity_agent_num) + (env.unity_action_size,)).to(device)
|
||||
dis_logprobs = torch.zeros((args.stepNum, env.unity_agent_num)).to(device)
|
||||
con_logprobs = torch.zeros((args.stepNum, env.unity_agent_num)).to(device)
|
||||
rewards = torch.zeros((args.stepNum, env.unity_agent_num)).to(device)
|
||||
dones = torch.zeros((args.stepNum, env.unity_agent_num)).to(device)
|
||||
values = torch.zeros((args.stepNum, env.unity_agent_num)).to(device)
|
||||
# Trajectory Buffer
|
||||
ob_bf = [[] for i in range(env.unity_agent_num)]
|
||||
act_bf = [[] for i in range(env.unity_agent_num)]
|
||||
dis_logprobs_bf = [[] for i in range(env.unity_agent_num)]
|
||||
con_logprobs_bf = [[] for i in range(env.unity_agent_num)]
|
||||
rewards_bf = [[] for i in range(env.unity_agent_num)]
|
||||
dones_bf = [[] for i in range(env.unity_agent_num)]
|
||||
values_bf = [[] for i in range(env.unity_agent_num)]
|
||||
|
||||
# TRY NOT TO MODIFY: start the game
|
||||
args.batch_size = int(env.unity_agent_num * args.stepNum)
|
||||
args.minibatch_size = int(args.batch_size // args.minibatchesNum)
|
||||
total_update_step = args.total_timesteps // args.batch_size
|
||||
total_update_step = args.total_timesteps // args.datasetSize
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
next_obs, _, _ = env.reset()
|
||||
next_obs = torch.Tensor(next_obs).to(device)
|
||||
next_done = torch.zeros(env.unity_agent_num).to(device)
|
||||
state, _, done = env.reset()
|
||||
# state = torch.Tensor(next_obs).to(device)
|
||||
# next_done = torch.zeros(env.unity_agent_num).to(device)
|
||||
|
||||
for total_steps in range(total_update_step):
|
||||
# discunt learning rate, while step == total_update_step lr will be 0
|
||||
print("new episode")
|
||||
if args.annealLR:
|
||||
frac = 1.0 - (total_steps - 1.0) / total_update_step
|
||||
lrnow = frac * args.lr
|
||||
optimizer.param_groups[0]["lr"] = lrnow
|
||||
|
||||
# initialize empty training datasets
|
||||
obs = torch.tensor([]).to(device) # (n,env.unity_observation_size)
|
||||
actions = torch.tensor([]).to(device) # (n,env.unity_action_size)
|
||||
dis_logprobs = torch.tensor([]).to(device) # (n,1)
|
||||
con_logprobs = torch.tensor([]).to(device) # (n,1)
|
||||
rewards = torch.tensor([]).to(device) # (n,1)
|
||||
values = torch.tensor([]).to(device) # (n,1)
|
||||
advantages = torch.tensor([]).to(device) # (n,1)
|
||||
returns = torch.tensor([]).to(device) # (n,1)
|
||||
|
||||
# MAIN LOOP: run agent in environment
|
||||
for i in range(args.stepNum * args.decision_period):
|
||||
i = 0
|
||||
training = False
|
||||
while True:
|
||||
if i % args.decision_period == 0:
|
||||
step = round(i / args.decision_period)
|
||||
# Choose action by agent
|
||||
global_step += 1 * env.unity_agent_num
|
||||
obs[step] = next_obs
|
||||
dones[step] = next_done
|
||||
|
||||
with torch.no_grad():
|
||||
# predict actions
|
||||
action, dis_logprob, _, con_logprob, _, value = agent.get_actions_value(
|
||||
next_obs
|
||||
torch.Tensor(state).to(device)
|
||||
)
|
||||
value = value.flatten()
|
||||
next_obs, reward, done = env.step(action.cpu().numpy())
|
||||
|
||||
# variable from GPU to CPU
|
||||
action_cpu = action.cpu().numpy()
|
||||
dis_logprob_cpu = dis_logprob.cpu().numpy()
|
||||
con_logprob_cpu = con_logprob.cpu().numpy()
|
||||
value_cpu = value.cpu().numpy()
|
||||
# Environment step
|
||||
next_state, reward, next_done = env.step(action_cpu)
|
||||
|
||||
# save memories
|
||||
actions[step] = action
|
||||
dis_logprobs[step] = dis_logprob
|
||||
con_logprobs[step] = con_logprob
|
||||
values[step] = value
|
||||
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
||||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(
|
||||
device
|
||||
)
|
||||
for i in range(env.unity_agent_num):
|
||||
# save memories to buffers
|
||||
ob_bf[i].append(state[i])
|
||||
act_bf[i].append(action_cpu[i])
|
||||
dis_logprobs_bf[i].append(dis_logprob_cpu[i])
|
||||
con_logprobs_bf[i].append(con_logprob_cpu[i])
|
||||
rewards_bf[i].append(reward[i])
|
||||
dones_bf[i].append(done[i])
|
||||
values_bf[i].append(value_cpu[i])
|
||||
if next_done[i] == True:
|
||||
# finished a round, send finished memories to training datasets
|
||||
# compute advantage and discounted reward
|
||||
adv, rt = GAE(
|
||||
agent,
|
||||
args,
|
||||
torch.tensor(rewards_bf[i]).to(device),
|
||||
torch.Tensor(dones_bf[i]).to(device),
|
||||
torch.tensor(values_bf[i]).to(device),
|
||||
torch.tensor(next_state[i]).to(device),
|
||||
torch.Tensor([next_done[i]]).to(device),
|
||||
)
|
||||
# send memories to training datasets
|
||||
obs = torch.cat((obs, torch.tensor(ob_bf[i]).to(device)), 0)
|
||||
actions = torch.cat((actions, torch.tensor(act_bf[i]).to(device)), 0)
|
||||
dis_logprobs = torch.cat(
|
||||
(dis_logprobs, torch.tensor(dis_logprobs_bf[i]).to(device)), 0
|
||||
)
|
||||
con_logprobs = torch.cat(
|
||||
(con_logprobs, torch.tensor(con_logprobs_bf[i]).to(device)), 0
|
||||
)
|
||||
rewards = torch.cat((rewards, torch.tensor(rewards_bf[i]).to(device)), 0)
|
||||
values = torch.cat((values, torch.tensor(values_bf[i]).to(device)), 0)
|
||||
advantages = torch.cat((advantages, adv), 0)
|
||||
returns = torch.cat((returns, rt), 0)
|
||||
|
||||
# clear buffers
|
||||
ob_bf[i] = []
|
||||
act_bf[i] = []
|
||||
dis_logprobs_bf[i] = []
|
||||
con_logprobs_bf[i] = []
|
||||
rewards_bf[i] = []
|
||||
dones_bf[i] = []
|
||||
values_bf[i] = []
|
||||
print(f"train dataset:{obs.size()[0]}/{args.datasetSize}")
|
||||
|
||||
if obs.size()[0] >= args.datasetSize:
|
||||
# start train NN
|
||||
break
|
||||
state, done = next_state, next_done
|
||||
else:
|
||||
# skip this step use last predict action
|
||||
next_obs, reward, done = env.step(action.cpu().numpy())
|
||||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(
|
||||
device
|
||||
)
|
||||
next_obs, reward, done = env.step(action_cpu)
|
||||
# save memories
|
||||
for i in range(env.unity_agent_num):
|
||||
if next_done[i] == True:
|
||||
# save last memories to buffers
|
||||
ob_bf[i].append(state[i])
|
||||
act_bf[i].append(action_cpu[i])
|
||||
dis_logprobs_bf[i].append(dis_logprob_cpu[i])
|
||||
con_logprobs_bf[i].append(con_logprob_cpu[i])
|
||||
rewards_bf[i].append(reward[i])
|
||||
dones_bf[i].append(done[i])
|
||||
values_bf[i].append(value_cpu[i])
|
||||
# finished a round, send finished memories to training datasets
|
||||
# compute advantage and discounted reward
|
||||
adv, rt = GAE(
|
||||
agent,
|
||||
args,
|
||||
torch.tensor(rewards_bf[i]).to(device),
|
||||
torch.Tensor(dones_bf[i]).to(device),
|
||||
torch.tensor(values_bf[i]).to(device),
|
||||
torch.tensor(next_state[i]).to(device),
|
||||
torch.Tensor([next_done[i]]).to(device),
|
||||
)
|
||||
# send memories to training datasets
|
||||
obs = torch.cat((obs, torch.tensor(ob_bf[i]).to(device)), 0)
|
||||
actions = torch.cat((actions, torch.tensor(act_bf[i]).to(device)), 0)
|
||||
dis_logprobs = torch.cat(
|
||||
(dis_logprobs, torch.tensor(dis_logprobs_bf[i]).to(device)), 0
|
||||
)
|
||||
con_logprobs = torch.cat(
|
||||
(con_logprobs, torch.tensor(con_logprobs_bf[i]).to(device)), 0
|
||||
)
|
||||
rewards = torch.cat((rewards, torch.tensor(rewards_bf[i]).to(device)), 0)
|
||||
values = torch.cat((values, torch.tensor(values_bf[i]).to(device)), 0)
|
||||
advantages = torch.cat((advantages, adv), 0)
|
||||
returns = torch.cat((returns, rt), 0)
|
||||
|
||||
# GAE
|
||||
with torch.no_grad():
|
||||
next_value = agent.get_value(next_obs).reshape(1, -1)
|
||||
if args.gae:
|
||||
advantages = torch.zeros_like(rewards).to(device)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(args.stepNum)):
|
||||
if t == args.stepNum - 1:
|
||||
nextnonterminal = 1.0 - next_done
|
||||
nextvalues = next_value
|
||||
else:
|
||||
nextnonterminal = 1.0 - dones[t + 1]
|
||||
nextvalues = values[t + 1]
|
||||
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
|
||||
advantages[t] = lastgaelam = (
|
||||
delta + args.gamma * args.gaeLambda * nextnonterminal * lastgaelam
|
||||
)
|
||||
returns = advantages + values
|
||||
else:
|
||||
returns = torch.zeros_like(rewards).to(device)
|
||||
for t in reversed(range(args.stepNum)):
|
||||
if t == args.stepNum - 1:
|
||||
nextnonterminal = 1.0 - next_done
|
||||
next_return = next_value
|
||||
else:
|
||||
nextnonterminal = 1.0 - dones[t + 1]
|
||||
next_return = returns[t + 1]
|
||||
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
|
||||
advantages = returns - values
|
||||
# clear buffers
|
||||
ob_bf[i] = []
|
||||
act_bf[i] = []
|
||||
dis_logprobs_bf[i] = []
|
||||
con_logprobs_bf[i] = []
|
||||
rewards_bf[i] = []
|
||||
dones_bf[i] = []
|
||||
values_bf[i] = []
|
||||
print(f"train dataset:{obs.size()[0]}/{args.datasetSize}")
|
||||
state, done = next_state, next_done
|
||||
i += 1
|
||||
|
||||
if args.train:
|
||||
# flatten the batch
|
||||
@ -317,15 +423,15 @@ if __name__ == "__main__":
|
||||
b_advantages = advantages.reshape(-1)
|
||||
b_returns = returns.reshape(-1)
|
||||
b_values = values.reshape(-1)
|
||||
|
||||
b_size = b_obs.size()[0]
|
||||
# Optimizing the policy and value network
|
||||
b_inds = np.arange(args.batch_size)
|
||||
b_inds = np.arange(b_size)
|
||||
# clipfracs = []
|
||||
for epoch in range(args.epochs):
|
||||
# shuffle all datasets
|
||||
np.random.shuffle(b_inds)
|
||||
for start in range(0, args.batch_size, args.minibatch_size):
|
||||
end = start + args.minibatch_size
|
||||
for start in range(0, b_size, args.minibatchSize):
|
||||
end = start + args.minibatchSize
|
||||
mb_inds = b_inds[start:end]
|
||||
mb_advantages = b_advantages[mb_inds]
|
||||
|
||||
|
@ -434,41 +434,119 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"x = torch.randn(2, 3).to(\"cuda\")\n",
|
||||
"print(x)\n",
|
||||
"print(torch.cat((x, x, x), 0))\n",
|
||||
"print(torch.cat((x, x, x), 1))\n",
|
||||
"\n",
|
||||
"aa = torch.empty(0).to(\"cuda\")\n",
|
||||
"torch.cat([aa,x])\n",
|
||||
"bb = [[]]*2\n",
|
||||
"print(bb)\n",
|
||||
"bb.append(x.to(\"cpu\").tolist())\n",
|
||||
"bb.append(x.to(\"cpu\").tolist())\n",
|
||||
"print(bb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"x : torch.Size([2, 3, 4])\n",
|
||||
"x : torch.Size([6, 2, 3, 4])\n",
|
||||
"x : torch.Size([6, 2, 3, 4])\n"
|
||||
"tensor([[-1.1090, 0.4686, 0.6883],\n",
|
||||
" [-0.1862, -0.3943, -0.0202],\n",
|
||||
" [ 0.1436, -0.9444, -1.2079],\n",
|
||||
" [-2.9434, -2.5989, -0.6653],\n",
|
||||
" [ 0.4668, 0.8548, -0.4641],\n",
|
||||
" [-0.3956, -0.2832, -0.1889],\n",
|
||||
" [-0.2801, -0.2092, 1.7254],\n",
|
||||
" [ 2.7938, -0.7742, 0.7053]], device='cuda:0')\n",
|
||||
"(8, 0)\n",
|
||||
"---\n",
|
||||
"[[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)], [array([-0.1861974 , -0.39429024, -0.02016036], dtype=float32)], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913 , -0.66532046], dtype=float32)], [array([ 0.46684313, 0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891 ], dtype=float32)], [array([-0.28008553, -0.20918302, 1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478, 0.705279 ], dtype=float32)]]\n",
|
||||
"[[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)], [], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913 , -0.66532046], dtype=float32)], [array([ 0.46684313, 0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891 ], dtype=float32)], [array([-0.28008553, -0.20918302, 1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478, 0.705279 ], dtype=float32)]]\n",
|
||||
"---\n",
|
||||
"[array([-1.1090169, 0.4685607, 0.6883437], dtype=float32), array([-1.1090169, 0.4685607, 0.6883437], dtype=float32)]\n",
|
||||
"vvv tensor([[-1.1090, 0.4686, 0.6883],\n",
|
||||
" [-1.1090, 0.4686, 0.6883]], device='cuda:0')\n",
|
||||
"tensor([[-1.1090, 0.4686, 0.6883],\n",
|
||||
" [-1.1090, 0.4686, 0.6883]], device='cuda:0')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"agent_num = 8\n",
|
||||
"ob_buffer = [[]for i in range(agent_num)]\n",
|
||||
"obs = torch.randn(8, 3).to(\"cuda\")\n",
|
||||
"print(obs)\n",
|
||||
"print(np.shape(np.array(ob_buffer)))\n",
|
||||
"print('---')\n",
|
||||
"obs_cpu = obs.to(\"cpu\").numpy()\n",
|
||||
"for i in range(agent_num):\n",
|
||||
" ob_buffer[i].append(obs_cpu[i])\n",
|
||||
"print(ob_buffer)\n",
|
||||
"ob_buffer[1] = []\n",
|
||||
"print(ob_buffer)\n",
|
||||
"print('---')\n",
|
||||
"for i in range(agent_num):\n",
|
||||
" ob_buffer[i].append(obs_cpu[i])\n",
|
||||
"print(ob_buffer[0])\n",
|
||||
"vvv = torch.tensor(ob_buffer[0]).to(\"cuda\")\n",
|
||||
"print(\"vvv\",vvv)\n",
|
||||
"empt = torch.tensor([]).to(\"cuda\")\n",
|
||||
"vvvv = torch.cat((empt,vvv),0)\n",
|
||||
"print(vvvv)\n",
|
||||
"vvvv.size()[0]>0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"start 0\n",
|
||||
"end 3\n",
|
||||
"start 3\n",
|
||||
"end 6\n",
|
||||
"start 6\n",
|
||||
"end 9\n",
|
||||
"start 9\n",
|
||||
"end 12\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"#1\n",
|
||||
"x = torch.randn(2, 1, 1)#为1可以扩展为3和4\n",
|
||||
"x = x.expand(2, 3, 4)\n",
|
||||
"print('x :', x.size())\n",
|
||||
"\n",
|
||||
"#2\n",
|
||||
"#扩展一个新的维度必须在最前面,否则会报错\n",
|
||||
"#x = x.expand(2, 3, 4, 6)\n",
|
||||
"\n",
|
||||
"x = x.expand(6, 2, 3, 4)\n",
|
||||
"print('x :', x.size())\n",
|
||||
"\n",
|
||||
"#3\n",
|
||||
"#某一个维度为-1表示不改变该维度的大小\n",
|
||||
"x = x.expand(6, -1, -1, -1)\n",
|
||||
"print('x :', x.size())\n",
|
||||
"\n",
|
||||
"x : torch.Size([2, 3, 4])\n",
|
||||
"x : torch.Size([6, 2, 3, 4])\n",
|
||||
"x : torch.Size([6, 2, 3, 4])"
|
||||
"for i in range(0,10,3):\n",
|
||||
" print(\"start\",i)\n",
|
||||
" print('end',i+3)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user