using System; using System.Collections.Generic; using System.Linq; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using UnityEngine; /*TODO: √tag 攻击排他 √通用HP 系统 环境tag修正? 以tag重置环境修正 Agent死亡时待机处理*/ public class AgentWithGun : Agent { public GameObject parameterContainerObj; public GameObject environmentObj; public GameObject enemyContainerObj; public GameObject sceneBlockContainerObj; public GameObject environmentUIControlObj; public GameObject targetControllerObj; public GameObject HUDObj; public Camera thisCam; [Header("GetAxis() Simulate")] public float moveSpeed = 9.0f; public float vX = 0f; public float vZ = 0f; public Vector3 thisMovement; public float acceleration = 0.1f; // 加速度 public float mouseXSensitivity = 100; public float mouseYSensitivity = 200; public float yRotation = 0.1f;//定义一个浮点类型的量,记录‘围绕’X轴旋转的角度 [Header("Env")] public bool oneHotRayTag = true; private List spinRecord = new List(); private bool lockMouse; private float damage; private float fireRate; private int enemyNum; private bool lockCameraX; private bool lockCameraY; // environment private int shoot = 0; private float lastShootTime = 0.0f; private int enemyKillCount = 0; private Vector3 killEnemyPosition; private int step = 0; private int EP = 0; public bool defaultTPCamera = true; private bool gunReadyToggle = true; private string myTag = ""; private float lastEnemyFacingDistance = 0f; // record last enemy facing minimum distance private float lastTargetFacingDistance = 0f; // record last target facing minimum distance // scripts private RaySensors raySensors; private CharacterController playerController; private EnvironmentUIControl envUICon; private ParameterContainer paramContainer; private SceneBlockContainer blockContainer; private EnemyContainer eneContainer; private TargetController targetCon; private HUDController hudController; private StartSeneData startSceneData; // observation private float[] myObserve = new float[4]; private float[] rayTagResult; private float[] rayTagResultOnehot; private float[] rayDisResult; private float[] targetStates; private float remainTime; private float inAreaState; [System.NonSerialized] public int finishedState; // start scene datas 0=train 1=play private int gameMode; private void Start() { // initialize scripts paramContainer = parameterContainerObj.GetComponent(); eneContainer = enemyContainerObj.GetComponent(); blockContainer = sceneBlockContainerObj.GetComponent(); envUICon = environmentUIControlObj.GetComponent(); targetCon = targetControllerObj.GetComponent(); hudController = HUDObj.GetComponent(); raySensors = GetComponent(); playerController = this.transform.GetComponent(); // initialize gamemode from PrameterContainer gameMode = paramContainer.gameMode; // initialize Environment parameters lockMouse = paramContainer.lockMouse; damage = paramContainer.damage; fireRate = paramContainer.fireRate; enemyNum = hudController.enemyNum; lockCameraX = paramContainer.lockCameraX; lockCameraY = paramContainer.lockCameraY; // initialize remainTime // this agent's tag myTag = gameObject.tag; } // ------------动作处理-------------- // moveAgent 用于模拟Input.GetAxis移动 public void MoveAgent(int vertical, int horizontal) { // Vector3 thisMovement; if (horizontal != 0)//当按下按键(水平方向) { if (vX < moveSpeed && vX > -moveSpeed)//当前速度小于最大速度 { vX += (float)horizontal * acceleration;//增加加速度 } else { //防止在一瞬间切换输入时速度仍保持不变 if ((vX * horizontal) > 0)//输入与当前速度方向同向 { vX = (float)horizontal * moveSpeed; //限制最大速度 } else { vX += (float)horizontal * acceleration;//增加加速度 } } } else { if (Math.Abs(vX) > 0.001) { vX -= (vX / Math.Abs(vX)) * acceleration;//减少加速度 } else { vX = 0; } } if (vertical != 0)//当按下按键(垂直方向) { if (vZ < moveSpeed && vZ > -moveSpeed)//当前速度小于最大速度 { vZ += (float)vertical * acceleration;//增加加速度 } else { if ((vZ * vertical) > 0)//输入与当前速度方向同向 { vZ = (float)vertical * moveSpeed; //限制最大速度 } else { vZ += (float)vertical * acceleration;//增加加速度 } } } else { if (Math.Abs(vZ) > 0.001) { vZ -= (vZ / Math.Abs(vZ)) * acceleration;//减少加速度 } else { vZ = 0; } } thisMovement = (transform.forward * vZ + transform.right * vX); //PlayerController下的.Move为实现物体运动的函数 //Move()括号内放入一个Vector3类型的量,本例中为Player_Move if (thisMovement.magnitude > moveSpeed) { thisMovement = thisMovement.normalized * moveSpeed; } playerController.Move(thisMovement * Time.deltaTime); // update Key Viewer } // ------------动作处理-------------- // cameraControl 用于控制Agent视角转动 public void CameraControl(float Mouse_X, float Mouse_Y) { //Mouse_X = Input.GetAxis("Mouse X") * MouseSensitivity * Time.deltaTime; //Debug.Log(Input.GetAxis("Mouse X")); //Mouse_Y = Input.GetAxis("Mouse Y") * MouseSensitivity * Time.deltaTime; if (lockCameraX) { Mouse_X = 0; } if (lockCameraY) { Mouse_Y = 0; } yRotation = yRotation - Mouse_Y; //xRotation值为正时,屏幕下移,当xRotation值为负时,屏幕上移 //当鼠标向上滑动,Mouse_Y值为正,xRotation-Mouse_Y的值为负,xRotation总的值为负,屏幕视角向上滑动 //当鼠标向下滑动,Mouse_Y值为负,xRotation-Mouse_Y的值为正,xRotation总的值为正,屏幕视角向下滑动 //简单来说就是要控制鼠标滑动的方向与屏幕移动的方向要相同 //limit UP DOWN between -90 -> 90 yRotation = Mathf.Clamp(yRotation, -90f, 90f); //相机左右旋转时,是以Y轴为中心旋转的,上下旋转时,是以X轴为中心旋转的 transform.Rotate(Vector3.up * Mouse_X); //Vector3.up相当于Vector3(0,1,0),CameraRotation.Rotate(Vector3.up * Mouse_X)相当于使CameraRotation对象绕y轴旋转Mouse_X个单位 //即相机左右旋转时,是以Y轴为中心旋转的,此时Mouse_X控制着值的大小 //相机在上下旋转移动时,相机方向不会随着移动,类似于低头和抬头,左右移动时,相机方向会随着向左向右移动,类似于向左向右看 //所以在控制相机向左向右旋转时,要保证和父物体一起转动 thisCam.transform.localRotation = Quaternion.Euler(yRotation, 0, 0); //this.transform指这个CameraRotation的位置,localRotation指的是旋转轴 //transform.localRotation = Quaternion.Eular(x,y,z)控制旋转的时候,按照X-Y-Z轴的旋转顺规 //即以围绕X轴旋转x度,围绕Y轴旋转y度,围绕Z轴旋转z度 //且绕轴旋转的坐标轴是父节点本地坐标系的坐标轴 } // GotKill 获得击杀时用于被呼出 public void KillRecord(Vector3 thiskillEnemyPosition) { enemyKillCount += 1; killEnemyPosition = thiskillEnemyPosition; } // check gun is ready to shoot private bool GunReady() { if ((Time.time - lastShootTime) >= fireRate) { return true; } else { return false; } } // ballistic 射击弹道处理,并返回获得reward private float Ballistic() { Vector3 point = new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0);//发射位置 Ray ray = thisCam.ScreenPointToRay(point); RaycastHit hit; // Debug.DrawRay(ray.origin, ray.direction * 100, Color.blue); //按下鼠标左键 if (shoot != 0 && gunReadyToggle == true) { lastShootTime = Time.time; if (Physics.Raycast(ray, out hit, 100)) { if (hit.collider.tag != myTag && hit.collider.tag != "Wall") { // kill enemy GameObject gotHitObj = hit.transform.gameObject;//获取受到Ray撞击的对象 gotHitObj.GetComponent().ReactToHit(damage, gameObject); shoot = 0; return targetCon.HitEnemyReward(gotHitObj.transform.position); } } if (targetCon.targetTypeInt == (int)SceneBlockContainer.Targets.Attack) { // while if attack mode float targetDis = Vector3.Distance(blockContainer.thisBlock.transform.position, transform.position); if (targetDis <= raySensors.viewDistance) { // Debug.DrawRay(new Vector3(0,0,0), viewPoint, Color.red); if (Vector3.Distance(ray.origin + (ray.direction * targetDis), blockContainer.thisBlock.transform.position) <= blockContainer.thisBlock.firebasesAreaDiameter / 2) { // im shooting at target but didn't hit enemy // Debug.DrawRay(ray.origin, viewPoint-ray.origin, Color.blue); return paramContainer.shootTargetAreaReward; } } } shoot = 0; return paramContainer.shootReward; } else if (shoot != 0 && gunReadyToggle == false) { // shoot without ready shoot = 0; return paramContainer.shootWithoutReadyReward; } else { // do not shoot shoot = 0; return paramContainer.nonReward; } } private float FacingReward() { float thisReward = 0; bool isFacingtoEnemy = false; float enemyFacingDistance = 0f; Ray ray = thisCam.ScreenPointToRay(new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0)); if (targetCon.targetTypeInt == (int)SceneBlockContainer.Targets.Free) { //free mode RaycastHit hit; if (Physics.Raycast(ray, out hit, 100)) { // facing to an enemy if (hit.collider.tag != myTag && hit.collider.tag != "Wall") { thisReward = paramContainer.facingReward; isFacingtoEnemy = true; } } if (raySensors.inViewEnemies.Count > 0 && !isFacingtoEnemy) { // have enemy in view List projectionDis = new List(); foreach (GameObject thisEnemy in raySensors.inViewEnemies) { // for each enemy in view Vector3 projection = Vector3.Project(thisEnemy.transform.position - transform.position, (ray.direction * 10)); Vector3 verticalToRay = transform.position + projection - thisEnemy.transform.position; projectionDis.Add(verticalToRay.magnitude); // Debug.Log("enemy!" + verticalToRay.magnitude); // Debug.DrawRay(transform.position, (ray.direction * 100), Color.cyan); // Debug.DrawRay(transform.position, thisEnemy.transform.position - transform.position, Color.yellow); // Debug.DrawRay(transform.position, projection, Color.blue); // Debug.DrawRay(thisEnemy.transform.position, verticalToRay, Color.magenta); } enemyFacingDistance = projectionDis.Min(); if (enemyFacingDistance <= lastEnemyFacingDistance) { // closing to enemy thisReward = 1 / MathF.Sqrt(paramContainer.facingInviewEnemyDisCOEF * enemyFacingDistance + 0.00001f); } else { thisReward = 0; } // enemy in view Reward lastEnemyFacingDistance = enemyFacingDistance; if (thisReward >= paramContainer.facingReward) thisReward = paramContainer.facingReward; // limit if (thisReward <= -paramContainer.facingReward) thisReward = -paramContainer.facingReward; // limit // Debug.Log("ninimum = " + thisReward); } } else if (targetCon.targetTypeInt == (int)SceneBlockContainer.Targets.Attack) { // attack mode // Target to Agent distance float targetDis = Vector3.Distance(blockContainer.thisBlock.transform.position, transform.position); // center of screen between target's distance float camCenterToTarget = Vector3.Distance(ray.origin + (ray.direction * targetDis), blockContainer.thisBlock.transform.position); if (targetDis <= raySensors.viewDistance) { // Debug.DrawRay(new Vector3(0,0,0), viewPoint, Color.red); // while center of screen between target's distance is lower than firebasesAreaDiameter // while facing to target if (camCenterToTarget <= blockContainer.thisBlock.firebasesAreaDiameter / 2) { // Debug.DrawRay(ray.origin, viewPoint-ray.origin, Color.blue); thisReward = paramContainer.facingReward; } else { // while not facing to target thisReward = (lastTargetFacingDistance - camCenterToTarget) * paramContainer.facingTargetReward; } } // update lastTargetFacingDistance lastTargetFacingDistance = camCenterToTarget; } return thisReward; } // ------------Reward-------------- // rewardCalculate 计算本动作的Reward public float RewardCalculate(float sceneReward, float mouseX, float movement) { float epreward = 0f; // 击杀reward判断 if (enemyKillCount > 0) { for (int i = 0; i < enemyKillCount; i++) { // get epreward += targetCon.KillReward(killEnemyPosition); } enemyKillCount = 0; } else { enemyKillCount = 0; } // 射击动作reward判断 epreward += Ballistic() + sceneReward; // facing reward epreward += FacingReward(); // Penalty // spin penalty spinRecord.Add(mouseX); if (spinRecord.Count >= paramContainer.spinRecordMax) { spinRecord.RemoveAt(0); } float spinPenaltyReward = Math.Abs(spinRecord.ToArray().Sum() * paramContainer.spinPenalty); if (spinPenaltyReward >= paramContainer.spinPenaltyThreshold) { epreward -= spinPenaltyReward; } else { epreward -= Math.Abs(mouseX) * paramContainer.mousePenalty; } // move penalty if (movement != 0) { epreward -= paramContainer.movePenalty; } return epreward; } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // env开始执行初始化 public override void OnEpisodeBegin() { Debug.LogWarning("GameState|START TEST!"); step = 0; if (lockMouse) { Cursor.lockState = CursorLockMode.Locked; // hide and lock the mouse } paramContainer.ResetTimeBonusReward(); //thisAgentObj.name = thisAgentObj.GetInstanceID().ToString(); if (gameMode == 0) { // train mode targetCon.RollNewScene(); } else { // play mode targetCon.PlayInitialize(); } // give default Reward to Reward value will be used. if (hudController.chartOn) { envUICon.InitChart(); } raySensors.UpdateRayInfo(); // update raycast } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // 观察情报 public override void CollectObservations(VectorSensor sensor) { //List enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances //List enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances /**myObserve[0] = transform.localPosition.x / raySensors.viewDistance; myObserve[1] = transform.localPosition.y / raySensors.viewDistance; myObserve[2] = transform.localPosition.z / raySensors.viewDistance; myObserve[3] = transform.eulerAngles.y / 360f;**/ myObserve[0] = transform.localPosition.x; myObserve[1] = transform.localPosition.y; myObserve[2] = transform.localPosition.z; myObserve[3] = transform.eulerAngles.y / 36f; rayTagResult = raySensors.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1) rayTagResultOnehot = raySensors.rayTagResultOneHot.ToArray(); // 探测用RayTagonehot结果 List[](raySensorNum*Tags,1) rayDisResult = raySensors.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1) targetStates = targetCon.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter remainTime = targetCon.leftTime; inAreaState = targetCon.GetInAreaState(); gunReadyToggle = GunReady(); //float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(allEnemyNum); // 敌人数量 int sensor.AddObservation(targetStates);// (6) targettype, target x,y,z, firebasesAreaDiameter sensor.AddObservation(inAreaState); // (1) sensor.AddObservation(remainTime); // (1) sensor.AddObservation(gunReadyToggle); // (1) save gun is ready? sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1) if (oneHotRayTag) { sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1) } else { sensor.AddObservation(rayTagResult); } sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1) envUICon.UpdateStateText(targetStates, inAreaState, remainTime, gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult); /*foreach(float aaa in rayDisResult) { Debug.Log(aaa); } Debug.LogWarning("------------");*/ //sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(raySensorNum); // raySensor数量 int //sensor.AddObservation(remainTime); // RemainTime int } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // agent 输入处理 public override void OnActionReceived(ActionBuffers actionBuffers) { //获取输入 int vertical = actionBuffers.DiscreteActions[0]; int horizontal = actionBuffers.DiscreteActions[1]; int mouseShoot = actionBuffers.DiscreteActions[2]; float Mouse_X = actionBuffers.ContinuousActions[0]; if (vertical == 2) vertical = -1; if (horizontal == 2) horizontal = -1; //应用输入 shoot = mouseShoot; CameraControl(Mouse_X, 0); MoveAgent(vertical, horizontal); raySensors.UpdateRayInfo(); // update raycast //判断结束 float sceneReward = 0f; float endReward = 0f; (finishedState, sceneReward, endReward) = targetCon.CheckOverAndRewards(); float thisRoundReward = RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal)); if (hudController.chartOn) { envUICon.UpdateChart(thisRoundReward); } else { envUICon.RemoveChart(); } //Debug.Log("reward = " + thisRoundReward); if (finishedState != (int)TargetController.EndType.Running) { // Win or lose Finished Debug.Log("Finish reward = " + thisRoundReward); EP += 1; string targetString = Enum.GetName(typeof(SceneBlockContainer.Targets), targetCon.targetTypeInt); switch (finishedState) { case (int)TargetController.EndType.Win: Debug.LogWarning("Result|" + targetString + "|Win"); break; case (int)TargetController.EndType.Lose: Debug.LogWarning("Result|" + targetString + "|Lose"); break; default: Debug.LogWarning("TypeError"); break; } SetReward(thisRoundReward); EndEpisode(); } else { // game not over yet step += 1; } SetReward(thisRoundReward); } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // 控制调试 public override void Heuristic(in ActionBuffers actionsOut) { //-------------------BUILD ActionSegment continuousActions = actionsOut.ContinuousActions; ActionSegment discreteActions = actionsOut.DiscreteActions; int vertical = 0; int horizontal = 0; if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S)) { vertical = 1; } else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W)) { vertical = -1; } else { vertical = 0; } if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A)) { horizontal = 1; } else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D)) { horizontal = -1; } else { horizontal = 0; } if (Input.GetMouseButton(0)) { // Debug.Log("mousebuttonhit"); shoot = 1; } else { shoot = 0; } discreteActions[0] = vertical; discreteActions[1] = horizontal; discreteActions[2] = shoot; //^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^ //vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv float Mouse_X = Input.GetAxis("Mouse X") * mouseXSensitivity * Time.deltaTime; float Mouse_Y = Input.GetAxis("Mouse Y") * mouseYSensitivity * Time.deltaTime; continuousActions[0] = Mouse_X; //continuousActions[1] = nonReward; //continuousActions[2] = shootReward; //continuousActions[3] = shootWithoutReadyReward; //continuousActions[4] = hitReward; //continuousActions[5] = winReward; //continuousActions[6] = loseReward; //continuousActions[7] = killReward; //continuousActions[1] = Mouse_Y; //continuousActions[2] = timeLimit; //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^ } }