Aimbot-ParallelEnv/Assets/Script/InGame/AgentWithGun.cs
Koha9 cfccd12820 V3.1.1 优化代码
优化可读性与规范化命名方式
2023-06-30 18:30:12 +09:00

654 lines
25 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
/*TODO:
√tag 攻击排他
√通用HP 系统
环境tag修正
以tag重置环境修正
Agent死亡时待机处理*/
public class AgentWithGun : Agent
{
public GameObject parameterContainerObj;
public GameObject environmentObj;
public GameObject enemyContainerObj;
public GameObject sceneBlockContainerObj;
public GameObject environmentUIControlObj;
public GameObject targetControllerObj;
public GameObject HUDObj;
public Camera thisCam;
[Header("GetAxis() Simulate")]
public float moveSpeed = 9.0f;
public float vX = 0f;
public float vZ = 0f;
public Vector3 thisMovement;
public float acceleration = 0.1f; // 加速度
public float mouseXSensitivity = 100;
public float mouseYSensitivity = 200;
public float yRotation = 0.1f;//定义一个浮点类型的量记录围绕X轴旋转的角度
[Header("Env")]
public bool oneHotRayTag = true;
private List<float> spinRecord = new List<float>();
private bool lockMouse;
private float damage;
private float fireRate;
private int enemyNum;
private bool lockCameraX;
private bool lockCameraY;
// environment
private int shoot = 0;
private float lastShootTime = 0.0f;
private int enemyKillCount = 0;
private Vector3 killEnemyPosition;
private int step = 0;
private int EP = 0;
public bool defaultTPCamera = true;
private bool gunReadyToggle = true;
private string myTag = "";
private float lastEnemyFacingDistance = 0f; // record last enemy facing minimum distance
private float lastTargetFacingDistance = 0f; // record last target facing minimum distance
// scripts
private RaySensors raySensors;
private CharacterController playerController;
private EnvironmentUIControl envUICon;
private ParameterContainer paramContainer;
private SceneBlockContainer blockContainer;
private EnemyContainer eneContainer;
private TargetController targetCon;
private HUDController hudController;
private StartSeneData startSceneData;
// observation
private float[] myObserve = new float[4];
private float[] rayTagResult;
private float[] rayTagResultOnehot;
private float[] rayDisResult;
private float[] targetStates;
private float remainTime;
private float inAreaState;
[System.NonSerialized] public int finishedState;
// start scene datas 0=train 1=play
private int gamemode;
private void Start()
{
// initialize startSceneData & datas
// while GameObject StartSceneDataTransfer is exist
try
{
startSceneData = GameObject.Find("StartSceneDataTransfer").GetComponent<StartSeneData>();
gamemode = startSceneData.gamemode;
}
// while GameObject StartSceneDataTransfer is not exist
catch
{
Debug.LogError("Run WithOut StartScreen");
gamemode = 1;
}
// initialize scripts
paramContainer = parameterContainerObj.GetComponent<ParameterContainer>();
eneContainer = enemyContainerObj.GetComponent<EnemyContainer>();
blockContainer = sceneBlockContainerObj.GetComponent<SceneBlockContainer>();
envUICon = environmentUIControlObj.GetComponent<EnvironmentUIControl>();
targetCon = targetControllerObj.GetComponent<TargetController>();
hudController = HUDObj.GetComponent<HUDController>();
raySensors = GetComponent<RaySensors>();
playerController = this.transform.GetComponent<CharacterController>();
// initialize Environment parameters
lockMouse = paramContainer.lockMouse;
damage = paramContainer.damage;
fireRate = paramContainer.fireRate;
enemyNum = hudController.enemyNum;
lockCameraX = paramContainer.lockCameraX;
lockCameraY = paramContainer.lockCameraY;
// initialize remainTime
// this agent's tag
myTag = gameObject.tag;
}
// ------------动作处理--------------
// moveAgent 用于模拟Input.GetAxis移动
public void MoveAgent(int vertical, int horizontal)
{
// Vector3 thisMovement;
if (horizontal != 0)//当按下按键(水平方向)
{
if (vX < moveSpeed && vX > -moveSpeed)//当前速度小于最大速度
{
vX += (float)horizontal * acceleration;//增加加速度
}
else
{
//防止在一瞬间切换输入时速度仍保持不变
if ((vX * horizontal) > 0)//输入与当前速度方向同向
{
vX = (float)horizontal * moveSpeed; //限制最大速度
}
else
{
vX += (float)horizontal * acceleration;//增加加速度
}
}
}
else
{
if (Math.Abs(vX) > 0.001)
{
vX -= (vX / Math.Abs(vX)) * acceleration;//减少加速度
}
else
{
vX = 0;
}
}
if (vertical != 0)//当按下按键(垂直方向)
{
if (vZ < moveSpeed && vZ > -moveSpeed)//当前速度小于最大速度
{
vZ += (float)vertical * acceleration;//增加加速度
}
else
{
if ((vZ * vertical) > 0)//输入与当前速度方向同向
{
vZ = (float)vertical * moveSpeed; //限制最大速度
}
else
{
vZ += (float)vertical * acceleration;//增加加速度
}
}
}
else
{
if (Math.Abs(vZ) > 0.001)
{
vZ -= (vZ / Math.Abs(vZ)) * acceleration;//减少加速度
}
else
{
vZ = 0;
}
}
thisMovement = (transform.forward * vZ + transform.right * vX);
//PlayerController下的.Move为实现物体运动的函数
//Move()括号内放入一个Vector3类型的量本例中为Player_Move
if (thisMovement.magnitude > moveSpeed)
{
thisMovement = thisMovement.normalized * moveSpeed;
}
playerController.Move(thisMovement * Time.deltaTime);
// update Key Viewer
}
// ------------动作处理--------------
// cameraControl 用于控制Agent视角转动
public void CameraControl(float Mouse_X, float Mouse_Y)
{
//Mouse_X = Input.GetAxis("Mouse X") * MouseSensitivity * Time.deltaTime;
//Debug.Log(Input.GetAxis("Mouse X"));
//Mouse_Y = Input.GetAxis("Mouse Y") * MouseSensitivity * Time.deltaTime;
if (lockCameraX)
{
Mouse_X = 0;
}
if (lockCameraY)
{
Mouse_Y = 0;
}
yRotation = yRotation - Mouse_Y;
//xRotation值为正时屏幕下移当xRotation值为负时屏幕上移
//当鼠标向上滑动Mouse_Y值为正,xRotation-Mouse_Y的值为负,xRotation总的值为负屏幕视角向上滑动
//当鼠标向下滑动Mouse_Y值为负,xRotation-Mouse_Y的值为正,xRotation总的值为正屏幕视角向下滑动
//简单来说就是要控制鼠标滑动的方向与屏幕移动的方向要相同
//limit UP DOWN between -90 -> 90
yRotation = Mathf.Clamp(yRotation, -90f, 90f);
//相机左右旋转时是以Y轴为中心旋转的上下旋转时是以X轴为中心旋转的
transform.Rotate(Vector3.up * Mouse_X);
//Vector3.up相当于Vector3(0,1,0),CameraRotation.Rotate(Vector3.up * Mouse_X)相当于使CameraRotation对象绕y轴旋转Mouse_X个单位
//即相机左右旋转时是以Y轴为中心旋转的此时Mouse_X控制着值的大小
//相机在上下旋转移动时,相机方向不会随着移动,类似于低头和抬头,左右移动时,相机方向会随着向左向右移动,类似于向左向右看
//所以在控制相机向左向右旋转时,要保证和父物体一起转动
thisCam.transform.localRotation = Quaternion.Euler(yRotation, 0, 0);
//this.transform指这个CameraRotation的位置,localRotation指的是旋转轴
//transform.localRotation = Quaternion.Eular(x,y,z)控制旋转的时候按照X-Y-Z轴的旋转顺规
//即以围绕X轴旋转x度围绕Y轴旋转y度围绕Z轴旋转z度
//且绕轴旋转的坐标轴是父节点本地坐标系的坐标轴
}
// GotKill 获得击杀时用于被呼出
public void KillRecord(Vector3 thiskillEnemyPosition)
{
enemyKillCount += 1;
killEnemyPosition = thiskillEnemyPosition;
}
// check gun is ready to shoot
private bool GunReady()
{
if ((Time.time - lastShootTime) >= fireRate)
{
return true;
}
else
{
return false;
}
}
// ballistic 射击弹道处理并返回获得reward
private float Ballistic()
{
Vector3 point = new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0);//发射位置
Ray ray = thisCam.ScreenPointToRay(point);
RaycastHit hit;
// Debug.DrawRay(ray.origin, ray.direction * 100, Color.blue);
//按下鼠标左键
if (shoot != 0 && gunReadyToggle == true)
{
lastShootTime = Time.time;
if (Physics.Raycast(ray, out hit, 100))
{
if (hit.collider.tag != myTag && hit.collider.tag != "Wall")
{
// kill enemy
GameObject gotHitObj = hit.transform.gameObject;//获取受到Ray撞击的对象
gotHitObj.GetComponent<States>().ReactToHit(damage, gameObject);
shoot = 0;
return targetCon.HitEnemyReward(gotHitObj.transform.position);
}
}
if (targetCon.targetTypeInt == (int)SceneBlockContainer.Targets.Attack)
{
// while if attack mode
float targetDis = Vector3.Distance(blockContainer.thisBlock.transform.position, transform.position);
if (targetDis <= raySensors.viewDistance)
{
// Debug.DrawRay(new Vector3(0,0,0), viewPoint, Color.red);
if (Vector3.Distance(ray.origin + (ray.direction * targetDis), blockContainer.thisBlock.transform.position) <= blockContainer.thisBlock.firebasesAreaDiameter / 2)
{
// im shooting at target but didn't hit enemy
// Debug.DrawRay(ray.origin, viewPoint-ray.origin, Color.blue);
return paramContainer.shootTargetAreaReward;
}
}
}
shoot = 0;
return paramContainer.shootReward;
}
else if (shoot != 0 && gunReadyToggle == false)
{
// shoot without ready
shoot = 0;
return paramContainer.shootWithoutReadyReward;
}
else
{
// do not shoot
shoot = 0;
return paramContainer.nonReward;
}
}
private float FacingReward()
{
float thisReward = 0;
bool isFacingtoEnemy = false;
float enemyFacingDistance = 0f;
Ray ray = thisCam.ScreenPointToRay(new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0));
if (targetCon.targetTypeInt == (int)SceneBlockContainer.Targets.Free)
{
//free mode
RaycastHit hit;
if (Physics.Raycast(ray, out hit, 100))
{
// facing to an enemy
if (hit.collider.tag != myTag && hit.collider.tag != "Wall")
{
thisReward = paramContainer.facingReward;
isFacingtoEnemy = true;
}
}
if (raySensors.inViewEnemies.Count > 0 && !isFacingtoEnemy)
{
// have enemy in view
List<float> projectionDis = new List<float>();
foreach (GameObject thisEnemy in raySensors.inViewEnemies)
{
// for each enemy in view
Vector3 projection = Vector3.Project(thisEnemy.transform.position - transform.position, (ray.direction * 10));
Vector3 verticalToRay = transform.position + projection - thisEnemy.transform.position;
projectionDis.Add(verticalToRay.magnitude);
// Debug.Log("enemy!" + verticalToRay.magnitude);
// Debug.DrawRay(transform.position, (ray.direction * 100), Color.cyan);
// Debug.DrawRay(transform.position, thisEnemy.transform.position - transform.position, Color.yellow);
// Debug.DrawRay(transform.position, projection, Color.blue);
// Debug.DrawRay(thisEnemy.transform.position, verticalToRay, Color.magenta);
}
enemyFacingDistance = projectionDis.Min();
if (enemyFacingDistance <= lastEnemyFacingDistance)
{
// closing to enemy
thisReward = 1 / MathF.Sqrt(paramContainer.facingInviewEnemyDisCOEF * enemyFacingDistance + 0.00001f);
}
else
{
thisReward = 0;
}
// enemy in view Reward
lastEnemyFacingDistance = enemyFacingDistance;
if (thisReward >= paramContainer.facingReward) thisReward = paramContainer.facingReward; // limit
if (thisReward <= -paramContainer.facingReward) thisReward = -paramContainer.facingReward; // limit
// Debug.Log("ninimum = " + thisReward);
}
}
else if (targetCon.targetTypeInt == (int)SceneBlockContainer.Targets.Attack)
{
// attack mode
// Target to Agent distance
float targetDis = Vector3.Distance(blockContainer.thisBlock.transform.position, transform.position);
// center of screen between target's distance
float camCenterToTarget = Vector3.Distance(ray.origin + (ray.direction * targetDis), blockContainer.thisBlock.transform.position);
if (targetDis <= raySensors.viewDistance)
{
// Debug.DrawRay(new Vector3(0,0,0), viewPoint, Color.red);
// while center of screen between target's distance is lower than firebasesAreaDiameter
// while facing to target
if (camCenterToTarget <= blockContainer.thisBlock.firebasesAreaDiameter / 2)
{
// Debug.DrawRay(ray.origin, viewPoint-ray.origin, Color.blue);
thisReward = paramContainer.facingReward;
}
else
{
// while not facing to target
thisReward = (lastTargetFacingDistance - camCenterToTarget) * paramContainer.facingTargetReward;
}
}
// update lastTargetFacingDistance
lastTargetFacingDistance = camCenterToTarget;
}
return thisReward;
}
// ------------Reward--------------
// rewardCalculate 计算本动作的Reward
public float RewardCalculate(float sceneReward, float mouseX, float movement)
{
float epreward = 0f;
// 击杀reward判断
if (enemyKillCount > 0)
{
for (int i = 0; i < enemyKillCount; i++)
{
// get
epreward += targetCon.KillReward(killEnemyPosition);
}
enemyKillCount = 0;
}
else
{
enemyKillCount = 0;
}
// 射击动作reward判断
epreward += Ballistic() + sceneReward;
// facing reward
epreward += FacingReward();
// Penalty
// spin penalty
spinRecord.Add(mouseX);
if (spinRecord.Count >= paramContainer.spinRecordMax)
{
spinRecord.RemoveAt(0);
}
float spinPenaltyReward = Math.Abs(spinRecord.ToArray().Sum() * paramContainer.spinPenalty);
if (spinPenaltyReward >= paramContainer.spinPenaltyThreshold)
{
epreward -= spinPenaltyReward;
}
else
{
epreward -= Math.Abs(mouseX) * paramContainer.mousePenalty;
}
// move penalty
if (movement != 0)
{
epreward -= paramContainer.movePenalty;
}
return epreward;
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// env开始执行初始化
public override void OnEpisodeBegin()
{
Debug.LogWarning("GameState|START TEST!");
step = 0;
if (lockMouse)
{
Cursor.lockState = CursorLockMode.Locked; // hide and lock the mouse
}
paramContainer.ResetTimeBonusReward();
//thisAgentObj.name = thisAgentObj.GetInstanceID().ToString();
if (gamemode == 0)
{
// train mode
targetCon.RollNewScene();
}
else
{
// play mode
targetCon.PlayInitialize();
}
// give default Reward to Reward value will be used.
if (hudController.chartOn)
{
envUICon.InitChart();
}
raySensors.UpdateRayInfo(); // update raycast
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 观察情报
public override void CollectObservations(VectorSensor sensor)
{
//List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
//List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
/**myObserve[0] = transform.localPosition.x / raySensors.viewDistance;
myObserve[1] = transform.localPosition.y / raySensors.viewDistance;
myObserve[2] = transform.localPosition.z / raySensors.viewDistance;
myObserve[3] = transform.eulerAngles.y / 360f;**/
myObserve[0] = transform.localPosition.x;
myObserve[1] = transform.localPosition.y;
myObserve[2] = transform.localPosition.z;
myObserve[3] = transform.eulerAngles.y / 36f;
rayTagResult = raySensors.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1)
rayTagResultOnehot = raySensors.rayTagResultOneHot.ToArray(); // 探测用RayTagonehot结果 List<int>[](raySensorNum*Tags,1)
rayDisResult = raySensors.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1)
targetStates = targetCon.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter
remainTime = targetCon.leftTime;
inAreaState = targetCon.GetInAreaState();
gunReadyToggle = GunReady();
//float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(allEnemyNum); // 敌人数量 int
sensor.AddObservation(targetStates);// (6) targettype, target x,y,z, firebasesAreaDiameter
sensor.AddObservation(inAreaState); // (1)
sensor.AddObservation(remainTime); // (1)
sensor.AddObservation(gunReadyToggle); // (1) save gun is ready?
sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1)
if (oneHotRayTag)
{
sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1)
}
else
{
sensor.AddObservation(rayTagResult);
}
sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1)
envUICon.UpdateStateText(targetStates, inAreaState, remainTime, gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult);
/*foreach(float aaa in rayDisResult)
{
Debug.Log(aaa);
}
Debug.LogWarning("------------");*/
//sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(raySensorNum); // raySensor数量 int
//sensor.AddObservation(remainTime); // RemainTime int
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// agent 输入处理
public override void OnActionReceived(ActionBuffers actionBuffers)
{
//获取输入
int vertical = actionBuffers.DiscreteActions[0];
int horizontal = actionBuffers.DiscreteActions[1];
int mouseShoot = actionBuffers.DiscreteActions[2];
float Mouse_X = actionBuffers.ContinuousActions[0];
if (vertical == 2) vertical = -1;
if (horizontal == 2) horizontal = -1;
//应用输入
shoot = mouseShoot;
CameraControl(Mouse_X, 0);
MoveAgent(vertical, horizontal);
raySensors.UpdateRayInfo(); // update raycast
//判断结束
float sceneReward = 0f;
float endReward = 0f;
(finishedState, sceneReward, endReward) = targetCon.CheckOverAndRewards();
float thisRoundReward = RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal));
if (hudController.chartOn)
{
envUICon.UpdateChart(thisRoundReward);
}
else
{
envUICon.RemoveChart();
}
//Debug.Log("reward = " + thisRoundReward);
if (finishedState != (int)TargetController.EndType.Running)
{
// Win or lose Finished
Debug.Log("Finish reward = " + thisRoundReward);
EP += 1;
string targetString = Enum.GetName(typeof(SceneBlockContainer.Targets), targetCon.targetTypeInt);
switch (finishedState)
{
case (int)TargetController.EndType.Win:
Debug.LogWarning("Result|" + targetString + "|Win");
break;
case (int)TargetController.EndType.Lose:
Debug.LogWarning("Result|" + targetString + "|Lose");
break;
default:
Debug.LogWarning("TypeError");
break;
}
SetReward(thisRoundReward);
EndEpisode();
}
else
{
// game not over yet
step += 1;
}
SetReward(thisRoundReward);
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 控制调试
public override void Heuristic(in ActionBuffers actionsOut)
{
//-------------------BUILD
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
int vertical = 0;
int horizontal = 0;
if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S))
{
vertical = 1;
}
else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W))
{
vertical = -1;
}
else
{
vertical = 0;
}
if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A))
{
horizontal = 1;
}
else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D))
{
horizontal = -1;
}
else
{
horizontal = 0;
}
if (Input.GetMouseButton(0))
{
// Debug.Log("mousebuttonhit");
shoot = 1;
}
else
{
shoot = 0;
}
discreteActions[0] = vertical;
discreteActions[1] = horizontal;
discreteActions[2] = shoot;
//^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^
//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv
float Mouse_X = Input.GetAxis("Mouse X") * mouseXSensitivity * Time.deltaTime;
float Mouse_Y = Input.GetAxis("Mouse Y") * mouseYSensitivity * Time.deltaTime;
continuousActions[0] = Mouse_X;
//continuousActions[1] = nonReward;
//continuousActions[2] = shootReward;
//continuousActions[3] = shootWithoutReadyReward;
//continuousActions[4] = hitReward;
//continuousActions[5] = winReward;
//continuousActions[6] = loseReward;
//continuousActions[7] = killReward;
//continuousActions[1] = Mouse_Y;
//continuousActions[2] = timeLimit;
//^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^
}
}