204 lines
7.0 KiB
Python
204 lines
7.0 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
|
|
from PPO import PPO
|
|
from tensorflow import keras
|
|
from tensorflow.keras import layers
|
|
from tensorflow.keras import optimizers
|
|
|
|
from GAILConfig import GAILConfig
|
|
|
|
EPS = 1e-6
|
|
|
|
|
|
class GAIL(object):
|
|
def __init__(
|
|
self,
|
|
stateSize: int,
|
|
disActShape: list,
|
|
conActSize: int,
|
|
conActRange: float,
|
|
gailConfig: GAILConfig,
|
|
):
|
|
if disActShape == [0]:
|
|
# non dis action output
|
|
self.disActSize = 0
|
|
self.disOutputSize = 0
|
|
else:
|
|
try:
|
|
if np.any(np.array(disActShape) <= 1):
|
|
raise ValueError(
|
|
"disActShape error,disActShape should greater than 1 but get", disActShape
|
|
)
|
|
except ValueError:
|
|
raise
|
|
self.disActSize = len(disActShape)
|
|
self.disOutputSize = sum(disActShape)
|
|
|
|
self.stateSize = stateSize
|
|
self.disActShape = disActShape
|
|
self.conActSize = conActSize
|
|
self.conActRange = conActRange
|
|
|
|
self.totalActSize = self.disActSize + conActSize
|
|
self.discrimInputSize = stateSize + self.totalActSize
|
|
self.discriminatorNNShape = gailConfig.discrimNNShape
|
|
self.discrimLR = gailConfig.discrimLR
|
|
self.discrimTrainEpochs = gailConfig.discrimTrainEpochs
|
|
self.discrimSaveDir = gailConfig.discrimSaveDir
|
|
self.ppoConfig = gailConfig.ppoConfig
|
|
|
|
self.ppo = PPO(stateSize, disActShape, conActSize, conActRange, self.ppoConfig)
|
|
self.discriminator = self.buildDiscriminatorNet(True)
|
|
|
|
def buildDiscriminatorNet(self, compileModel: bool):
|
|
# -----------Input Layers-----------
|
|
stateInput = layers.Input(shape=(self.discrimInputSize,), name="stateInput")
|
|
|
|
# -------Intermediate layers--------
|
|
interLayers = []
|
|
interLayersIndex = 0
|
|
for neuralUnit in self.discriminatorNNShape:
|
|
thisLayerName = "dense" + str(interLayersIndex)
|
|
if interLayersIndex == 0:
|
|
interLayers.append(
|
|
layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(stateInput)
|
|
)
|
|
else:
|
|
interLayers.append(
|
|
layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(interLayers[-1])
|
|
)
|
|
interLayersIndex += 1
|
|
|
|
# ----------Output Layers-----------
|
|
output = layers.Dense(1, activation="sigmoid")(interLayers[-1])
|
|
|
|
# ----------Model Compile-----------
|
|
model = keras.Model(inputs=stateInput, outputs=output)
|
|
if compileModel:
|
|
criticOPT = optimizers.Adam(learning_rate=self.discrimLR)
|
|
model.compile(optimizer=criticOPT, loss=self.discrimLoss())
|
|
return model
|
|
|
|
def discrimLoss(self):
|
|
def loss(y_true, y_pred):
|
|
"""discriminator loss function
|
|
|
|
Args:
|
|
y_true (tf.constant): demo trajectory
|
|
y_pred (tf.constant): agent trajectory predict value
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
demoP = self.discriminator(y_true)
|
|
agentLoss = tf.negative(tf.reduce_mean(tf.math.log(1.0 - y_pred + EPS)))
|
|
demoLoss = tf.negative(tf.reduce_mean(tf.math.log(demoP + EPS)))
|
|
loss = agentLoss + demoLoss
|
|
return loss
|
|
|
|
return loss
|
|
|
|
def inference(self, states: ndarray, actions: ndarray):
|
|
"""discriminator predict result
|
|
|
|
Args:
|
|
states (ndarray): states
|
|
actions (ndarray): actions
|
|
|
|
Returns:
|
|
tf.constant: discrim predict result
|
|
"""
|
|
# check dimention
|
|
if states.ndim != 2:
|
|
stateNum = int(len(states) / self.stateSize)
|
|
states = states.reshape([stateNum, self.stateSize])
|
|
if actions.ndim != 2:
|
|
actionsNum = int(len(actions) / self.totalActSize)
|
|
actions = actions.reshape([actionsNum, self.totalActSize])
|
|
|
|
thisTrajectory = np.append(states, actions, axis=1)
|
|
discrimPredict = self.discriminator(thisTrajectory)
|
|
return discrimPredict
|
|
|
|
def discriminatorACC(
|
|
self, demoStates: ndarray, demoActions: ndarray, agentStates: ndarray, agentActions: ndarray
|
|
):
|
|
demoAcc = np.mean(self.inference(demoStates, demoActions))
|
|
agentAcc = np.mean(self.inference(agentStates, agentActions))
|
|
return demoAcc, agentAcc
|
|
|
|
def trainDiscriminator(
|
|
self,
|
|
demoStates: ndarray,
|
|
demoActions: ndarray,
|
|
agentStates: ndarray,
|
|
agentActions: ndarray,
|
|
epochs: int = None,
|
|
):
|
|
"""train Discriminator
|
|
|
|
Args:
|
|
demoStates (ndarray): expert states
|
|
demoActions (ndarray): expert actions
|
|
agentStates (ndarray): agentPPO generated states
|
|
agentActions (ndarray): agentPPO generated actions
|
|
epoch (int): epoch times
|
|
|
|
Returns:
|
|
tf.constant: all losses array
|
|
"""
|
|
if epochs == None:
|
|
epochs = self.discrimTrainEpochs
|
|
demoTrajectory = np.append(demoStates, demoActions, axis=1)
|
|
agentTrajectory = np.append(agentStates, agentActions, axis=1)
|
|
his = self.discriminator.fit(x=agentTrajectory, y=demoTrajectory, epochs=epochs, verbose=0)
|
|
|
|
demoAcc = np.mean(self.inference(demoStates, demoActions))
|
|
agentAcc = np.mean(self.inference(agentStates, agentActions))
|
|
return his.history["loss"], demoAcc, 1 - agentAcc
|
|
|
|
def getActions(self, state: ndarray):
|
|
"""Agent choose action to take
|
|
|
|
Args:
|
|
state (ndarray): enviroment state
|
|
|
|
Returns:
|
|
np.array:
|
|
actions,
|
|
actions list,2dims like [[0],[1],[1.5]]
|
|
predictResult,
|
|
actor NN predict Result output
|
|
"""
|
|
actions, predictResult = self.ppo.chooseAction(state)
|
|
return actions, predictResult
|
|
|
|
def trainPPO(
|
|
self,
|
|
states: ndarray,
|
|
oldActorResult: ndarray,
|
|
actions: ndarray,
|
|
newRewards: ndarray,
|
|
dones: ndarray,
|
|
nextState: ndarray,
|
|
epochs: int = None,
|
|
):
|
|
criticV = self.ppo.getCriticV(states)
|
|
discountedR = self.ppo.discountReward(nextState, criticV, dones, newRewards)
|
|
advantage = self.ppo.getGAE(discountedR, criticV)
|
|
criticLosses = self.ppo.trainCritic(states, discountedR, epochs)
|
|
actorLosses = self.ppo.trainActor(states, oldActorResult, actions, advantage, epochs)
|
|
return actorLosses, criticLosses
|
|
|
|
def saveWeights(self, score: float):
|
|
saveDir = self.discrimSaveDir + "discriminator/discriminator.ckpt"
|
|
self.discriminator.save_weights(saveDir, save_format="tf")
|
|
print("GAIL Model's Weights Saved")
|
|
self.ppo.saveWeights(score=score)
|
|
|
|
def generateAction(self, states: ndarray):
|
|
act, actorP = self.ppo.chooseAction(states)
|
|
return act, actorP
|