194 lines
6.2 KiB
C#
194 lines
6.2 KiB
C#
using System.Collections.Generic;
|
||
using UnityEngine;
|
||
using Unity.MLAgents;
|
||
using Unity.MLAgents.Sensors;
|
||
using Unity.MLAgents.Actuators;
|
||
|
||
public class RollerAgent : Agent
|
||
{
|
||
|
||
public Transform Target;
|
||
public Transform MinLocation;
|
||
public GameObject UPArrow;
|
||
public GameObject DownArrow;
|
||
public GameObject LArrow;
|
||
public GameObject RPArrow;
|
||
|
||
public int timeLimit = 8;
|
||
public float forceMultiplier = 10;
|
||
public float WINREWARD = 10.0f;
|
||
public float FAILREWARD = -10.0f;
|
||
private float minDistance = 0f;
|
||
|
||
|
||
Rigidbody rBody;
|
||
private float startTime = 0;
|
||
|
||
void Start()
|
||
{
|
||
rBody = GetComponent<Rigidbody>();
|
||
}
|
||
|
||
//Episode开始时执行
|
||
public override void OnEpisodeBegin()
|
||
{
|
||
startTime = Time.time;//Reset StartTime as now time
|
||
// If the Agent fell, zero its momentum
|
||
if (this.transform.localPosition.y < 0)
|
||
{
|
||
this.rBody.angularVelocity = Vector3.zero;
|
||
this.rBody.velocity = Vector3.zero;
|
||
this.transform.localPosition = new Vector3(0, 0.5f, 0);
|
||
}
|
||
|
||
// Random Target Position
|
||
Vector3 NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
|
||
float dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
|
||
while (dist <= 1.45f){
|
||
NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
|
||
dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
|
||
}
|
||
Target.localPosition = NewTargetPosition;
|
||
minDistance = dist;
|
||
MinLocation.localPosition = this.transform.localPosition;
|
||
}
|
||
|
||
// 观察情报
|
||
public override void CollectObservations(VectorSensor sensor)
|
||
{
|
||
// Target and Agent positions
|
||
sensor.AddObservation(Target.localPosition.x);
|
||
sensor.AddObservation(Target.localPosition.z);
|
||
sensor.AddObservation(this.transform.localPosition.x);
|
||
sensor.AddObservation(this.transform.localPosition.z);
|
||
|
||
// Agent velocity
|
||
sensor.AddObservation(rBody.velocity.x);
|
||
sensor.AddObservation(rBody.velocity.z);
|
||
}
|
||
|
||
// 移动
|
||
public void MoveBall(int action_x, int action_z)
|
||
{
|
||
// action = [0,0]
|
||
Vector3 controlSignal = Vector3.zero;
|
||
controlSignal.x = action_x;
|
||
controlSignal.z = action_z;
|
||
|
||
GameObject[] Arrows = GameObject.FindGameObjectsWithTag("Arrow");
|
||
foreach (GameObject gameObject in Arrows)
|
||
{
|
||
Destroy(gameObject, 0);
|
||
}
|
||
if (action_x == 1)
|
||
{
|
||
Instantiate(UPArrow, new Vector3(this.transform.localPosition.x + 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z+0.0f), Quaternion.Euler(0, 0, 0));
|
||
}else if(action_x == -1)
|
||
{
|
||
Instantiate(DownArrow, new Vector3(this.transform.localPosition.x - 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 0.0f), Quaternion.Euler(0, 180, 0));
|
||
}
|
||
if (action_z == 1)
|
||
{
|
||
Instantiate(LArrow, new Vector3(this.transform.localPosition.x + 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 1.0f), Quaternion.Euler(0, 270, 0));
|
||
}
|
||
else if (action_z == -1)
|
||
{
|
||
Instantiate(RPArrow, new Vector3(this.transform.localPosition.x - 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z - 1.0f), Quaternion.Euler(0, 90, 0));
|
||
}
|
||
|
||
|
||
rBody.AddForce(controlSignal * forceMultiplier);
|
||
}
|
||
|
||
// agent 输入处理
|
||
public override void OnActionReceived(ActionBuffers actionBuffers)
|
||
{
|
||
// Actions, size = 2
|
||
int inpX = 0;
|
||
int inpZ = 0;
|
||
inpX = actionBuffers.DiscreteActions[0];
|
||
inpZ = actionBuffers.DiscreteActions[1];
|
||
//Debug.Log(actionBuffers.DiscreteActions[0]);
|
||
MoveBall(inpX, inpZ);
|
||
// Rewards
|
||
// 向target靠近则会获取reward,step靠近target的距离。
|
||
float nowDistanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
|
||
|
||
|
||
if (Time.time - startTime >= timeLimit || this.transform.localPosition.y < 0)
|
||
{
|
||
// Time UP or Fall from game area
|
||
SetReward(FAILREWARD);
|
||
//Debug.Log("Rewards = " + thisReward);
|
||
EndEpisode();
|
||
}
|
||
else if(nowDistanceToTarget < 1.42f)
|
||
{
|
||
// Got the target
|
||
SetReward(WINREWARD);
|
||
//Debug.Log("Rewards = " + thisReward);
|
||
EndEpisode();
|
||
}
|
||
else
|
||
{
|
||
float thisReward = 0f;
|
||
if (nowDistanceToTarget < minDistance)
|
||
{
|
||
thisReward = minDistance - nowDistanceToTarget;
|
||
minDistance = nowDistanceToTarget;
|
||
MinLocation.localPosition = this.transform.localPosition;
|
||
SetReward(thisReward);
|
||
//Debug.Log("Rewards = " + thisReward);
|
||
}
|
||
else if (nowDistanceToTarget > minDistance)
|
||
{
|
||
thisReward = minDistance - nowDistanceToTarget;
|
||
SetReward(thisReward);
|
||
//Debug.Log("Rewards = " + thisReward);
|
||
}
|
||
else
|
||
{
|
||
thisReward = 0f;
|
||
SetReward(thisReward);
|
||
//Debug.Log("Rewards = " + thisReward);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 键盘控制调试
|
||
public override void Heuristic(in ActionBuffers actionsOut)
|
||
{
|
||
//ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
|
||
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
|
||
|
||
int inpX = 0;
|
||
int inpY = 0;
|
||
if (Input.GetKey(KeyCode.W))
|
||
{
|
||
inpX = 1;
|
||
}
|
||
else if (Input.GetKey(KeyCode.S))
|
||
{
|
||
inpX = -1;
|
||
}
|
||
else
|
||
{
|
||
inpX = 0;
|
||
}
|
||
if (Input.GetKey(KeyCode.A))
|
||
{
|
||
inpY = 1;
|
||
}
|
||
else if (Input.GetKey(KeyCode.D))
|
||
{
|
||
inpY = -1;
|
||
}
|
||
else
|
||
{
|
||
inpY = 0;
|
||
}
|
||
|
||
discreteActions[0] = inpX;
|
||
discreteActions[1] = inpY;
|
||
}
|
||
} |