Scroll-ball-env/Assets/Script/RollerAgent.cs
2024-03-05 19:02:19 +09:00

194 lines
6.2 KiB
C#
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class RollerAgent : Agent
{
public Transform Target;
public Transform MinLocation;
public GameObject UPArrow;
public GameObject DownArrow;
public GameObject LArrow;
public GameObject RPArrow;
public int timeLimit = 8;
public float forceMultiplier = 10;
public float WINREWARD = 10.0f;
public float FAILREWARD = -10.0f;
private float minDistance = 0f;
Rigidbody rBody;
private float startTime = 0;
void Start()
{
rBody = GetComponent<Rigidbody>();
}
//Episode开始时执行
public override void OnEpisodeBegin()
{
startTime = Time.time;//Reset StartTime as now time
// If the Agent fell, zero its momentum
if (this.transform.localPosition.y < 0)
{
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3(0, 0.5f, 0);
}
// Random Target Position
Vector3 NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
float dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
while (dist <= 1.45f){
NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
}
Target.localPosition = NewTargetPosition;
minDistance = dist;
MinLocation.localPosition = this.transform.localPosition;
}
// 观察情报
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.localPosition.x);
sensor.AddObservation(Target.localPosition.z);
sensor.AddObservation(this.transform.localPosition.x);
sensor.AddObservation(this.transform.localPosition.z);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
// 移动
public void MoveBall(int action_x, int action_z)
{
// action = [0,0]
Vector3 controlSignal = Vector3.zero;
controlSignal.x = action_x;
controlSignal.z = action_z;
GameObject[] Arrows = GameObject.FindGameObjectsWithTag("Arrow");
foreach (GameObject gameObject in Arrows)
{
Destroy(gameObject, 0);
}
if (action_x == 1)
{
Instantiate(UPArrow, new Vector3(this.transform.localPosition.x + 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z+0.0f), Quaternion.Euler(0, 0, 0));
}else if(action_x == -1)
{
Instantiate(DownArrow, new Vector3(this.transform.localPosition.x - 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 0.0f), Quaternion.Euler(0, 180, 0));
}
if (action_z == 1)
{
Instantiate(LArrow, new Vector3(this.transform.localPosition.x + 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 1.0f), Quaternion.Euler(0, 270, 0));
}
else if (action_z == -1)
{
Instantiate(RPArrow, new Vector3(this.transform.localPosition.x - 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z - 1.0f), Quaternion.Euler(0, 90, 0));
}
rBody.AddForce(controlSignal * forceMultiplier);
}
// agent 输入处理
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Actions, size = 2
int inpX = 0;
int inpZ = 0;
inpX = actionBuffers.DiscreteActions[0];
inpZ = actionBuffers.DiscreteActions[1];
//Debug.Log(actionBuffers.DiscreteActions[0]);
MoveBall(inpX, inpZ);
// Rewards
// 向target靠近则会获取rewardstep靠近target的距离。
float nowDistanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
if (Time.time - startTime >= timeLimit || this.transform.localPosition.y < 0)
{
// Time UP or Fall from game area
SetReward(FAILREWARD);
//Debug.Log("Rewards = " + thisReward);
EndEpisode();
}
else if(nowDistanceToTarget < 1.42f)
{
// Got the target
SetReward(WINREWARD);
//Debug.Log("Rewards = " + thisReward);
EndEpisode();
}
else
{
float thisReward = 0f;
if (nowDistanceToTarget < minDistance)
{
thisReward = minDistance - nowDistanceToTarget;
minDistance = nowDistanceToTarget;
MinLocation.localPosition = this.transform.localPosition;
SetReward(thisReward);
//Debug.Log("Rewards = " + thisReward);
}
else if (nowDistanceToTarget > minDistance)
{
thisReward = minDistance - nowDistanceToTarget;
SetReward(thisReward);
//Debug.Log("Rewards = " + thisReward);
}
else
{
thisReward = 0f;
SetReward(thisReward);
//Debug.Log("Rewards = " + thisReward);
}
}
}
// 键盘控制调试
public override void Heuristic(in ActionBuffers actionsOut)
{
//ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
int inpX = 0;
int inpY = 0;
if (Input.GetKey(KeyCode.W))
{
inpX = 1;
}
else if (Input.GetKey(KeyCode.S))
{
inpX = -1;
}
else
{
inpX = 0;
}
if (Input.GetKey(KeyCode.A))
{
inpY = 1;
}
else if (Input.GetKey(KeyCode.D))
{
inpY = -1;
}
else
{
inpY = 0;
}
discreteActions[0] = inpX;
discreteActions[1] = inpY;
}
}