Aimbot-ParallelEnv/Assets/Script/GameScript/MLAgentsCustomController.cs
2023-10-23 01:54:30 +09:00

275 lines
10 KiB
C#

using System;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class MLAgentsCustomController : Agent
{
[SerializeField] private GameObject paramContainerObj;
[SerializeField] private GameObject CommonParameterContainer;
[SerializeField] private GameObject targetControllerObj;
[SerializeField] private GameObject environmentUIObj;
[SerializeField] private GameObject sideChannelObj;
[SerializeField] private GameObject hudUIObj;
[Header("Env")]
public bool oneHotRayTag = true;
// script
private AgentController agentController;
private ParameterContainer paramContainer;
private CommonParameterContainer commonParamCon;
private TargetController targetController;
private EnvironmentUIControl envUIController;
private HUDController hudController;
private TargetUIController targetUIController;
private RaySensors raySensors;
private MessageBoxController messageBoxController;
private AimBotSideChannelController sideChannelController;
// observation
private float[] myObserve = new float[4];
private float[] rayTagResult;
private float[] rayTagResultOnehot;
private float[] rayDisResult;
private float[] targetStates;
private float remainTime;
private float inAreaState;
private int finishedState;
private int step = 0;
private int EP = 0;
private void Start()
{
agentController = transform.GetComponent<AgentController>();
raySensors = transform.GetComponent<RaySensors>();
paramContainer = paramContainerObj.GetComponent<ParameterContainer>();
commonParamCon = CommonParameterContainer.GetComponent<CommonParameterContainer>();
targetController = targetControllerObj.GetComponent<TargetController>();
envUIController = environmentUIObj.GetComponent<EnvironmentUIControl>();
hudController = hudUIObj.GetComponent<HUDController>();
targetUIController = hudUIObj.GetComponent<TargetUIController>();
messageBoxController = hudUIObj.GetComponent<MessageBoxController>();
sideChannelController = sideChannelObj.GetComponent<AimBotSideChannelController>();
}
#region On episode begin function
public override void OnEpisodeBegin()
{
step = 0;
agentController.UpdateLockMouse();
paramContainer.ResetTimeBonusReward();
if (commonParamCon.gameMode == 0)
{
// train mode
Debug.Log("MLAgentCustomController.OnEpisodeBegin: train mode start");
targetController.RollNewScene();
}
else
{
Debug.Log("MLAgentCustomController.OnEpisodeBegin: play mode start");
// play mode
targetController.PlayInitialize();
// reset target UI
targetUIController.ClearGamePressed();
}
// give default Reward to Reward value will be used.
if (hudController.chartOn)
{
envUIController.InitChart();
}
raySensors.UpdateRayInfo(); // update raycast
}
#endregion On episode begin function
#region Observation sensor function
public override void CollectObservations(VectorSensor sensor)
{
//List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
//List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
/**myObserve[0] = transform.localPosition.x / raySensors.viewDistance;
myObserve[1] = transform.localPosition.y / raySensors.viewDistance;
myObserve[2] = transform.localPosition.z / raySensors.viewDistance;
myObserve[3] = transform.eulerAngles.y / 360f;**/
myObserve[0] = transform.localPosition.x;
myObserve[1] = transform.localPosition.y;
myObserve[2] = transform.localPosition.z;
myObserve[3] = transform.eulerAngles.y / 36f;
rayTagResult = raySensors.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1)
rayTagResultOnehot = raySensors.rayTagResultOneHot.ToArray(); // 探测用RayTagonehot结果 List<int>[](raySensorNum*Tags,1)
rayDisResult = raySensors.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1)
targetStates = targetController.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter
remainTime = targetController.leftTime;
inAreaState = targetController.GetInAreaState();
agentController.UpdateGunState();
//float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(allEnemyNum); // 敌人数量 int
sensor.AddObservation(targetStates);// (6) targettype, target x,y,z, firebasesAreaDiameter
sensor.AddObservation(inAreaState); // (1)
sensor.AddObservation(remainTime); // (1)
sensor.AddObservation(agentController.gunReadyToggle); // (1) save gun is ready?
sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1)
if (oneHotRayTag)
{
sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1)
}
else
{
sensor.AddObservation(rayTagResult);
}
sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1)
envUIController.UpdateStateText(targetStates, inAreaState, remainTime, agentController.gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult);
/*foreach(float aaa in rayDisResult)
{
Debug.Log(aaa);
}
Debug.LogWarning("------------");*/
//sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(raySensorNum); // raySensor数量 int
//sensor.AddObservation(remainTime); // RemainTime int
}
#endregion Observation sensor function
#region Action received function
public override void OnActionReceived(ActionBuffers actionBuffers)
{
//获取输入
int vertical = actionBuffers.DiscreteActions[0];
int horizontal = actionBuffers.DiscreteActions[1];
int mouseShoot = actionBuffers.DiscreteActions[2];
float Mouse_X = actionBuffers.ContinuousActions[0];
if (vertical == 2) vertical = -1;
if (horizontal == 2) horizontal = -1;
//应用输入
agentController.CameraControl(Mouse_X, 0);
agentController.MoveAgent(vertical, horizontal);
raySensors.UpdateRayInfo(); // update raycast
//判断结束
float sceneReward = 0f;
float endReward = 0f;
(finishedState, sceneReward, endReward) = targetController.CheckOverAndRewards();
float nowRoundReward = agentController.RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal), mouseShoot);
if (hudController.chartOn)
{
envUIController.UpdateChart(nowRoundReward);
}
else
{
envUIController.RemoveChart();
}
//Debug.Log("reward = " + nowRoundReward);
if (finishedState != (int)TargetController.EndType.Running)
{
// Win or lose Finished
Debug.Log("Finish reward = " + nowRoundReward);
EP += 1;
string targetString = Enum.GetName(typeof(Targets), targetController.targetTypeInt);
switch (finishedState)
{
case (int)TargetController.EndType.Win:
sideChannelController.SendSideChannelMessage("Result", targetString + "|Win");
messageBoxController.PushMessage(
new List<string> { "Game Win" },
new List<string> { "green" });
break;
case (int)TargetController.EndType.Lose:
sideChannelController.SendSideChannelMessage("Result", targetString + "|Lose");
messageBoxController.PushMessage(
new List<string> { "Game Lose" },
new List<string> { "red" });
break;
default:
Debug.LogWarning("TypeError");
break;
}
SetReward(nowRoundReward);
EndEpisode();
}
else
{
// game not over yet
step += 1;
}
SetReward(nowRoundReward);
}
#endregion Action received function
#region Heuristic function
public override void Heuristic(in ActionBuffers actionsOut)
{
//-------------------BUILD
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S))
{
discreteActions[0] = 1;
}
else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W))
{
discreteActions[0] = -1;
}
else
{
discreteActions[0] = 0;
}
if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A))
{
discreteActions[1] = 1;
}
else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D))
{
discreteActions[1] = -1;
}
else
{
discreteActions[1] = 0;
}
if (Input.GetMouseButton(0))
{
// Debug.Log("mousebuttonhit");
discreteActions[2] = 1;
}
else
{
discreteActions[2] = 0;
}
//^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^
//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv
float Mouse_X = Input.GetAxis("Mouse X") * agentController.mouseXSensitivity * Time.deltaTime;
//float Mouse_Y = Input.GetAxis("Mouse Y") * agentController.mouseYSensitivity * Time.deltaTime;
continuousActions[0] = Mouse_X;
//continuousActions[1] = nonReward;
//continuousActions[2] = shootReward;
//continuousActions[3] = shootWithoutReadyReward;
//continuousActions[4] = hitReward;
//continuousActions[5] = winReward;
//continuousActions[6] = loseReward;
//continuousActions[7] = killReward;
//continuousActions[1] = Mouse_Y;
//continuousActions[2] = timeLimit;
//^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^
}
#endregion Heuristic function
}