using System; using System.Collections.Generic; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using UnityEngine; public class MLAgentsCustomController : Agent { [SerializeField] private GameObject paramContainerObj; [SerializeField] private GameObject targetControllerObj; [SerializeField] private GameObject environmentUIObj; [SerializeField] private GameObject sideChannelObj; [SerializeField] private GameObject worldUIControllerObj; [SerializeField] private GameObject hudObj; // script private AgentController agentController; private ParameterContainer paramContainer; private CommonParameterContainer commonParamCon; private TargetController targetController; private EnvironmentUIControl envUIController; private HUDController hudController; private TargetUIController targetUIController; private RaySensors raySensors; private MessageBoxController messageBoxController; private AimBotSideChannelController sideChannelController; private WorldUIController worldUICon; private RewardFunction rewardFunction; // observation private float[] myObserve = new float[5]; private float[] rayTagResult; private float[] rayTagResultOnehot; private float[] rayDisResult; private float remainTime; private float inFireBaseState; private int endTypeInt; private void Start() { agentController = transform.GetComponent(); raySensors = transform.GetComponent(); paramContainer = paramContainerObj.GetComponent(); commonParamCon = CommonParameterContainer.Instance; targetController = targetControllerObj.GetComponent(); envUIController = environmentUIObj.GetComponent(); hudController = hudObj.GetComponent(); targetUIController = hudObj.GetComponent(); messageBoxController = hudObj.GetComponent(); sideChannelController = sideChannelObj.GetComponent(); rewardFunction = gameObject.GetComponent(); worldUICon = worldUIControllerObj.GetComponent(); } public override void OnEpisodeBegin() { agentController.UpdateLockMouse(); paramContainer.ResetTimeBonusReward(); if (commonParamCon.gameMode == 0) { // train mode Debug.Log("MLAgentCustomController.OnEpisodeBegin: train mode start"); targetController.RollNewScene(); } else { Debug.Log("MLAgentCustomController.OnEpisodeBegin: play mode start"); // play mode targetController.PlayModeInitialize(); // reset target UI targetUIController.ClearGamePressed(); } // give default Reward to Reward value will be used. if (hudController.chartOn) { envUIController.InitChart(); } raySensors.UpdateRayInfo(); // update raycast } public override void CollectObservations(VectorSensor sensor) { //List enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances //List enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances /**myObserve[0] = transform.localPosition.x / raySensors.viewDistance; myObserve[1] = transform.localPosition.y / raySensors.viewDistance; myObserve[2] = transform.localPosition.z / raySensors.viewDistance; myObserve[3] = transform.eulerAngles.y / 360f;**/ float obsNum = 0f; float angleInRadians = transform.eulerAngles.y * Mathf.Deg2Rad; myObserve[0] = transform.localPosition.x; myObserve[1] = transform.localPosition.y; myObserve[2] = transform.localPosition.z; myObserve[3] = MathF.Sin(angleInRadians); myObserve[4] = MathF.Cos(angleInRadians); rayTagResult = raySensors.rayTagResult;// 探测用RayTag类型结果 float[](raySensorNum,1) rayTagResultOnehot = raySensors.rayTagResultOneHot; // 探测用RayTagonehot结果 List[](raySensorNum*Tags,1) rayDisResult = raySensors.rayDisResult; // 探测用RayDis距离结果 float[](raySensorNum,1) remainTime = targetController.leftTime; inFireBaseState = targetController.GetInAreaState(); agentController.UpdateGunState(); //float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(allEnemyNum); // 敌人数量 int sensor.AddObservation(targetController.targetState);// (5) targettype, target x,y,z, firebasesAreaDiameter sensor.AddObservation(inFireBaseState); // (1) sensor.AddObservation(remainTime); // (1) sensor.AddObservation(agentController.gunReadyToggle); // (1) save gun is ready? sensor.AddObservation(myObserve); // (5)自机位置xyz+朝向 float[](5,1) // count observation number obsNum = targetController.targetState.Length+1+1+1+myObserve.Length; Debug.Log(obsNum); if (commonParamCon.oneHotRayTag) { sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1) obsNum += rayTagResultOnehot.Length; } else { sensor.AddObservation(rayTagResult); obsNum += rayTagResult.Length; } Debug.Log(obsNum); sensor.AddObservation(rayDisResult); // 探测用RayDis距离结果 float[](raySensorNum,1) obsNum += rayDisResult.Length; envUIController.UpdateStateText(targetController.targetState, inFireBaseState, remainTime, agentController.gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult); Debug.Log(obsNum); /*foreach(float aaa in rayDisResult) { Debug.Log(aaa); } Debug.LogWarning("------------");*/ //sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z //sensor.AddObservation(raySensorNum); // raySensor数量 int //sensor.AddObservation(remainTime); // RemainTime int } public override void OnActionReceived(ActionBuffers actionBuffers) { //获取输入 int vertical = actionBuffers.DiscreteActions[0]; int horizontal = actionBuffers.DiscreteActions[1]; int mouseShoot = actionBuffers.DiscreteActions[2]; float Mouse_X = actionBuffers.ContinuousActions[0]; if (vertical == 2) vertical = -1; if (horizontal == 2) horizontal = -1; //应用输入 agentController.CameraControl(Mouse_X, 0); agentController.MoveAgent(vertical, horizontal); raySensors.UpdateRayInfo(); // update raycast //判断结束 float sceneReward = 0f; float endReward = 0f; (endTypeInt, sceneReward, endReward) = rewardFunction.CheckOverAndRewards(); float nowReward = rewardFunction.RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal), mouseShoot); if (hudController.chartOn) { envUIController.UpdateChart(nowReward); } else { envUIController.RemoveChart(); } worldUICon.UpdateChart(targetController.targetType, endTypeInt); //Debug.Log("reward = " + nowReward); if (endTypeInt != (int)TargetController.EndType.Running) { // Win or lose Finished Debug.Log("Finish reward = " + nowReward); string targetString = Enum.GetName(typeof(Targets), targetController.targetType); switch (endTypeInt) { case (int)TargetController.EndType.Win: sideChannelController.SendSideChannelMessage("Result", targetString + "|Win"); messageBoxController.PushMessage( new List { "Game Win" }, new List { "green" }); break; case (int)TargetController.EndType.Lose: sideChannelController.SendSideChannelMessage("Result", targetString + "|Lose"); messageBoxController.PushMessage( new List { "Game Lose" }, new List { "red" }); break; default: Debug.LogWarning("TypeError"); break; } SetReward(nowReward); EndEpisode(); } else { // game not over yet } SetReward(nowReward); } public override void Heuristic(in ActionBuffers actionsOut) { //-------------------BUILD ActionSegment continuousActions = actionsOut.ContinuousActions; ActionSegment discreteActions = actionsOut.DiscreteActions; if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S)) { discreteActions[0] = 1; } else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W)) { discreteActions[0] = -1; } else { discreteActions[0] = 0; } if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A)) { discreteActions[1] = 1; } else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D)) { discreteActions[1] = -1; } else { discreteActions[1] = 0; } if (Input.GetMouseButton(0)) { // Debug.Log("mousebuttonhit"); discreteActions[2] = 1; } else { discreteActions[2] = 0; } //^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^ //vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv float Mouse_X = Input.GetAxis("Mouse X") * agentController.mouseXSensitivity * Time.deltaTime; //float Mouse_Y = Input.GetAxis("Mouse Y") * agentController.mouseYSensitivity * Time.deltaTime; continuousActions[0] = Mouse_X; //continuousActions[1] = nonReward; //continuousActions[2] = shootReward; //continuousActions[3] = shootWithoutReadyReward; //continuousActions[4] = hitReward; //continuousActions[5] = winReward; //continuousActions[6] = loseReward; //continuousActions[7] = killReward; //continuousActions[1] = Mouse_Y; //continuousActions[2] = timeLimit; //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^ } }