Aimbot-PPO/Aimbot-PPO-MultiScene/Assets/Script/InGame/AgentWithGun.cs
Koha9 818928a5aa Add Gun State, fix PPO GAIL class bug
Add Gun state
fix PPO GAIL class errors
2022-10-23 23:38:07 +09:00

680 lines
24 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.Reflection;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using UnityEditor;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using XCharts;
using XCharts.Runtime;
/*主要ML-Agent控制*/
public class AgentWithGun : Agent
{
public GameObject thisAgentObj;
public Transform thisAgent;
public Camera thisCam;
public CharacterController PlayerController;
public GameObject enemyPrefab;
public GameObject cameraChangerOBJ;
[Header("Rewards")]
[Tooltip("Nothing happened reward")]
public float nonRewardDefault = -0.05f;
[Tooltip("Agent Do shoot action reward")]
public float shootRewardDefault = -0.1f;
[Tooltip("Agent Do shoot action but gun is not read")]
public float shootWithoutReadyRewardDefault = -1.0f;
[Tooltip("Hit Enemy reward")]
public float hitRewardDefault = 2.0f;
[Tooltip("Episode Win reward")]
public float winRewardDefault = 10.0f;
[Tooltip("Episode Lose reward")]
public float loseRewardDefault = -10.0f;
[Tooltip("Enemy down reward")]
public float killRewardDefault = 5.0f;
[Header("Env")]
public bool lockMouse = false;
public float Damage = 50; // damage to enemy
public float fireRate = 0.5f;
public int enemyNum = 3;
public int timeLimit = 30;
public bool lockCameraX = false;
public bool lockCameraY = true;
//public Vector3 startPosition = new Vector3(9, 1, 18);
public int minEnemyAreaX = -12;
public int maxEnemyAreaX = 11;
public int minEnemyAreaY = -20;
public int maxEnemyAreaY = 20;
public int minAgentAreaX = -12;
public int maxAgentAreaX = 11;
public int minAgentAreaY = -28;
public int maxAgentAreaY = -22;
[Header("GetAxis() Simulate")]
public float MoveSpeed = 2.0f;
public float vX = 0f;
public float vZ = 0f;
public float acceleration = 0.1f; // 加速度
public float mouseXSensitivity = 100;
public float mouseYSensitivity = 200;
public float yRotation = 0.1f;//定义一个浮点类型的量记录围绕X轴旋转的角度
private float startTime = 0;
private int shoot = 0;
private float lastShootTime = 0.0f;
private int nowEnemyNum = 0;
private int enemyKillCount = 0;
private int step = 0;
private int EP = 0;
private string LoadDirDate;
private string LoadDirTime;
private float LoadDirDateF;
private float loadDirTimeF;
public bool defaultTPCamera = true;
private bool gunReadyToggle = true;
private StartSeneData DataTransfer;
private UIController UICon;
private HistoryRecorder HistoryRec;
private RaySensors rayScript;
private CameraChange camChanger;
[System.NonSerialized] public float nonReward;
[System.NonSerialized] public float shootReward;
[System.NonSerialized] public float shootWithoutReadyReward;
[System.NonSerialized] public float hitReward;
[System.NonSerialized] public float winReward;
[System.NonSerialized] public float loseReward;
[System.NonSerialized] public float killReward;
[System.NonSerialized] public float saveNow = 0;
[System.NonSerialized] public int remainTime;
void Start()
{
try
{
// get DataTranfer
DataTransfer = GameObject.Find("StartSeneDataTransfer").GetComponent<StartSeneData>();
// Enemy Num
enemyNum = DataTransfer.EnemyNum;
// Time Limit
timeLimit = DataTransfer.Timelim;
// get load directory.
LoadDirDate = DataTransfer.LoadDirDate;
LoadDirTime = DataTransfer.LoadDirTime;
LoadDirDateF = float.Parse(LoadDirDate);
loadDirTimeF = float.Parse(LoadDirTime);
// get Default reward.
nonRewardDefault = DataTransfer.nonReward;
shootRewardDefault = DataTransfer.shootReward;
shootWithoutReadyRewardDefault = DataTransfer.shootWithoutReadyReward;
hitRewardDefault = DataTransfer.hitReward;
killRewardDefault = DataTransfer.killReward;
winRewardDefault = DataTransfer.winReward;
loseRewardDefault = DataTransfer.loseReward;
lockMouse = DataTransfer.lockMouse;
defaultTPCamera = DataTransfer.defaultTPCamera;
// change Decision Period & Take Actions Between Decisions
transform.GetComponent<DecisionRequester>().DecisionPeriod = DataTransfer.DecisionPeriod;
transform.GetComponent<DecisionRequester>().TakeActionsBetweenDecisions = DataTransfer.ActionsBetweenDecisions;
}
catch (NullReferenceException)
{
// Enemy Num
enemyNum = 3;
// Time Limit
timeLimit = 30;
// get load directory.
LoadDirDate = "0";
LoadDirTime = "0";
LoadDirDateF = float.Parse(LoadDirDate);
loadDirTimeF = float.Parse(LoadDirTime);
// get Default reward.
nonRewardDefault = -0.05f;
shootRewardDefault = -0.06f;
shootWithoutReadyRewardDefault = -0.06f;
hitRewardDefault = 5.0f;
killRewardDefault = 10.0f;
winRewardDefault = 20.0f;
loseRewardDefault = -10.0f;
// change Decision Period & Take Actions Between Decisions
transform.GetComponent<DecisionRequester>().DecisionPeriod = 1;
transform.GetComponent<DecisionRequester>().TakeActionsBetweenDecisions = true;
}
finally
{
UICon = transform.GetComponent<UIController>();
HistoryRec = transform.GetComponent<HistoryRecorder>();
rayScript = GetComponent<RaySensors>();
camChanger = cameraChangerOBJ.GetComponent<CameraChange>();
// give default Reward to Reward value will be used.
nonReward = nonRewardDefault;
shootReward = shootRewardDefault;
shootWithoutReadyReward = shootWithoutReadyRewardDefault;
hitReward = hitRewardDefault;
winReward = winRewardDefault;
loseReward = loseRewardDefault;
killReward = killRewardDefault;
//initialize remainTime
remainTime = (int)(timeLimit - Time.time + startTime);
// change default camera view
if (defaultTPCamera)
{
camChanger.ShowTPSView();
}
else
{
camChanger.ShowFPSView();
}
}
}
/* ----------此Update用于debugBuild前删除或注释掉----------*/
/*void Update()
{
//Debug.Log(RaySensors.rayTagResult[0]);
}*/
/* ----------此Update用于debugBuild前删除或注释掉----------*/
// --------------初始化---------------
// randomInitEnemys随机生成enemy
public void randomInitEnemys(int EnemyNum)
{
for (int i = 0; i < EnemyNum; i++)
{
int randX = UnityEngine.Random.Range(minEnemyAreaX, maxEnemyAreaX);
int randZ = UnityEngine.Random.Range(minEnemyAreaY, maxEnemyAreaY);
int Y = 1;
Instantiate(enemyPrefab, new Vector3(randX, Y, randZ), Quaternion.identity);
}
}
// --------------初始化---------------
// randomInitAgent随机位置初始化Agent
public void randomInitAgent()
{
int randX = UnityEngine.Random.Range(minAgentAreaX, maxAgentAreaX);
int randZ = UnityEngine.Random.Range(minAgentAreaY, maxAgentAreaY);
int Y = 1;
Vector3 initAgentLoc = new Vector3(randX, Y, randZ);
thisAgent.localPosition = initAgentLoc;
}
// ------------动作处理--------------
// moveAgent 用于模拟Input.GetAxis移动
public void moveAgent(int vertical, int horizontal)
{
Vector3 thisMovement;
if (horizontal != 0)//当按下按键(水平方向)
{
if (vX < MoveSpeed && vX > -MoveSpeed)//当前速度小于最大速度
{
vX += (float)horizontal * acceleration;//增加加速度
}
else
{
//防止在一瞬间切换输入时速度仍保持不变
if ((vX * horizontal) > 0)//输入与当前速度方向同向
{
vX = (float)horizontal * MoveSpeed; //限制最大速度
}
else
{
vX += (float)horizontal * acceleration;//增加加速度
}
}
}
else
{
if (Math.Abs(vX) > 0.001)
{
vX -= (vX / Math.Abs(vX)) * acceleration;//减少加速度
}
else
{
vX = 0;
}
}
if (vertical != 0)//当按下按键(垂直方向)
{
if (vZ < MoveSpeed && vZ > -MoveSpeed)//当前速度小于最大速度
{
vZ += (float)vertical * acceleration;//增加加速度
}
else
{
if ((vZ * vertical) > 0)//输入与当前速度方向同向
{
vZ = (float)vertical * MoveSpeed; //限制最大速度
}
else
{
vZ += (float)vertical * acceleration;//增加加速度
}
}
}
else
{
if (Math.Abs(vZ) > 0.001)
{
vZ -= (vZ / Math.Abs(vZ)) * acceleration;//减少加速度
}
else
{
vZ = 0;
}
}
thisMovement = (transform.forward * vZ + transform.right * vX) * MoveSpeed;
//PlayerController下的.Move为实现物体运动的函数
//Move()括号内放入一个Vector3类型的量本例中为Player_Move
PlayerController.Move(thisMovement * Time.deltaTime);
// update Key Viewer
}
// ------------动作处理--------------
// cameraControl 用于控制Agent视角转动
public void cameraControl(float Mouse_X, float Mouse_Y)
{
//Mouse_X = Input.GetAxis("Mouse X") * MouseSensitivity * Time.deltaTime;
//Debug.Log(Input.GetAxis("Mouse X"));
//Mouse_Y = Input.GetAxis("Mouse Y") * MouseSensitivity * Time.deltaTime;
if (lockCameraX)
{
Mouse_X = 0;
}
if (lockCameraY)
{
Mouse_Y = 0;
}
yRotation = yRotation - Mouse_Y;
//xRotation值为正时屏幕下移当xRotation值为负时屏幕上移
//当鼠标向上滑动Mouse_Y值为正,xRotation-Mouse_Y的值为负,xRotation总的值为负屏幕视角向上滑动
//当鼠标向下滑动Mouse_Y值为负,xRotation-Mouse_Y的值为正,xRotation总的值为正屏幕视角向下滑动
//简单来说就是要控制鼠标滑动的方向与屏幕移动的方向要相同
//limit UP DOWN between -90 -> 90
yRotation = Mathf.Clamp(yRotation, -90f, 90f);
//相机左右旋转时是以Y轴为中心旋转的上下旋转时是以X轴为中心旋转的
thisAgent.Rotate(Vector3.up * Mouse_X);
//Vector3.up相当于Vector3(0,1,0),CameraRotation.Rotate(Vector3.up * Mouse_X)相当于使CameraRotation对象绕y轴旋转Mouse_X个单位
//即相机左右旋转时是以Y轴为中心旋转的此时Mouse_X控制着值的大小
//相机在上下旋转移动时,相机方向不会随着移动,类似于低头和抬头,左右移动时,相机方向会随着向左向右移动,类似于向左向右看
//所以在控制相机向左向右旋转时,要保证和父物体一起转动
thisCam.transform.localRotation = Quaternion.Euler(yRotation, 0, 0);
//this.transform指这个CameraRotation的位置,localRotation指的是旋转轴
//transform.localRotation = Quaternion.Eular(x,y,z)控制旋转的时候按照X-Y-Z轴的旋转顺规
//即以围绕X轴旋转x度围绕Y轴旋转y度围绕Z轴旋转z度
//且绕轴旋转的坐标轴是父节点本地坐标系的坐标轴
}
// GotKill 获得击杀时用于被呼出
public void GotKill()
{
enemyKillCount += 1;
}
// check gun is ready to shoot
bool gunReady()
{
if ((Time.time - lastShootTime) >= fireRate)
{
return true;
}
else
{
return false;
}
}
// ballistic 射击弹道处理并返回获得reward
float ballistic()
{
Vector3 point = new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0);//发射位置
Ray ray = thisCam.ScreenPointToRay(point);
RaycastHit hit;
Debug.DrawRay(ray.origin, ray.direction * 100, Color.blue);
UICon.updateShootKeyViewer(shoot, gunReadyToggle);
//按下鼠标左键
if (shoot != 0 && gunReadyToggle == true)
{
lastShootTime = Time.time;
if (Physics.Raycast(ray, out hit, 100))
{
if (hit.collider.tag == "Enemy")
{
GameObject gotHitObj = hit.transform.gameObject;//获取受到Ray撞击的对象
gotHitObj.GetComponent<Enemy>().ReactToHit(Damage, thisAgentObj);
shoot = 0;
return hitReward;
}
}
shoot = 0;
return shootReward;
}
else if (shoot != 0 && gunReadyToggle == false)
{
shoot = 0;
return shootWithoutReadyReward;
}
else
{
shoot = 0;
return nonReward;
}
}
// destroyEnemy消除除了自己以外的所有Enemy
public void destroyAllEnemys()
{
GameObject[] EnemyGameObjs;
EnemyGameObjs = GameObject.FindGameObjectsWithTag("Enemy");
//遍历所有Enemy
foreach (GameObject EnemyObj in EnemyGameObjs)
{
Vector3 thisEnemyPosition = EnemyObj.transform.position;
Vector3 thisEnemyScale = EnemyObj.transform.localScale;
Vector3 MyselfPosition = thisAgent.position;
//探测到Agent为自己时的处理
if (thisEnemyPosition == MyselfPosition)
{
//Debug.Log("OH It's me");
}
else
{
Destroy(EnemyObj);
}
}
}
// checkFinish 检查是否结束回合返回int值
// 1 = success,2 = overtime,0 = notover
int checkFinish()
{
GameObject[] EnemyGameObjs;
EnemyGameObjs = GameObject.FindGameObjectsWithTag("Enemy");
if (EnemyGameObjs.Length <= 1)
{
//成功击杀所有Enemy
return 1;
}
else if (Time.time - startTime >= timeLimit)
{
//超时失败
return 2;
}
else
{
return 0;
}
}
// getEnemyNum 获取现场除了自己以外的敌人数量
int getEnemyNum()
{
int enemyNum = 0;
GameObject[] EnemyGameObjs;
EnemyGameObjs = GameObject.FindGameObjectsWithTag("Enemy");
//遍历所有Enemy
foreach (GameObject EnemyObj in EnemyGameObjs)
{
Vector3 thisEnemyPosition = EnemyObj.transform.position;
Vector3 thisEnemyScale = EnemyObj.transform.localScale;
Vector3 MyselfPosition = thisAgent.position;
//探测到Agent为自己时的处理
if (thisEnemyPosition == MyselfPosition)
{
//Debug.Log("OH It's me");
}
else
{
enemyNum += 1;
}
}
return enemyNum;
}
// enemyNumDiff 获取与上一把相比敌人数量的区别
int enemyNumDiff()
{
int diff = 0;
int nowEnemyNum = getEnemyNum();
diff = enemyNum - nowEnemyNum;
return diff;
}
// ------------Reward--------------
// rewardCalculate 计算本动作的Reward
public float rewardCalculate()
{
float epreward = 0f;
// 击杀reward判断
if (enemyKillCount > 0)
{
for (int i = 0; i < enemyKillCount; i++)
{
epreward += killReward;
}
enemyKillCount = 0;
}
else
{
enemyKillCount = 0;
}
// 射击动作reward判断
epreward += ballistic();
return epreward;
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// env开始执行初始化
public override void OnEpisodeBegin()
{
step = 0;
if (EP == 0)
{
UICon.iniChart();
}
if (lockMouse)
{
Cursor.lockState = CursorLockMode.Locked; // hide and lock the mouse
}
//iniCharts();
thisAgentObj.name = thisAgentObj.GetInstanceID().ToString();
destroyAllEnemys();
startTime = Time.time;// Reset StartTime as now time
randomInitAgent();
randomInitEnemys(enemyNum);
nowEnemyNum = getEnemyNum(); // Reset Enemy number
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 观察情报
public override void CollectObservations(VectorSensor sensor)
{
//List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
//List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
rayScript.updateRayInfo();
float[] myObserve = { thisAgent.position.x, thisAgent.position.y, thisAgent.position.z, thisAgent.rotation.w };
float[] rayTagResult = rayScript.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1)
float[] rayDisResult = rayScript.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1)
//float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(allEnemyNum); // 敌人数量 int
sensor.AddObservation(myObserve); // 自机位置xyz+朝向 float[](4,1)
sensor.AddObservation(rayTagResult); // 探测用RayTag结果 float[](raySensorNum,1)
sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1)
//sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(raySensorNum); // raySensor数量 int
gunReadyToggle = gunReady();
sensor.AddObservation(gunReadyToggle); // save gun is ready?
sensor.AddObservation(LoadDirDateF); // 用于loadModel的第一级dir
sensor.AddObservation(loadDirTimeF); // 用于loadModel的第二级dir
sensor.AddObservation(saveNow); // sent saveNow Toggle to python let agent save weights
saveNow = 0; // reset saveNow Toggle
//sensor.AddObservation(remainTime); // RemainTime int
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// agent 输入处理
public override void OnActionReceived(ActionBuffers actionBuffers)
{
//获取输入
int vertical = actionBuffers.DiscreteActions[0];
int horizontal = actionBuffers.DiscreteActions[1];
int mouseShoot = actionBuffers.DiscreteActions[2];
float Mouse_X = actionBuffers.ContinuousActions[0];
if (vertical == 2) vertical = -1;
if (horizontal == 2) horizontal = -1;
remainTime = (int)(timeLimit - Time.time + startTime);
//应用输入
shoot = mouseShoot;
HistoryRec.realTimeKeyCounter(vertical, horizontal, shoot);
(int kWCount, int kSCount, int kACount, int kDCount, int shootCount) = HistoryRec.getKeyCount();
UICon.updateRemainTime(remainTime);
UICon.updateRemainEnemy(enemyNum);
UICon.updateWASDKeyViewer(vertical, horizontal);
UICon.updateKeyCounterChart(kWCount, kSCount, kACount, kDCount, shootCount);
UICon.updateMouseMovementViewer(Mouse_X);
UICon.updateRewardViewer(nonReward, shootReward, shootWithoutReadyReward, hitReward, winReward, loseReward, killReward);
cameraControl(Mouse_X, 0);
moveAgent(vertical, horizontal);
float thisRoundReward = rewardCalculate();
//判断结束
int finished = checkFinish();
if (finished == 1)
{
//Win Finished
HistoryRec.addRealTimeReward(winReward);
HistoryRec.EPTotalRewardsUpdate();
UICon.epUpdateChart(EP, HistoryRec.getLastEPTotalReward());
UICon.resetStepChart();
UICon.resetCounterChat();
EP += 1;
SetReward(winReward);
Debug.Log("reward = " + winReward);
EndEpisode();
}
else if (finished == 2)
{
//Lose Finished
HistoryRec.addRealTimeReward(loseReward);
HistoryRec.EPTotalRewardsUpdate();
UICon.epUpdateChart(EP, HistoryRec.getLastEPTotalReward());
UICon.resetStepChart();
UICon.resetCounterChat();
EP += 1;
SetReward(loseReward);
Debug.Log("reward = " + loseReward);
EndEpisode();
}
else
{
// game not over yet
HistoryRec.addRealTimeReward(thisRoundReward);
UICon.stepUpdateChart(step, thisRoundReward);
step += 1;
SetReward(thisRoundReward);
Debug.Log("reward = " + thisRoundReward);
}
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 控制调试
public override void Heuristic(in ActionBuffers actionsOut)
{
//
//-------------------BUILD
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
int vertical = 0;
int horizontal = 0;
if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S))
{
vertical = 1;
}
else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W))
{
vertical = -1;
}
else
{
vertical = 0;
}
if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A))
{
horizontal = 1;
}
else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D))
{
horizontal = -1;
}
else
{
horizontal = 0;
}
if (Input.GetMouseButton(0))
{
// Debug.Log("mousebuttonhit");
shoot = 1;
}
else
{
shoot = 0;
}
discreteActions[0] = vertical;
discreteActions[1] = horizontal;
discreteActions[2] = shoot;
//^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^
//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv
float Mouse_X = Input.GetAxis("Mouse X") * mouseXSensitivity * Time.deltaTime;
float Mouse_Y = Input.GetAxis("Mouse Y") * mouseYSensitivity * Time.deltaTime;
continuousActions[0] = Mouse_X;
//continuousActions[1] = nonReward;
//continuousActions[2] = shootReward;
//continuousActions[3] = shootWithoutReadyReward;
//continuousActions[4] = hitReward;
//continuousActions[5] = winReward;
//continuousActions[6] = loseReward;
//continuousActions[7] = killReward;
//continuousActions[1] = Mouse_Y;
//continuousActions[2] = timeLimit;
//^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^
}
}