import tensorflow as tf
import numpy as np
from numpy import ndarray

from PPO import PPO
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import optimizers

from GAILConfig import GAILConfig

EPS = 1e-6


class GAIL(object):
    def __init__(
        self,
        stateSize: int,
        disActShape: list,
        conActSize: int,
        conActRange: float,
        gailConfig: GAILConfig,
    ):
        if disActShape == [0]:
            # non dis action output
            self.disActSize = 0
            self.disOutputSize = 0
        else:
            try:
                if np.any(np.array(disActShape) <= 1):
                    raise ValueError(
                        "disActShape error,disActShape should greater than 1 but get", disActShape
                    )
            except ValueError:
                raise
            self.disActSize = len(disActShape)
            self.disOutputSize = sum(disActShape)

        self.stateSize = stateSize
        self.disActShape = disActShape
        self.conActSize = conActSize
        self.conActRange = conActRange

        self.totalActSize = self.disActSize + conActSize
        self.discrimInputSize = stateSize + self.totalActSize
        self.discriminatorNNShape = gailConfig.discrimNNShape
        self.discrimLR = gailConfig.discrimLR
        self.discrimTrainEpochs = gailConfig.discrimTrainEpochs
        self.discrimSaveDir = gailConfig.discrimSaveDir
        self.ppoConfig = gailConfig.ppoConfig

        self.ppo = PPO(stateSize, disActShape, conActSize, conActRange, self.ppoConfig)
        self.discriminator = self.buildDiscriminatorNet(True)

    def buildDiscriminatorNet(self, compileModel: bool):
        # -----------Input Layers-----------
        stateInput = layers.Input(shape=(self.discrimInputSize,), name="stateInput")

        # -------Intermediate layers--------
        interLayers = []
        interLayersIndex = 0
        for neuralUnit in self.discriminatorNNShape:
            thisLayerName = "dense" + str(interLayersIndex)
            if interLayersIndex == 0:
                interLayers.append(
                    layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(stateInput)
                )
            else:
                interLayers.append(
                    layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(interLayers[-1])
                )
            interLayersIndex += 1

        # ----------Output Layers-----------
        output = layers.Dense(1, activation="sigmoid")(interLayers[-1])

        # ----------Model Compile-----------
        model = keras.Model(inputs=stateInput, outputs=output)
        if compileModel:
            criticOPT = optimizers.Adam(learning_rate=self.discrimLR)
            model.compile(optimizer=criticOPT, loss=self.discrimLoss())
        return model

    def discrimLoss(self):
        def loss(y_true, y_pred):
            """discriminator loss function

            Args:
                y_true (tf.constant): demo trajectory
                y_pred (tf.constant): agent trajectory predict value

            Returns:
                _type_: _description_
            """
            demoP = self.discriminator(y_true)
            agentLoss = tf.negative(tf.reduce_mean(tf.math.log(1.0 - y_pred + EPS)))
            demoLoss = tf.negative(tf.reduce_mean(tf.math.log(demoP + EPS)))
            loss = agentLoss + demoLoss
            return loss

        return loss

    def inference(self, states: ndarray, actions: ndarray):
        """discriminator predict result

        Args:
            states (ndarray): states
            actions (ndarray): actions

        Returns:
            tf.constant: discrim predict result
        """
        # check dimention
        if states.ndim != 2:
            stateNum = int(len(states) / self.stateSize)
            states = states.reshape([stateNum, self.stateSize])
        if actions.ndim != 2:
            actionsNum = int(len(actions) / self.totalActSize)
            actions = actions.reshape([actionsNum, self.totalActSize])

        thisTrajectory = np.append(states, actions, axis=1)
        discrimPredict = self.discriminator(thisTrajectory)
        return discrimPredict

    def discriminatorACC(
        self, demoStates: ndarray, demoActions: ndarray, agentStates: ndarray, agentActions: ndarray
    ):
        demoAcc = np.mean(self.inference(demoStates, demoActions))
        agentAcc = np.mean(self.inference(agentStates, agentActions))
        return demoAcc, agentAcc

    def trainDiscriminator(
        self,
        demoStates: ndarray,
        demoActions: ndarray,
        agentStates: ndarray,
        agentActions: ndarray,
        epochs: int = None,
    ):
        """train Discriminator

        Args:
            demoStates (ndarray): expert states
            demoActions (ndarray): expert actions
            agentStates (ndarray): agentPPO generated states
            agentActions (ndarray): agentPPO generated actions
            epoch (int): epoch times

        Returns:
            tf.constant: all losses array
        """
        if epochs == None:
            epochs = self.discrimTrainEpochs
        demoTrajectory = np.append(demoStates, demoActions, axis=1)
        agentTrajectory = np.append(agentStates, agentActions, axis=1)
        his = self.discriminator.fit(x=agentTrajectory, y=demoTrajectory, epochs=epochs, verbose=0)

        demoAcc = np.mean(self.inference(demoStates, demoActions))
        agentAcc = np.mean(self.inference(agentStates, agentActions))
        return his.history["loss"], demoAcc, 1 - agentAcc

    def getActions(self, state: ndarray):
        """Agent choose action to take

        Args:
            state (ndarray): enviroment state

        Returns:
            np.array:
                actions,
                    actions list,2dims like [[0],[1],[1.5]]
                predictResult,
                    actor NN predict Result output
        """
        actions, predictResult = self.ppo.chooseAction(state)
        return actions, predictResult

    def trainPPO(
        self,
        states: ndarray,
        oldActorResult: ndarray,
        actions: ndarray,
        newRewards: ndarray,
        dones: ndarray,
        nextState: ndarray,
        epochs: int = None,
    ):
        criticV = self.ppo.getCriticV(states)
        discountedR = self.ppo.discountReward(nextState, criticV, dones, newRewards)
        advantage = self.ppo.getGAE(discountedR, criticV)
        criticLosses = self.ppo.trainCritic(states, discountedR, epochs)
        actorLosses = self.ppo.trainActor(states, oldActorResult, actions, advantage, epochs)
        return actorLosses, criticLosses

    def saveWeights(self, score: float):
        saveDir = self.discrimSaveDir + "discriminator/discriminator.ckpt"
        self.discriminator.save_weights(saveDir, save_format="tf")
        print("GAIL Model's Weights Saved")
        self.ppo.saveWeights(score=score)

    def generateAction(self, states: ndarray):
        act, actorP = self.ppo.chooseAction(states)
        return act, actorP