Compare commits
1 Commits
OffP-FullM
...
OnPolicy
Author | SHA1 | Date | |
---|---|---|---|
1e974ada2a |
@ -12,12 +12,13 @@ class Aimbot(gym.Env):
|
||||
envPath: str,
|
||||
workerID: int = 1,
|
||||
basePort: int = 100,
|
||||
side_channels: list = []
|
||||
):
|
||||
super(Aimbot, self).__init__()
|
||||
self.env = UnityEnvironment(
|
||||
file_name=envPath,
|
||||
seed=1,
|
||||
side_channels=[],
|
||||
side_channels=side_channels,
|
||||
worker_id=workerID,
|
||||
base_port=basePort,
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ import wandb
|
||||
import time
|
||||
import numpy as np
|
||||
import random
|
||||
import uuid
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
@ -12,15 +13,24 @@ from torch.distributions.normal import Normal
|
||||
from torch.distributions.categorical import Categorical
|
||||
from distutils.util import strtobool
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from mlagents_envs.environment import UnityEnvironment
|
||||
from mlagents_envs.side_channel.side_channel import (
|
||||
SideChannel,
|
||||
IncomingMessage,
|
||||
OutgoingMessage,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
bestReward = 0
|
||||
|
||||
DEFAULT_SEED = 9331
|
||||
ENV_PATH = "../Build/Build-ParallelEnv-BigArea-6Enemy/Aimbot-ParallelEnv"
|
||||
ENV_PATH = "../Build/Build-ParallelEnv-Target-OffPolicy-SingleStack-SideChannel/Aimbot-ParallelEnv"
|
||||
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
|
||||
WAND_ENTITY = "koha9"
|
||||
WORKER_ID = 1
|
||||
BASE_PORT = 1000
|
||||
|
||||
# !!!check every parameters before run!!!
|
||||
|
||||
TOTAL_STEPS = 2000000
|
||||
STEP_NUM = 314
|
||||
@ -44,6 +54,10 @@ WANDB_TACK = False
|
||||
LOAD_DIR = None
|
||||
# LOAD_DIR = "../PPO-Model/SmallArea-256-128-hybrid-2nd-trainning.pt"
|
||||
|
||||
# public data
|
||||
TotalRounds = {"Go":0,"Attack":0,"Free":0}
|
||||
WinRounds = {"Go":0,"Attack":0,"Free":0}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# fmt: off
|
||||
@ -178,6 +192,51 @@ class PPOAgent(nn.Module):
|
||||
self.critic(hidden),
|
||||
)
|
||||
|
||||
class AimbotSideChannel(SideChannel):
|
||||
def __init__(self, channel_id: uuid.UUID) -> None:
|
||||
super().__init__(channel_id)
|
||||
def on_message_received(self, msg: IncomingMessage) -> None:
|
||||
"""
|
||||
Note: We must implement this method of the SideChannel interface to
|
||||
receive messages from Unity
|
||||
"""
|
||||
thisMessage = msg.read_string()
|
||||
print(thisMessage)
|
||||
thisResult = thisMessage.split("|")
|
||||
if(thisResult[0] == "result"):
|
||||
TotalRounds[thisResult[1]]+=1
|
||||
if(thisResult[2] == "Win"):
|
||||
WinRounds[thisResult[1]]+=1
|
||||
print(TotalRounds)
|
||||
print(WinRounds)
|
||||
elif(thisResult[0] == "Error"):
|
||||
print(thisMessage)
|
||||
# 发送函数
|
||||
def send_string(self, data: str) -> None:
|
||||
"""发送一个字符串给C#"""
|
||||
msg = OutgoingMessage()
|
||||
msg.write_string(data)
|
||||
super().queue_message_to_send(msg)
|
||||
|
||||
def send_bool(self, data: bool) -> None:
|
||||
msg = OutgoingMessage()
|
||||
msg.write_bool(data)
|
||||
super().queue_message_to_send(msg)
|
||||
|
||||
def send_int(self, data: int) -> None:
|
||||
msg = OutgoingMessage()
|
||||
msg.write_int32(data)
|
||||
super().queue_message_to_send(msg)
|
||||
|
||||
def send_float(self, data: float) -> None:
|
||||
msg = OutgoingMessage()
|
||||
msg.write_float32(data)
|
||||
super().queue_message_to_send(msg)
|
||||
|
||||
def send_float_list(self, data: List[float]) -> None:
|
||||
msg = OutgoingMessage()
|
||||
msg.write_float32_list(data)
|
||||
super().queue_message_to_send(msg)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
@ -188,7 +247,8 @@ if __name__ == "__main__":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
|
||||
|
||||
# Initialize environment anget optimizer
|
||||
env = Aimbot(envPath=args.path, workerID=args.workerID, basePort=args.baseport)
|
||||
aimBotsideChannel = AimbotSideChannel(SIDE_CHANNEL_UUID);
|
||||
env = Aimbot(envPath=args.path, workerID=args.workerID, basePort=args.baseport,side_channels=[aimBotsideChannel])
|
||||
if args.load_dir is None:
|
||||
agent = PPOAgent(env).to(device)
|
||||
else:
|
||||
@ -424,6 +484,9 @@ if __name__ == "__main__":
|
||||
"charts/SPS", int(global_step / (time.time() - start_time)), global_step
|
||||
)
|
||||
writer.add_scalar("charts/Reward", rewardsMean, global_step)
|
||||
writer.add_scalar("charts/GoWinRatio", WinRounds["Go"]/TotalRounds["Go"] if TotalRounds["Go"] != 0 else 0, global_step)
|
||||
writer.add_scalar("charts/AttackWinRatio", WinRounds["Attack"]/TotalRounds["Attack"] if TotalRounds["Attack"] != 0 else 0, global_step)
|
||||
writer.add_scalar("charts/FreeWinRatio", WinRounds["Free"]/TotalRounds["Free"] if TotalRounds["Free"] != 0 else 0, global_step)
|
||||
if rewardsMean > bestReward:
|
||||
bestReward = rewardsMean
|
||||
saveDir = "../PPO-Model/bigArea-384-128-hybrid-" + str(rewardsMean) + ".pt"
|
||||
|
Loading…
Reference in New Issue
Block a user