171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
|
import tensorflow as tf
|
||
|
import numpy as np
|
||
|
from numpy import ndarray
|
||
|
|
||
|
from PPO import PPO
|
||
|
from tensorflow import keras
|
||
|
from tensorflow.keras import layers
|
||
|
from tensorflow.keras import optimizers
|
||
|
|
||
|
from GAILConfig import GAILConfig
|
||
|
|
||
|
EPS = 1e-8
|
||
|
|
||
|
|
||
|
class GAIL(object):
|
||
|
def __init__(
|
||
|
self,
|
||
|
stateSize: int,
|
||
|
disActShape: list,
|
||
|
conActSize: int,
|
||
|
conActRange: float,
|
||
|
gailConfig: GAILConfig,
|
||
|
):
|
||
|
self.stateSize = stateSize
|
||
|
self.disActShape = disActShape
|
||
|
self.disActSize = len(disActShape)
|
||
|
self.conActSize = conActSize
|
||
|
self.conActRange = conActRange
|
||
|
|
||
|
self.totalActSize = self.disActSize + conActSize
|
||
|
self.discrimInputSize = stateSize + self.totalActSize
|
||
|
self.discriminatorNNShape = gailConfig.discrimNNShape
|
||
|
self.discrimLR = gailConfig.discrimLR
|
||
|
self.discrimTrainEpochs = gailConfig.discrimTrainEpochs
|
||
|
self.ppoConfig = gailConfig.ppoConfig
|
||
|
|
||
|
self.ppo = PPO(stateSize, disActShape, conActSize, conActRange, self.ppoConfig)
|
||
|
self.discriminator = self.buildDiscriminatorNet(True)
|
||
|
|
||
|
def buildDiscriminatorNet(self, compileModel: bool):
|
||
|
# -----------Input Layers-----------
|
||
|
stateInput = layers.Input(shape=(self.discrimInputSize,), name="stateInput")
|
||
|
|
||
|
# -------Intermediate layers--------
|
||
|
interLayers = []
|
||
|
interLayersIndex = 0
|
||
|
for neuralUnit in self.discriminatorNNShape:
|
||
|
thisLayerName = "dense" + str(interLayersIndex)
|
||
|
if interLayersIndex == 0:
|
||
|
interLayers.append(
|
||
|
layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(stateInput)
|
||
|
)
|
||
|
else:
|
||
|
interLayers.append(
|
||
|
layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(interLayers[-1])
|
||
|
)
|
||
|
interLayersIndex += 1
|
||
|
|
||
|
# ----------Output Layers-----------
|
||
|
output = layers.Dense(1, activation="sigmoid")(interLayers[-1])
|
||
|
|
||
|
# ----------Model Compile-----------
|
||
|
model = keras.Model(inputs=stateInput, outputs=output)
|
||
|
if compileModel:
|
||
|
criticOPT = optimizers.Adam(learning_rate=self.discrimLR)
|
||
|
model.compile(optimizer=criticOPT, loss=self.discrimLoss())
|
||
|
return model
|
||
|
|
||
|
def discrimLoss(self):
|
||
|
def loss(y_true, y_pred):
|
||
|
"""discriminator loss function
|
||
|
|
||
|
Args:
|
||
|
y_true (tf.constant): demo trajectory
|
||
|
y_pred (tf.constant): agent trajectory predict value
|
||
|
|
||
|
Returns:
|
||
|
_type_: _description_
|
||
|
"""
|
||
|
demoP = self.discriminator(y_true)
|
||
|
agentLoss = tf.negative(tf.reduce_mean(tf.math.log(1.0 - y_pred + EPS)))
|
||
|
demoLoss = tf.negative(tf.reduce_mean(tf.math.log(demoP + EPS)))
|
||
|
loss = agentLoss + demoLoss
|
||
|
return loss
|
||
|
|
||
|
return loss
|
||
|
|
||
|
def inference(self, states: ndarray, actions: ndarray):
|
||
|
"""discriminator predict result
|
||
|
|
||
|
Args:
|
||
|
states (ndarray): states
|
||
|
actions (ndarray): actions
|
||
|
|
||
|
Returns:
|
||
|
tf.constant: discrim predict result
|
||
|
"""
|
||
|
# check dimention
|
||
|
if states.ndim != 2:
|
||
|
stateNum = int(len(states) / self.stateSize)
|
||
|
states = states.reshape([stateNum, self.stateSize])
|
||
|
if actions.ndim != 2:
|
||
|
actionsNum = int(len(actions) / self.totalActSize)
|
||
|
actions = actions.reshape([actionsNum, self.totalActSize])
|
||
|
|
||
|
thisTrajectory = tf.concat([states, actions], axis=1)
|
||
|
discrimPredict = self.discriminator(thisTrajectory)
|
||
|
return discrimPredict
|
||
|
|
||
|
def discriminatorACC(
|
||
|
self, demoStates: ndarray, demoActions: ndarray, agentStates: ndarray, agentActions: ndarray
|
||
|
):
|
||
|
demoAcc = np.mean(self.inference(demoStates, demoActions))
|
||
|
agentAcc = np.mean(self.inference(agentStates, agentActions))
|
||
|
return demoAcc, agentAcc
|
||
|
|
||
|
def trainDiscriminator(
|
||
|
self,
|
||
|
demoStates: ndarray,
|
||
|
demoActions: ndarray,
|
||
|
agentStates: ndarray,
|
||
|
agentActions: ndarray,
|
||
|
epochs: int = None,
|
||
|
):
|
||
|
"""train Discriminator
|
||
|
|
||
|
Args:
|
||
|
demoStates (ndarray): expert states
|
||
|
demoActions (ndarray): expert actions
|
||
|
agentStates (ndarray): agentPPO generated states
|
||
|
agentActions (ndarray): agentPPO generated actions
|
||
|
epoch (int): epoch times
|
||
|
|
||
|
Returns:
|
||
|
tf.constant: all losses array
|
||
|
"""
|
||
|
if epochs == None:
|
||
|
epochs = self.discrimTrainEpochs
|
||
|
demoTrajectory = tf.concat([demoStates, demoActions], axis=1)
|
||
|
agentTrajectory = tf.concat([agentStates, agentActions], axis=1)
|
||
|
his = self.discriminator.fit(x=agentTrajectory, y=demoTrajectory, epochs=epochs, verbose=0)
|
||
|
|
||
|
demoAcc = np.mean(self.inference(demoStates, demoActions))
|
||
|
agentAcc = np.mean(self.inference(agentStates, agentActions))
|
||
|
return his.history["loss"], demoAcc, 1 - agentAcc
|
||
|
|
||
|
def getActions(self, state: ndarray):
|
||
|
actions, predictResult = self.ppo.chooseAction(state)
|
||
|
return actions, predictResult
|
||
|
|
||
|
def trainPPO(
|
||
|
self,
|
||
|
states: ndarray,
|
||
|
oldActorResult: ndarray,
|
||
|
actions: ndarray,
|
||
|
newRewards: ndarray,
|
||
|
dones: ndarray,
|
||
|
nextState: ndarray,
|
||
|
epochs: int = None,
|
||
|
):
|
||
|
criticV = self.ppo.getCriticV(states)
|
||
|
discountedR = self.ppo.discountReward(nextState, criticV, dones, newRewards)
|
||
|
advantage = self.ppo.getGAE(discountedR, criticV)
|
||
|
criticLosses = self.ppo.trainCritic(states, discountedR, epochs)
|
||
|
actorLosses = self.ppo.trainActor(states, oldActorResult, actions, advantage, epochs)
|
||
|
return actorLosses, criticLosses
|
||
|
|
||
|
def generateAction(self, states: ndarray):
|
||
|
act, actorP = self.ppo.chooseAction(states)
|
||
|
return act, actorP
|