Change Param based on a Paper
Change Param based on a Paper, and it work!
This commit is contained in:
parent
3116831ae6
commit
0e0d98d8b1
@ -38,27 +38,27 @@ BASE_PORT = 1000
|
||||
TOTAL_STEPS = 3150000
|
||||
BATCH_SIZE = 1024
|
||||
MAX_TRAINNING_DATASETS = 6000
|
||||
DECISION_PERIOD = 2
|
||||
FREEZE_HEAD_NETWORK = False
|
||||
LEARNING_RATE = 1e-3
|
||||
DECISION_PERIOD = 1
|
||||
LEARNING_RATE = 5e-4
|
||||
GAMMA = 0.99
|
||||
GAE_LAMBDA = 0.9
|
||||
EPOCHS = 2
|
||||
CLIP_COEF = 0.1
|
||||
GAE_LAMBDA = 0.95
|
||||
EPOCHS = 3
|
||||
CLIP_COEF = 0.11
|
||||
LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
|
||||
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
|
||||
ENTROPY_COEF = [1.0, 1.0, 1.0, 1.0]
|
||||
ENTROPY_COEF = [0.1, 0.1, 0.1, 0.1]
|
||||
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
|
||||
TARGET_LEARNING_RATE = 1e-6
|
||||
FREEZE_VIEW_NETWORK = False
|
||||
|
||||
ANNEAL_LEARNING_RATE = True
|
||||
CLIP_VLOSS = True
|
||||
NORM_ADV = True
|
||||
TRAIN = True
|
||||
|
||||
WANDB_TACK = True
|
||||
WANDB_TACK = False
|
||||
LOAD_DIR = None
|
||||
#LOAD_DIR = "../PPO-Model/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670634636-freeonly-14/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670634636_-0.35597783.pt"
|
||||
# LOAD_DIR = "../PPO-Model/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670986948-freeonly-20/Aimbot_Target_Hybrid_PMNN_V2_OffPolicy_EndBC_9331_1670986948_0.7949778.pt"
|
||||
|
||||
# public data
|
||||
class Targets(Enum):
|
||||
@ -67,11 +67,12 @@ class Targets(Enum):
|
||||
Attack = 2
|
||||
Defence = 3
|
||||
Num = 4
|
||||
TARGET_STATE_SIZE = 7 # 6+1
|
||||
TARGET_STATE_SIZE = 6
|
||||
INAREA_STATE_SIZE = 1
|
||||
TIME_STATE_SIZE = 1
|
||||
GUN_STATE_SIZE = 1
|
||||
MY_STATE_SIZE = 4
|
||||
TOTAL_T_STATE_SIZE = TARGET_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE
|
||||
TOTAL_T_SIZE = TARGET_STATE_SIZE+INAREA_STATE_SIZE+TIME_STATE_SIZE+GUN_STATE_SIZE+MY_STATE_SIZE
|
||||
BASE_WINREWARD = 999
|
||||
BASE_LOSEREWARD = -999
|
||||
TARGETNUM= 4
|
||||
@ -107,8 +108,8 @@ def parse_args():
|
||||
# model parameters
|
||||
parser.add_argument("--train",type=lambda x: bool(strtobool(x)), default=TRAIN, nargs="?", const=True,
|
||||
help="Train Model or not")
|
||||
parser.add_argument("--freeze-headnet", type=lambda x: bool(strtobool(x)), default=FREEZE_HEAD_NETWORK, nargs="?", const=True,
|
||||
help="freeze head network or not")
|
||||
parser.add_argument("--freeze-viewnet", type=lambda x: bool(strtobool(x)), default=FREEZE_VIEW_NETWORK, nargs="?", const=True,
|
||||
help="freeze view network or not")
|
||||
parser.add_argument("--datasetSize", type=int, default=MAX_TRAINNING_DATASETS,
|
||||
help="training dataset size,start training while dataset collect enough data")
|
||||
parser.add_argument("--minibatchSize", type=int, default=BATCH_SIZE,
|
||||
@ -166,70 +167,73 @@ class PPOAgent(nn.Module):
|
||||
def __init__(self, env: Aimbot,targetNum:int):
|
||||
super(PPOAgent, self).__init__()
|
||||
self.targetNum = targetNum
|
||||
self.stateSize = env.unity_observation_shape[0]
|
||||
self.agentNum = env.unity_agent_num
|
||||
self.targetSize = TARGET_STATE_SIZE
|
||||
self.timeSize = TIME_STATE_SIZE
|
||||
self.gunSize = GUN_STATE_SIZE
|
||||
self.myStateSize = MY_STATE_SIZE
|
||||
self.totalTSize = TOTAL_T_STATE_SIZE
|
||||
self.targetInputSize = TOTAL_T_STATE_SIZE - TIME_STATE_SIZE - 1 # all target except time and target state
|
||||
self.totalRaySize = env.unity_observation_shape[0] - TOTAL_T_STATE_SIZE
|
||||
self.criticInputSize = env.unity_observation_shape[0] - TIME_STATE_SIZE - 1 # all except time and target state
|
||||
self.raySize = env.unity_observation_shape[0] - TOTAL_T_SIZE
|
||||
self.nonRaySize = TOTAL_T_SIZE
|
||||
self.head_input_size = env.unity_observation_shape[0] - self.targetSize-self.timeSize-self.gunSize# except target state input
|
||||
|
||||
self.discrete_size = env.unity_discrete_size
|
||||
self.discrete_shape = list(env.unity_discrete_branches)
|
||||
self.continuous_size = env.unity_continuous_size
|
||||
|
||||
self.viewNetwork = nn.Sequential(
|
||||
layer_init(nn.Linear(self.totalRaySize, 200)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(self.raySize, 200)),
|
||||
nn.Tanh()
|
||||
)
|
||||
self.targetNetworks = nn.ModuleList([nn.Sequential(
|
||||
layer_init(nn.Linear(self.targetInputSize,128)),
|
||||
layer_init(nn.Linear(self.nonRaySize, 100)),
|
||||
nn.Tanh()
|
||||
)for i in range(targetNum)])
|
||||
self.middleNetworks = nn.ModuleList([nn.Sequential(
|
||||
layer_init(nn.Linear(328,256)),
|
||||
nn.Softplus()
|
||||
layer_init(nn.Linear(300,200)),
|
||||
nn.Tanh()
|
||||
)for i in range(targetNum)])
|
||||
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(256, self.discrete_size), std=0.5) for i in range(targetNum)])
|
||||
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(256, self.continuous_size), std=0) for i in range(targetNum)])
|
||||
self.actor_dis = nn.ModuleList([layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(targetNum)])
|
||||
self.actor_mean = nn.ModuleList([layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(targetNum)])
|
||||
# self.actor_logstd = nn.ModuleList([layer_init(nn.Linear(256, self.continuous_size), std=1) for i in range(targetNum)])
|
||||
self.actor_logstd = nn.ParameterList([nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(targetNum)])
|
||||
self.critic = nn.ModuleList([nn.Sequential(
|
||||
layer_init(nn.Linear(self.criticInputSize, 512)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(512, 256)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(256, 1), std=0.5))for i in range(targetNum)])
|
||||
self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))
|
||||
self.critic = nn.ModuleList([layer_init(nn.Linear(200, 1), std=1)for i in range(targetNum)])
|
||||
|
||||
def get_value(self, state: torch.Tensor):
|
||||
targets = state[:,0].to(torch.int32) # int
|
||||
headInput = torch.cat([state[:,1:self.targetSize],state[:,self.targetSize+self.timeSize:]],dim=1) # except target state
|
||||
|
||||
return torch.stack([self.critic[targets[i]](headInput[i])for i in range(targets.size()[0])])
|
||||
target = state[:,0].to(torch.int32) # int
|
||||
thisStateNum = target.size()[0]
|
||||
viewInput = state[:,-self.raySize:] # all ray input
|
||||
targetInput = state[:,:self.nonRaySize]
|
||||
viewLayer = self.viewNetwork(viewInput)
|
||||
targetLayer = torch.stack([self.targetNetworks[target[i]](targetInput[i]) for i in range(thisStateNum)])
|
||||
middleInput = torch.cat([viewLayer,targetLayer],dim = 1)
|
||||
middleLayer = torch.stack([self.middleNetworks[target[i]](middleInput[i]) for i in range(thisStateNum)])
|
||||
criticV = torch.stack([self.critic[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.critic
|
||||
return criticV
|
||||
|
||||
def get_actions_value(self, state: torch.Tensor, actions=None):
|
||||
targets = state[:,0].to(torch.int32) # int
|
||||
viewInput = state[:,-self.totalRaySize:] # all ray input
|
||||
targetInput = torch.cat([state[:,1:self.targetSize],state[:,self.targetSize+self.timeSize:self.totalTSize]],dim=1) # all target except time and target intselt
|
||||
|
||||
target = state[:,0].to(torch.int32) # int
|
||||
thisStateNum = target.size()[0]
|
||||
viewInput = state[:,-self.raySize:] # all ray input
|
||||
targetInput = state[:,:self.nonRaySize]
|
||||
viewLayer = self.viewNetwork(viewInput)
|
||||
targetLayer = torch.stack([self.targetNetworks[targets[i]](targetInput[i]) for i in range(targets.size()[0])])
|
||||
targetLayer = torch.stack([self.targetNetworks[target[i]](targetInput[i]) for i in range(thisStateNum)])
|
||||
middleInput = torch.cat([viewLayer,targetLayer],dim = 1)
|
||||
middleLayer = torch.stack([self.middleNetworks[targets[i]](middleInput[i]) for i in range(targets.size()[0])])
|
||||
middleLayer = torch.stack([self.middleNetworks[target[i]](middleInput[i]) for i in range(thisStateNum)])
|
||||
|
||||
# discrete
|
||||
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
|
||||
dis_logits = torch.stack([self.actor_dis[targets[i]](middleLayer[i]) for i in range(targets.size()[0])])
|
||||
dis_logits = torch.stack([self.actor_dis[target[i]](middleLayer[i]) for i in range(thisStateNum)])
|
||||
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
|
||||
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
|
||||
# continuous
|
||||
actions_mean = torch.stack([self.actor_mean[targets[i]](middleLayer[i]) for i in range(targets.size()[0])]) # self.actor_mean(hidden)
|
||||
# action_logstd = torch.stack([self.actor_logstd[targets[i]].expand_as(actions_mean) for i in range(targets.size()[0])]) # self.actor_logstd.expand_as(actions_mean)
|
||||
actions_mean = torch.stack([self.actor_mean[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.actor_mean(hidden)
|
||||
action_logstd = self.actor_logstd.expand_as(actions_mean) # self.actor_logstd.expand_as(actions_mean)
|
||||
# print(action_logstd)
|
||||
action_std = torch.squeeze(torch.stack([torch.exp(self.actor_logstd[targets[i]]) for i in range(targets.size()[0])]),dim = -1) # torch.exp(action_logstd)
|
||||
action_std = torch.clamp(action_std,1e-10)
|
||||
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
|
||||
con_probs = Normal(actions_mean, action_std)
|
||||
# critic
|
||||
criticV = self.get_value(state)
|
||||
criticV = torch.stack([self.critic[target[i]](middleLayer[i]) for i in range(thisStateNum)]) # self.critic
|
||||
|
||||
if actions is None:
|
||||
if args.train:
|
||||
@ -369,11 +373,11 @@ if __name__ == "__main__":
|
||||
else:
|
||||
agent = torch.load(args.load_dir)
|
||||
# freeze
|
||||
if args.freeze_headnet:
|
||||
# freeze the head network
|
||||
if args.freeze_viewnet:
|
||||
# freeze the view network
|
||||
for p in agent.viewNetwork.parameters():
|
||||
p.requires_grad = False
|
||||
print("HEAD NETWORK FREEZED")
|
||||
print("VIEW NETWORK FREEZED")
|
||||
print("Load Agent", args.load_dir)
|
||||
print(agent.eval())
|
||||
|
||||
@ -489,7 +493,7 @@ if __name__ == "__main__":
|
||||
thisRewardsTensor,
|
||||
torch.Tensor(dones_bf[i]).to(device),
|
||||
torch.tensor(values_bf[i]).to(device),
|
||||
torch.Tensor(next_state[i]).to(device).unsqueeze(dim = 0),
|
||||
torch.tensor(next_state[i]).to(device).unsqueeze(0),
|
||||
torch.Tensor([next_done[i]]).to(device),
|
||||
)
|
||||
# send memories to training datasets
|
||||
@ -522,7 +526,7 @@ if __name__ == "__main__":
|
||||
trainQueue.append(i)
|
||||
if(len(trainQueue)>0):
|
||||
break
|
||||
# state, done = next_state, next_done
|
||||
state, done = next_state, next_done
|
||||
else:
|
||||
step += 1
|
||||
# skip this step use last predict action
|
||||
@ -625,11 +629,9 @@ if __name__ == "__main__":
|
||||
# discrete ratio
|
||||
dis_logratio = new_dis_logprob - b_dis_logprobs[mb_inds]
|
||||
dis_ratio = dis_logratio.exp()
|
||||
# dis_ratio = (new_dis_logprob / (b_dis_logprobs[mb_inds]+1e-8)).mean()
|
||||
# continuous ratio
|
||||
con_logratio = new_con_logprob - b_con_logprobs[mb_inds]
|
||||
con_ratio = con_logratio.exp()
|
||||
# con_ratio = (new_con_logprob / (b_con_logprobs[mb_inds]+1e-8)).mean()
|
||||
|
||||
"""
|
||||
# early stop
|
||||
@ -673,10 +675,22 @@ if __name__ == "__main__":
|
||||
loss = (
|
||||
dis_pg_loss * POLICY_COEF[thisT]
|
||||
+ con_pg_loss * POLICY_COEF[thisT]
|
||||
- entropy_loss * ENTROPY_COEF[thisT]
|
||||
+ entropy_loss * ENTROPY_COEF[thisT]
|
||||
+ v_loss * CRITIC_COEF[thisT]
|
||||
)*LOSS_COEF[thisT]
|
||||
|
||||
if(torch.isnan(loss).any()):
|
||||
print("LOSS Include NAN!!!")
|
||||
if(torch.isnan(dis_pg_loss.any())):
|
||||
print("dis_pg_loss include nan")
|
||||
if(torch.isnan(con_pg_loss.any())):
|
||||
print("con_pg_loss include nan")
|
||||
if(torch.isnan(entropy_loss.any())):
|
||||
print("entropy_loss include nan")
|
||||
if(torch.isnan(v_loss.any())):
|
||||
print("v_loss include nan")
|
||||
raise
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clips gradient norm of an iterable of parameters.
|
||||
|
@ -833,18 +833,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "TypeError",
|
||||
"evalue": "new(): data must be a sequence (got bool)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_42068\\1624049819.py\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdistributions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormal\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mNormal\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0maaa\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'cuda'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 6\u001b[0m \u001b[0maaa\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;31mTypeError\u001b[0m: new(): data must be a sequence (got bool)"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([False, True, False], device='cuda:0')\n",
|
||||
"tensor(True, device='cuda:0')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -853,8 +850,8 @@
|
||||
"import numpy as np\n",
|
||||
"from torch.distributions.normal import Normal\n",
|
||||
"\n",
|
||||
"aaa = torch.Tensor(True).to('cuda').unsqueeze(0)\n",
|
||||
"aaa"
|
||||
"print(torch.isnan(torch.tensor([1,float('nan'),2]).to(\"cuda\")))\n",
|
||||
"print(torch.isnan(torch.tensor([1,float('nan'),2]).to(\"cuda\")).any())"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user