using System; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using UnityEngine; public class MLAgentsCustomController : Agent { public GameObject paramContainerObj; public GameObject targetControllerObj; public GameObject environmentUIObj; public GameObject hudUIObj; [Header("Env")] public bool oneHotRayTag = true; // script private AgentController agentController; private ParameterContainer paramContainer; private TargetController targetController; private EnvironmentUIControl envUIController; private HUDController hudController; private TargetUIController targetUIController; private RaySensors raySensors; // observation private float[] myObserve = new float[4]; private float[] rayTagResult; private float[] rayTagResultOnehot; private float[] rayDisResult; private float[] targetStates; private float remainTime; private float inAreaState; private int finishedState; private int step = 0; private int EP = 0; private void Start() { agentController = transform.GetComponent(); raySensors = transform.GetComponent(); paramContainer = paramContainerObj.GetComponent(); targetController = targetControllerObj.GetComponent(); envUIController = environmentUIObj.GetComponent(); hudController = hudUIObj.GetComponent(); targetUIController = hudUIObj.GetComponent(); } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // env开始执行初始化 public override void OnEpisodeBegin() { Debug.LogWarning("GameState|START TEST!"); step = 0; agentController.UpdateLockMouse(); paramContainer.ResetTimeBonusReward(); //thisAgentObj.name = thisAgentObj.GetInstanceID().ToString(); if (paramContainer.gameMode == 0) { // train mode targetController.RollNewScene(); } else { // play mode targetController.PlayInitialize(); // reset target UI targetUIController.ClearGamePressed(); } // give default Reward to Reward value will be used. if (hudController.chartOn) { envUIController.InitChart(); } raySensors.UpdateRayInfo(); // update raycast } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // 观察情报 public override void CollectObservations(VectorSensor sensor) { //List enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances //List enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances /**myObserve[0] = transform.localPosition.x / raySensors.viewDistance; myObserve[1] = transform.localPosition.y / raySensors.viewDistance; myObserve[2] = transform.localPosition.z / raySensors.viewDistance; myObserve[3] = transform.eulerAngles.y / 360f;**/ myObserve[0] = transform.localPosition.x; myObserve[1] = transform.localPosition.y; myObserve[2] = transform.localPosition.z; myObserve[3] = transform.eulerAngles.y / 36f; rayTagResult = raySensors.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1) rayTagResultOnehot = raySensors.rayTagResultOneHot.ToArray(); // 探测用RayTagonehot结果 List[](raySensorNum*Tags,1) rayDisResult = raySensors.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1) targetStates = targetController.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter remainTime = targetController.leftTime; inAreaState = targetController.GetInAreaState(); agentController.UpdateGunState(); //float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(allEnemyNum); // 敌人数量 int sensor.AddObservation(targetStates);// (6) targettype, target x,y,z, firebasesAreaDiameter sensor.AddObservation(inAreaState); // (1) sensor.AddObservation(remainTime); // (1) sensor.AddObservation(agentController.gunReadyToggle); // (1) save gun is ready? sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1) if (oneHotRayTag) { sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1) } else { sensor.AddObservation(rayTagResult); } sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1) envUIController.UpdateStateText(targetStates, inAreaState, remainTime, agentController.gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult); /*foreach(float aaa in rayDisResult) { Debug.Log(aaa); } Debug.LogWarning("------------");*/ //sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(raySensorNum); // raySensor数量 int //sensor.AddObservation(remainTime); // RemainTime int } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // agent 输入处理 public override void OnActionReceived(ActionBuffers actionBuffers) { //获取输入 int vertical = actionBuffers.DiscreteActions[0]; int horizontal = actionBuffers.DiscreteActions[1]; int mouseShoot = actionBuffers.DiscreteActions[2]; float Mouse_X = actionBuffers.ContinuousActions[0]; if (vertical == 2) vertical = -1; if (horizontal == 2) horizontal = -1; //应用输入 agentController.CameraControl(Mouse_X, 0); agentController.MoveAgent(vertical, horizontal); raySensors.UpdateRayInfo(); // update raycast //判断结束 float sceneReward = 0f; float endReward = 0f; (finishedState, sceneReward, endReward) = targetController.CheckOverAndRewards(); float thisRoundReward = agentController.RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal), mouseShoot); if (hudController.chartOn) { envUIController.UpdateChart(thisRoundReward); } else { envUIController.RemoveChart(); } //Debug.Log("reward = " + thisRoundReward); if (finishedState != (int)TargetController.EndType.Running) { // Win or lose Finished Debug.Log("Finish reward = " + thisRoundReward); EP += 1; string targetString = Enum.GetName(typeof(SceneBlockContainer.Targets), targetController.targetTypeInt); switch (finishedState) { case (int)TargetController.EndType.Win: Debug.LogWarning("Result|" + targetString + "|Win"); break; case (int)TargetController.EndType.Lose: Debug.LogWarning("Result|" + targetString + "|Lose"); break; default: Debug.LogWarning("TypeError"); break; } SetReward(thisRoundReward); EndEpisode(); } else { // game not over yet step += 1; } SetReward(thisRoundReward); } // ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS // 控制调试 public override void Heuristic(in ActionBuffers actionsOut) { //-------------------BUILD ActionSegment continuousActions = actionsOut.ContinuousActions; ActionSegment discreteActions = actionsOut.DiscreteActions; if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S)) { discreteActions[0] = 1; } else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W)) { discreteActions[0] = -1; } else { discreteActions[0] = 0; } if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A)) { discreteActions[1] = 1; } else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D)) { discreteActions[1] = -1; } else { discreteActions[1] = 0; } if (Input.GetMouseButton(0)) { // Debug.Log("mousebuttonhit"); discreteActions[2] = 1; } else { discreteActions[2] = 0; } //^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^ //vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv float Mouse_X = Input.GetAxis("Mouse X") * agentController.mouseXSensitivity * Time.deltaTime; //float Mouse_Y = Input.GetAxis("Mouse Y") * agentController.mouseYSensitivity * Time.deltaTime; continuousActions[0] = Mouse_X; //continuousActions[1] = nonReward; //continuousActions[2] = shootReward; //continuousActions[3] = shootWithoutReadyReward; //continuousActions[4] = hitReward; //continuousActions[5] = winReward; //continuousActions[6] = loseReward; //continuousActions[7] = killReward; //continuousActions[1] = Mouse_Y; //continuousActions[2] = timeLimit; //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^ } }