using System; using System.Collections.Generic; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using UnityEngine; public class MLAgentsCustomController : Agent { [SerializeField] private GameObject paramContainerObj; [SerializeField] private GameObject targetControllerObj; [SerializeField] private GameObject environmentUIObj; [SerializeField] private GameObject sideChannelObj; [SerializeField] private GameObject worldUIControllerObj; [SerializeField] private GameObject hudObj; [Header("Env")] public bool oneHotRayTag = true; // script private AgentController agentController; private ParameterContainer paramContainer; private CommonParameterContainer commonParamCon; private TargetController targetController; private EnvironmentUIControl envUIController; private HUDController hudController; private TargetUIController targetUIController; private RaySensors raySensors; private MessageBoxController messageBoxController; private AimBotSideChannelController sideChannelController; private WorldUIController worldUICon; private RewardFunction rewardFunction; // observation private float[] myObserve = new float[4]; private float[] rayTagResult; private float[] rayTagResultOnehot; private float[] rayDisResult; private float[] targetStates; private float remainTime; private float inAreaState; private int endTypeInt; private void Start() { agentController = transform.GetComponent(); raySensors = transform.GetComponent(); paramContainer = paramContainerObj.GetComponent(); commonParamCon = CommonParameterContainer.Instance; targetController = targetControllerObj.GetComponent(); envUIController = environmentUIObj.GetComponent(); hudController = hudObj.GetComponent(); targetUIController = hudObj.GetComponent(); messageBoxController = hudObj.GetComponent(); sideChannelController = sideChannelObj.GetComponent(); rewardFunction = gameObject.GetComponent(); worldUICon = worldUIControllerObj.GetComponent(); } public override void OnEpisodeBegin() { agentController.UpdateLockMouse(); paramContainer.ResetTimeBonusReward(); if (commonParamCon.gameMode == 0) { // train mode Debug.Log("MLAgentCustomController.OnEpisodeBegin: train mode start"); targetController.RollNewScene(); } else { Debug.Log("MLAgentCustomController.OnEpisodeBegin: play mode start"); // play mode targetController.PlayInitialize(); // reset target UI targetUIController.ClearGamePressed(); } // give default Reward to Reward value will be used. if (hudController.chartOn) { envUIController.InitChart(); } raySensors.UpdateRayInfo(); // update raycast } public override void CollectObservations(VectorSensor sensor) { //List enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances //List enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances /**myObserve[0] = transform.localPosition.x / raySensors.viewDistance; myObserve[1] = transform.localPosition.y / raySensors.viewDistance; myObserve[2] = transform.localPosition.z / raySensors.viewDistance; myObserve[3] = transform.eulerAngles.y / 360f;**/ myObserve[0] = transform.localPosition.x; myObserve[1] = transform.localPosition.y; myObserve[2] = transform.localPosition.z; myObserve[3] = transform.eulerAngles.y / 36f; rayTagResult = raySensors.rayTagResult;// 探测用RayTag类型结果 float[](raySensorNum,1) rayTagResultOnehot = raySensors.rayTagResultOneHot.ToArray(); // 探测用RayTagonehot结果 List[](raySensorNum*Tags,1) rayDisResult = raySensors.rayDisResult; // 探测用RayDis距离结果 float[](raySensorNum,1) targetStates = targetController.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter remainTime = targetController.leftTime; inAreaState = targetController.GetInAreaState(); agentController.UpdateGunState(); //float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(allEnemyNum); // 敌人数量 int sensor.AddObservation(targetStates);// (6) targettype, target x,y,z, firebasesAreaDiameter sensor.AddObservation(inAreaState); // (1) sensor.AddObservation(remainTime); // (1) sensor.AddObservation(agentController.gunReadyToggle); // (1) save gun is ready? sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1) if (oneHotRayTag) { sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1) } else { sensor.AddObservation(rayTagResult); } sensor.AddObservation(rayDisResult); // 探测用RayDis距离结果 float[](raySensorNum,1) envUIController.UpdateStateText(targetStates, inAreaState, remainTime, agentController.gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult); /*foreach(float aaa in rayDisResult) { Debug.Log(aaa); } Debug.LogWarning("------------");*/ //sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(raySensorNum); // raySensor数量 int //sensor.AddObservation(remainTime); // RemainTime int } public override void OnActionReceived(ActionBuffers actionBuffers) { //获取输入 int vertical = actionBuffers.DiscreteActions[0]; int horizontal = actionBuffers.DiscreteActions[1]; int mouseShoot = actionBuffers.DiscreteActions[2]; float Mouse_X = actionBuffers.ContinuousActions[0]; if (vertical == 2) vertical = -1; if (horizontal == 2) horizontal = -1; //应用输入 agentController.CameraControl(Mouse_X, 0); agentController.MoveAgent(vertical, horizontal); raySensors.UpdateRayInfo(); // update raycast //判断结束 float sceneReward = 0f; float endReward = 0f; (endTypeInt, sceneReward, endReward) = rewardFunction.CheckOverAndRewards(); float nowReward = rewardFunction.RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal), mouseShoot); if (hudController.chartOn) { envUIController.UpdateChart(nowReward); } else { envUIController.RemoveChart(); } worldUICon.UpdateChart(targetController.targetType, endTypeInt); //Debug.Log("reward = " + nowReward); if (endTypeInt != (int)TargetController.EndType.Running) { // Win or lose Finished Debug.Log("Finish reward = " + nowReward); string targetString = Enum.GetName(typeof(Targets), targetController.targetType); switch (endTypeInt) { case (int)TargetController.EndType.Win: sideChannelController.SendSideChannelMessage("Result", targetString + "|Win"); messageBoxController.PushMessage( new List { "Game Win" }, new List { "green" }); break; case (int)TargetController.EndType.Lose: sideChannelController.SendSideChannelMessage("Result", targetString + "|Lose"); messageBoxController.PushMessage( new List { "Game Lose" }, new List { "red" }); break; default: Debug.LogWarning("TypeError"); break; } SetReward(nowReward); EndEpisode(); } else { // game not over yet } SetReward(nowReward); } public override void Heuristic(in ActionBuffers actionsOut) { //-------------------BUILD ActionSegment continuousActions = actionsOut.ContinuousActions; ActionSegment discreteActions = actionsOut.DiscreteActions; if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S)) { discreteActions[0] = 1; } else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W)) { discreteActions[0] = -1; } else { discreteActions[0] = 0; } if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A)) { discreteActions[1] = 1; } else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D)) { discreteActions[1] = -1; } else { discreteActions[1] = 0; } if (Input.GetMouseButton(0)) { // Debug.Log("mousebuttonhit"); discreteActions[2] = 1; } else { discreteActions[2] = 0; } //^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^ //vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv float Mouse_X = Input.GetAxis("Mouse X") * agentController.mouseXSensitivity * Time.deltaTime; //float Mouse_Y = Input.GetAxis("Mouse Y") * agentController.mouseYSensitivity * Time.deltaTime; continuousActions[0] = Mouse_X; //continuousActions[1] = nonReward; //continuousActions[2] = shootReward; //continuousActions[3] = shootWithoutReadyReward; //continuousActions[4] = hitReward; //continuousActions[5] = winReward; //continuousActions[6] = loseReward; //continuousActions[7] = killReward; //continuousActions[1] = Mouse_Y; //continuousActions[2] = timeLimit; //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^ } }