Aimbot-ParallelEnv/Assets/Script/GameScript/MLAgentsCustomController.cs

258 lines
10 KiB
C#
Raw Normal View History

using System;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class MLAgentsCustomController : Agent
{
[SerializeField] private GameObject paramContainerObj;
[SerializeField] private GameObject targetControllerObj;
[SerializeField] private GameObject environmentUIObj;
[SerializeField] private GameObject sideChannelObj;
[SerializeField] private GameObject worldUIControllerObj;
[SerializeField] private GameObject hudObj;
// script
private AgentController agentController;
private ParameterContainer paramContainer;
private CommonParameterContainer commonParamCon;
private TargetController targetController;
private EnvironmentUIControl envUIController;
private HUDController hudController;
private TargetUIController targetUIController;
private RaySensors raySensors;
private MessageBoxController messageBoxController;
private AimBotSideChannelController sideChannelController;
private WorldUIController worldUICon;
private RewardFunction rewardFunction;
// observation
private float[] myObserve = new float[4];
private float[] rayTagResult;
private float[] rayTagResultOnehot;
private float[] rayDisResult;
private float[] targetStates;
private float remainTime;
private float inFireBaseState;
private int endTypeInt;
private void Start()
{
agentController = transform.GetComponent<AgentController>();
raySensors = transform.GetComponent<RaySensors>();
paramContainer = paramContainerObj.GetComponent<ParameterContainer>();
commonParamCon = CommonParameterContainer.Instance;
targetController = targetControllerObj.GetComponent<TargetController>();
envUIController = environmentUIObj.GetComponent<EnvironmentUIControl>();
hudController = hudObj.GetComponent<HUDController>();
targetUIController = hudObj.GetComponent<TargetUIController>();
messageBoxController = hudObj.GetComponent<MessageBoxController>();
sideChannelController = sideChannelObj.GetComponent<AimBotSideChannelController>();
rewardFunction = gameObject.GetComponent<RewardFunction>();
worldUICon = worldUIControllerObj.GetComponent<WorldUIController>();
}
public override void OnEpisodeBegin()
{
agentController.UpdateLockMouse();
paramContainer.ResetTimeBonusReward();
if (commonParamCon.gameMode == 0)
{
// train mode
Debug.Log("MLAgentCustomController.OnEpisodeBegin: train mode start");
targetController.RollNewScene();
}
else
{
Debug.Log("MLAgentCustomController.OnEpisodeBegin: play mode start");
// play mode
targetController.PlayInitialize();
// reset target UI
targetUIController.ClearGamePressed();
}
// give default Reward to Reward value will be used.
if (hudController.chartOn)
{
envUIController.InitChart();
}
raySensors.UpdateRayInfo(); // update raycast
}
public override void CollectObservations(VectorSensor sensor)
{
//List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
//List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
/**myObserve[0] = transform.localPosition.x / raySensors.viewDistance;
myObserve[1] = transform.localPosition.y / raySensors.viewDistance;
myObserve[2] = transform.localPosition.z / raySensors.viewDistance;
myObserve[3] = transform.eulerAngles.y / 360f;**/
float angleInRadians = transform.eulerAngles.y * Mathf.Deg2Rad;
myObserve[0] = transform.localPosition.x;
myObserve[1] = transform.localPosition.y;
myObserve[2] = transform.localPosition.z;
myObserve[3] = MathF.Sin(angleInRadians);
myObserve[4] = MathF.Cos(angleInRadians);
2023-12-29 14:51:05 +00:00
rayTagResult = raySensors.rayTagResult;// 探测用RayTag类型结果 float[](raySensorNum,1)
rayTagResultOnehot = raySensors.rayTagResultOneHot; // 探测用RayTagonehot结果 List<int>[](raySensorNum*Tags,1)
2023-12-29 14:51:05 +00:00
rayDisResult = raySensors.rayDisResult; // 探测用RayDis距离结果 float[](raySensorNum,1)
targetStates = targetController.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter
remainTime = targetController.leftTime;
inFireBaseState = targetController.GetInAreaState();
agentController.UpdateGunState();
//float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(allEnemyNum); // 敌人数量 int
sensor.AddObservation(targetStates);// (5) targettype, target x,y,z, firebasesAreaDiameter
sensor.AddObservation(inFireBaseState); // (1)
sensor.AddObservation(remainTime); // (1)
sensor.AddObservation(agentController.gunReadyToggle); // (1) save gun is ready?
sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1)
if (commonParamCon.oneHotRayTag)
{
sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1)
}
else
{
sensor.AddObservation(rayTagResult);
}
2023-12-29 14:51:05 +00:00
sensor.AddObservation(rayDisResult); // 探测用RayDis距离结果 float[](raySensorNum,1)
envUIController.UpdateStateText(targetStates, inFireBaseState, remainTime, agentController.gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult);
/*foreach(float aaa in rayDisResult)
{
Debug.Log(aaa);
}
Debug.LogWarning("------------");*/
//sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(raySensorNum); // raySensor数量 int
//sensor.AddObservation(remainTime); // RemainTime int
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
//获取输入
int vertical = actionBuffers.DiscreteActions[0];
int horizontal = actionBuffers.DiscreteActions[1];
int mouseShoot = actionBuffers.DiscreteActions[2];
float Mouse_X = actionBuffers.ContinuousActions[0];
if (vertical == 2) vertical = -1;
if (horizontal == 2) horizontal = -1;
//应用输入
agentController.CameraControl(Mouse_X, 0);
agentController.MoveAgent(vertical, horizontal);
raySensors.UpdateRayInfo(); // update raycast
//判断结束
float sceneReward = 0f;
float endReward = 0f;
(endTypeInt, sceneReward, endReward) = rewardFunction.CheckOverAndRewards();
float nowReward = rewardFunction.RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal), mouseShoot);
if (hudController.chartOn)
{
envUIController.UpdateChart(nowReward);
}
else
{
envUIController.RemoveChart();
}
worldUICon.UpdateChart(targetController.targetType, endTypeInt);
//Debug.Log("reward = " + nowReward);
if (endTypeInt != (int)TargetController.EndType.Running)
{
// Win or lose Finished
Debug.Log("Finish reward = " + nowReward);
string targetString = Enum.GetName(typeof(Targets), targetController.targetType);
switch (endTypeInt)
{
case (int)TargetController.EndType.Win:
sideChannelController.SendSideChannelMessage("Result", targetString + "|Win");
messageBoxController.PushMessage(
new List<string> { "Game Win" },
new List<string> { "green" });
break;
case (int)TargetController.EndType.Lose:
sideChannelController.SendSideChannelMessage("Result", targetString + "|Lose");
messageBoxController.PushMessage(
new List<string> { "Game Lose" },
new List<string> { "red" });
break;
default:
Debug.LogWarning("TypeError");
break;
}
SetReward(nowReward);
EndEpisode();
}
else
{
// game not over yet
}
SetReward(nowReward);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
//-------------------BUILD
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S))
{
discreteActions[0] = 1;
}
else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W))
{
discreteActions[0] = -1;
}
else
{
discreteActions[0] = 0;
}
if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A))
{
discreteActions[1] = 1;
}
else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D))
{
discreteActions[1] = -1;
}
else
{
discreteActions[1] = 0;
}
if (Input.GetMouseButton(0))
{
// Debug.Log("mousebuttonhit");
discreteActions[2] = 1;
}
else
{
discreteActions[2] = 0;
}
//^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^
//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv
float Mouse_X = Input.GetAxis("Mouse X") * agentController.mouseXSensitivity * Time.deltaTime;
//float Mouse_Y = Input.GetAxis("Mouse Y") * agentController.mouseYSensitivity * Time.deltaTime;
continuousActions[0] = Mouse_X;
//continuousActions[1] = nonReward;
//continuousActions[2] = shootReward;
//continuousActions[3] = shootWithoutReadyReward;
//continuousActions[4] = hitReward;
//continuousActions[5] = winReward;
//continuousActions[6] = loseReward;
//continuousActions[7] = killReward;
//continuousActions[1] = Mouse_Y;
//continuousActions[2] = timeLimit;
//^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^
}
}