add new reward function in attack mode calculate distance between closest enemy and facing center line. let agent could spawn in whole map area. add penalty while mouseX is moving.
using System;
using System.Reflection;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using UnityEditor;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using System.Linq;
using System.Drawing;
using Color = UnityEngine.Color;
using static TargetController;
public class AgentWithGun : Agent
public GameObject ParameterContainerObj;
public GameObject EnvironmentObj;
public GameObject EnemyContainerObj;
public GameObject SceneBlockContainerObj;
public GameObject EnvironmentUIControlObj;
public GameObject TargetControllerObj;
public Camera thisCam;
[Header("GetAxis() Simulate")]
public float MoveSpeed = 2.0f;
public float vX = 0f;
public float vZ = 0f;
public float acceleration = 0.1f; // 加速度
public float mouseXSensitivity = 100;
public float mouseYSensitivity = 200;
public float yRotation = 0.1f;//定义一个浮点类型的量,记录‘围绕’X轴旋转的角度
private List<float> spinRecord = new List<float>();
private bool lockMouse;
private float Damage;
private float fireRate;
private int enemyNum;
private bool lockCameraX;
private bool lockCameraY;
// environment
private int shoot = 0;
private float lastShootTime = 0.0f;
private int nowEnemyNum = 0;
private int enemyKillCount = 0;
private Vector3 killEnemyPosition;
private int step = 0;
private int EP = 0;
public bool defaultTPCamera = true;
private bool gunReadyToggle = true;
private string myTag = "";
// scripts
private RaySensors rayScript;
private CharacterController PlayerController;
private EnvironmentUIControl EnvUICon;
private ParameterContainer paramContainer;
private SceneBlockContainer blockContainer;
private EnemyContainer eneContainer;
private TargetController targetCon;
[System.NonSerialized] public int finishedState;
private void Start()
paramContainer = ParameterContainerObj.GetComponent<ParameterContainer>();
eneContainer = EnemyContainerObj.GetComponent<EnemyContainer>();
blockContainer = SceneBlockContainerObj.GetComponent<SceneBlockContainer>();
EnvUICon = EnvironmentUIControlObj.GetComponent<EnvironmentUIControl>();
targetCon = TargetControllerObj.GetComponent<TargetController>();
rayScript = GetComponent<RaySensors>();
PlayerController = this.transform.GetComponent<CharacterController>();
// Environment parameters
lockMouse = paramContainer.lockMouse;
Damage = paramContainer.Damage;
fireRate = paramContainer.fireRate;
enemyNum = paramContainer.enemyNum;
lockCameraX = paramContainer.lockCameraX;
lockCameraY = paramContainer.lockCameraY;
//initialize remainTime
// this agent's tag
myTag = gameObject.tag;
/* ----------此Update用于debug,Build前删除或注释掉!----------*/
/*void Update()
/* ----------此Update用于debug,Build前删除或注释掉!----------*/
// ------------动作处理--------------
// moveAgent 用于模拟Input.GetAxis移动
public void moveAgent(int vertical, int horizontal)
Vector3 thisMovement;
if (horizontal != 0)//当按下按键(水平方向)
if (vX < MoveSpeed && vX > -MoveSpeed)//当前速度小于最大速度
vX += (float)horizontal * acceleration;//增加加速度
if ((vX * horizontal) > 0)//输入与当前速度方向同向
vX = (float)horizontal * MoveSpeed; //限制最大速度
vX += (float)horizontal * acceleration;//增加加速度
if (Math.Abs(vX) > 0.001)
vX -= (vX / Math.Abs(vX)) * acceleration;//减少加速度
vX = 0;
if (vertical != 0)//当按下按键(垂直方向)
if (vZ < MoveSpeed && vZ > -MoveSpeed)//当前速度小于最大速度
vZ += (float)vertical * acceleration;//增加加速度
if ((vZ * vertical) > 0)//输入与当前速度方向同向
vZ = (float)vertical * MoveSpeed; //限制最大速度
vZ += (float)vertical * acceleration;//增加加速度
if (Math.Abs(vZ) > 0.001)
vZ -= (vZ / Math.Abs(vZ)) * acceleration;//减少加速度
vZ = 0;
thisMovement = (transform.forward * vZ + transform.right * vX) * MoveSpeed;
PlayerController.Move(thisMovement * Time.deltaTime);
// update Key Viewer
// ------------动作处理--------------
// cameraControl 用于控制Agent视角转动
public void cameraControl(float Mouse_X, float Mouse_Y)
//Mouse_X = Input.GetAxis("Mouse X") * MouseSensitivity * Time.deltaTime;
//Debug.Log(Input.GetAxis("Mouse X"));
//Mouse_Y = Input.GetAxis("Mouse Y") * MouseSensitivity * Time.deltaTime;
if (lockCameraX)
Mouse_X = 0;
if (lockCameraY)
Mouse_Y = 0;
yRotation = yRotation - Mouse_Y;
//limit UP DOWN between -90 -> 90
yRotation = Mathf.Clamp(yRotation, -90f, 90f);
transform.Rotate(Vector3.up * Mouse_X);
//Vector3.up相当于Vector3(0,1,0),CameraRotation.Rotate(Vector3.up * Mouse_X)相当于使CameraRotation对象绕y轴旋转Mouse_X个单位
thisCam.transform.localRotation = Quaternion.Euler(yRotation, 0, 0);
//transform.localRotation = Quaternion.Eular(x,y,z)控制旋转的时候,按照X-Y-Z轴的旋转顺规
// GotKill 获得击杀时用于被呼出
public void killRecord(Vector3 thiskillEnemyPosition)
enemyKillCount += 1;
killEnemyPosition = thiskillEnemyPosition;
// check gun is ready to shoot
bool gunReady()
if ((Time.time - lastShootTime) >= fireRate)
return true;
return false;
// ballistic 射击弹道处理,并返回获得reward
float ballistic()
Vector3 point = new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0);//发射位置
Ray ray = thisCam.ScreenPointToRay(point);
RaycastHit hit;
// Debug.DrawRay(ray.origin, ray.direction * 100,;
if (shoot != 0 && gunReadyToggle == true)
lastShootTime = Time.time;
if (Physics.Raycast(ray, out hit, 100))
if (hit.collider.tag != myTag && hit.collider.tag != "Wall")
// kill enemy
GameObject gotHitObj = hit.transform.gameObject;//获取受到Ray撞击的对象
gotHitObj.GetComponent<states>().ReactToHit(Damage, gameObject);
shoot = 0;
return targetCon.hitEnemyReward(gotHitObj.transform.position);
if (targetCon.targetTypeInt == (int)TargetController.Targets.Attack)
// while if attack mode
float targetDis = Vector3.Distance(blockContainer.thisBlock.transform.position, transform.position);
if (targetDis <= rayScript.viewDistance)
// Debug.DrawRay(new Vector3(0,0,0), viewPoint,;
if (Vector3.Distance(ray.origin + (ray.direction * targetDis), blockContainer.thisBlock.transform.position) <= blockContainer.thisBlock.firebasesAreaDiameter / 2)
// im shooting at target but didn't hit enemy
// Debug.DrawRay(ray.origin, viewPoint-ray.origin,;
return paramContainer.shootTargetAreaReward;
shoot = 0;
return paramContainer.shootReward;
else if (shoot != 0 && gunReadyToggle == false)
// shoot without ready
shoot = 0;
return paramContainer.shootWithoutReadyReward;
// do not shoot
shoot = 0;
return paramContainer.nonReward;
float facingReward()
float thisReward = 0;
bool isFacingtoEnemy = false;
Ray ray = thisCam.ScreenPointToRay(new Vector3(thisCam.pixelWidth / 2, thisCam.pixelHeight / 2, 0));
if (targetCon.targetTypeInt == (int)TargetController.Targets.Free)
//free mode
RaycastHit hit;
if (Physics.Raycast(ray, out hit, 100))
// facing to an enemy
if (hit.collider.tag != myTag && hit.collider.tag != "Wall")
thisReward = paramContainer.facingReward;
isFacingtoEnemy = true;
if (rayScript.inViewEnemies.Count > 0 && !isFacingtoEnemy) {
// have enemy in view
List<float> projectionDis = new List<float>();
foreach (GameObject thisEnemy in rayScript.inViewEnemies)
// for each enemy in view
Vector3 projection = Vector3.Project(thisEnemy.transform.position - transform.position, (ray.direction * 10));
Vector3 verticalToRay = transform.position + projection - thisEnemy.transform.position;
// Debug.Log("enemy!" + verticalToRay.magnitude);
// Debug.DrawRay(transform.position, (ray.direction * 100), Color.cyan);
// Debug.DrawRay(transform.position, thisEnemy.transform.position - transform.position, Color.yellow);
// Debug.DrawRay(transform.position, projection,;
// Debug.DrawRay(thisEnemy.transform.position, verticalToRay, Color.magenta);
// enemy in view Reward
thisReward = 1 / MathF.Sqrt(paramContainer.facingInviewEnemyDisCOEF* projectionDis.Min()+0.00001f);
if (thisReward >= paramContainer.facingReward) thisReward = paramContainer.facingReward; // limit
Debug.Log("ninimum = " + thisReward);
else if(targetCon.targetTypeInt == (int)TargetController.Targets.Attack)
// attack mode
float targetDis = Vector3.Distance(blockContainer.thisBlock.transform.position, transform.position);
if(targetDis <= rayScript.viewDistance)
// Debug.DrawRay(new Vector3(0,0,0), viewPoint,;
if (Vector3.Distance(ray.origin + (ray.direction * targetDis), blockContainer.thisBlock.transform.position) <= blockContainer.thisBlock.firebasesAreaDiameter / 2)
// im watching target
// Debug.DrawRay(ray.origin, viewPoint-ray.origin,;
thisReward = paramContainer.facingReward;
return thisReward;
// getEnemyNum 获取现场除了自己以外的敌人数量
int getEnemyNum()
int enemyNum = 0;
GameObject[] EnemyGameObjs;
EnemyGameObjs = GameObject.FindGameObjectsWithTag("Enemy");
foreach (GameObject EnemyObj in EnemyGameObjs)
Vector3 thisEnemyPosition = EnemyObj.transform.localPosition;
Vector3 thisEnemyScale = EnemyObj.transform.localScale;
Vector3 MyselfPosition = transform.localPosition;
if (thisEnemyPosition == MyselfPosition)
//Debug.Log("OH It's me");
enemyNum += 1;
return enemyNum;
// enemyNumDiff 获取与上一把相比敌人数量的区别
int enemyNumDiff()
int diff = 0;
int nowEnemyNum = getEnemyNum();
diff = enemyNum - nowEnemyNum;
return diff;
// ------------Reward--------------
// rewardCalculate 计算本动作的Reward
public float rewardCalculate(float sceneReward,float mouseX)
float epreward = 0f;
// 击杀reward判断
if (enemyKillCount > 0)
for (int i = 0; i < enemyKillCount; i++)
// get
epreward += targetCon.killReward(killEnemyPosition);
enemyKillCount = 0;
enemyKillCount = 0;
// 射击动作reward判断
epreward += ballistic() + sceneReward;
// facing reward
epreward += facingReward();
// spin penalty
if (spinRecord.Count >= paramContainer.spinRecordMax)
float spinPenaltyReward = Math.Abs(spinRecord.ToArray().Sum() * paramContainer.spinPenalty);
if(spinPenaltyReward >= paramContainer.spinPenaltyThreshold)
epreward -= spinPenaltyReward;
epreward -= Math.Abs(mouseX) * paramContainer.mousePenalty;
return epreward;
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// env开始执行初始化
public override void OnEpisodeBegin()
step = 0;
if (lockMouse)
Cursor.lockState = CursorLockMode.Locked; // hide and lock the mouse
// = thisAgentObj.GetInstanceID().ToString();
nowEnemyNum = getEnemyNum(); // Reset Enemy number
// give default Reward to Reward value will be used.
if (paramContainer.chartOn)
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 观察情报
public override void CollectObservations(VectorSensor sensor)
//List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
//List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
float[] myObserve = { transform.localPosition.x, transform.localPosition.y, transform.localPosition.z, transform.eulerAngles.y };
float[] rayTagResult = rayScript.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1)
float[] rayDisResult = rayScript.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1)
float[] targetStates = targetCon.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter
float remainTime = targetCon.leftTime;
//float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(allEnemyNum); // 敌人数量 int
sensor.AddObservation(targetStates);// (6) targettype, target x,y,z, firebasesAreaDiameter
sensor.AddObservation(targetCon.getInAreaState()); // (1)
sensor.AddObservation(remainTime); // (1)
sensor.AddObservation(gunReadyToggle); //(1) save gun is ready?
sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1)
sensor.AddObservation(rayTagResult); // 探测用RayTag结果 float[](raySensorNum,1)
sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1)
//sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(raySensorNum); // raySensor数量 int
gunReadyToggle = gunReady();
//sensor.AddObservation(remainTime); // RemainTime int
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// agent 输入处理
public override void OnActionReceived(ActionBuffers actionBuffers)
int vertical = actionBuffers.DiscreteActions[0];
int horizontal = actionBuffers.DiscreteActions[1];
int mouseShoot = actionBuffers.DiscreteActions[2];
float Mouse_X = actionBuffers.ContinuousActions[0];
if (vertical == 2) vertical = -1;
if (horizontal == 2) horizontal = -1;
shoot = mouseShoot;
cameraControl(Mouse_X, 0);
moveAgent(vertical, horizontal);
rayScript.updateRayInfo(); // update raycast
float sceneReward = 0f;
float endReward = 0f;
(finishedState, sceneReward, endReward) = targetCon.checkOverAndRewards();
float thisRoundReward = rewardCalculate(sceneReward+ endReward,Mouse_X);
if (paramContainer.chartOn)
//Debug.Log("reward = " + thisRoundReward);
if (finishedState != (int)TargetController.EndType.Running)
// Win or lose Finished
Debug.Log("Finish reward = " + thisRoundReward);
EP += 1;
string targetString = Enum.GetName(typeof(TargetController.Targets), targetCon.targetTypeInt);
switch (finishedState)
case (int)TargetController.EndType.Win:
case (int)TargetController.EndType.Lose:
// game not over yet
step += 1;
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 控制调试
public override void Heuristic(in ActionBuffers actionsOut)
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
int vertical = 0;
int horizontal = 0;
if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S))
vertical = 1;
else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W))
vertical = -1;
vertical = 0;
if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A))
horizontal = 1;
else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D))
horizontal = -1;
horizontal = 0;
if (Input.GetMouseButton(0))
// Debug.Log("mousebuttonhit");
shoot = 1;
shoot = 0;
discreteActions[0] = vertical;
discreteActions[1] = horizontal;
discreteActions[2] = shoot;
float Mouse_X = Input.GetAxis("Mouse X") * mouseXSensitivity * Time.deltaTime;
float Mouse_Y = Input.GetAxis("Mouse Y") * mouseYSensitivity * Time.deltaTime;
continuousActions[0] = Mouse_X;
//continuousActions[1] = nonReward;
//continuousActions[2] = shootReward;
//continuousActions[3] = shootWithoutReadyReward;
//continuousActions[4] = hitReward;
//continuousActions[5] = winReward;
//continuousActions[6] = loseReward;
//continuousActions[7] = killReward;
//continuousActions[1] = Mouse_Y;
//continuousActions[2] = timeLimit;
} |