import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import math
import copy
import datetime
import os

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from keras_radam import RAdam


class PPO(object):
    """Create PPO Agent
    """

    def __init__(self, stateSize, disActShape, conActSize, conActRange, criticLR, actorLR, gamma, epsilon, entropyWeight, saveDir, loadModelDir):
        
        # check disActShape is correct(greater than 1)
        try:
            if np.any(np.array(disActShape)<=1):
                raise ValueError("disActShape error,disActShape should greater than 1 but get",disActShape)
        except ValueError as e:
            raise
        
        self.stateSize = stateSize
        # self.actionSize = actionSize
        self.disActShape = disActShape # shape of discrete action output. like [3,3,2]
        self.disActSize = len(disActShape)
        self.conActSize = conActSize
        self.conActRange = conActRange
        self.criticLR = criticLR
        self.actorLR = actorLR
        self.GAMMA = gamma
        self.EPSILON = epsilon
        self.saveDir = saveDir
        self.entropyWeight = entropyWeight

        self.disOutputSize = sum(disActShape)
        self.conOutputSize = conActSize * 2
        
        if loadModelDir == None:
            # critc NN
            self.critic = self.buildCriticNet(self.stateSize, 1, compileModel = True)
            # actor NN
            self.actor = self.buildActorNet(self.stateSize, self.conActRange, compileModel = True)
        else:
            # critc NN
            self.critic = self.buildCriticNet(self.stateSize, 1, compileModel=True)
            # actor NN
            self.actor = self.buildActorNet(self.stateSize, self.conActRange, compileModel=True)
            # load weight to Critic&Actor NN
            self.loadWeightToModels(loadModelDir)
            

    # Build Net
    def buildActorNet(self, inputSize, continuousActionRange,compileModel):
        """build Actor Nueral Net and compile.Output:[disAct1,disAct2,disAct3,mu,sigma]

        Args:
            inputSize (int): InputLayer Nueral size.
            continuousActionRange (foat): continuous Action's max Range.

        Returns:
            keras.Model: return Actor NN
        """
        stateInput = layers.Input(shape=(inputSize,), name='stateInput')
        dense0 = layers.Dense(500, activation='relu',name='dense0',)(stateInput)
        dense1 = layers.Dense(200, activation='relu',name='dense1',)(dense0)
        dense2 = layers.Dense(100, activation='relu', name='dense2')(dense1)

        disAct1 = layers.Dense(3, activation='softmax',name='WSAction')(dense2)  # WS
        disAct2 = layers.Dense(3, activation='softmax',name='ADAction')(dense2)  # AD
        disAct3 = layers.Dense(2, activation='softmax',name='ShootAction')(dense2)  # Mouse shoot
        mu = continuousActionRange * layers.Dense(1, activation='tanh', name='muOut')(dense2)  # mu,既正态分布mean
        sigma = 1e-8 + layers.Dense(1, activation='softplus',name='sigmaOut')(dense2)  # sigma,既正态分布
        # musig = layers.concatenate([mu,sigma],name = 'musig')
        totalOut = layers.concatenate(
            [disAct1, disAct2, disAct3, mu, sigma], name='totalOut')  # package

        model = keras.Model(inputs=stateInput, outputs=totalOut)
        #actorOPT = optimizers.Adam(learning_rate = self.actorLR)
        if compileModel:
            actorOPT = RAdam(self.actorLR)
            model.compile(optimizer=actorOPT, loss=self.aLoss())
        return model

    def buildCriticNet(self, inputSize, outputSize,compileModel):
        """build Critic Nueral Net and compile.Output:[Q]

        Args:
            inputSize (int): InputLayer Neural Size
            outputSize (float): Q size

        Returns:
            keras.Model: return Critic NN
        """
        stateInput = keras.Input(shape=(inputSize,))
        dense0 = layers.Dense(500, activation='relu',
                              name='dense0',)(stateInput)
        dense1 = layers.Dense(200, activation='relu')(dense0)
        dense2 = layers.Dense(100, activation='relu')(dense1)
        output = layers.Dense(outputSize)(dense2)
        model = keras.Model(inputs=stateInput, outputs=output)
        if compileModel:
            criticOPT = optimizers.Adam(learning_rate=self.criticLR)
            model.compile(optimizer=criticOPT, loss=self.cLoss())
        return model

    # loss Function
    def cLoss(self):
        """Critic Loss function
        """
        def loss(y_true, y_pred):
            # y_true: discountedR
            # y_pred: critcV = model.predict(states)

            advantage = y_true - y_pred  # TD error
            loss = tf.reduce_mean(tf.square(advantage))
            return loss
        return loss

    def aLoss(self):
        def getDiscreteALoss(nowProbs,oldProbs,advantage):
            """get Discrete Action Loss

            Args:
                nowProbs (tf.constant): (length,actionSize)
                oldProbs (tf.constant): (length,actionSize)
                advantage (tf.constant): (length,)

            Returns:
                tf.constant: (length,)
            """
            entropy = tf.reduce_mean(tf.math.multiply(nowProbs,tf.math.log(nowProbs+1e-6)))
            ratio = tf.math.divide(nowProbs,oldProbs+1e-6)
            value = tf.math.multiply(ratio,tf.expand_dims(advantage,axis = 1))
            clipRatio = tf.clip_by_value(ratio,1. - self.EPSILON,1.+self.EPSILON)
            clipValue = tf.math.multiply(clipRatio,tf.expand_dims(advantage,axis = 1))
            loss = -tf.reduce_mean(tf.math.minimum(value,clipValue)) + self.entropyWeight * entropy
            return loss
        
        def getContinuousALoss(musig,actions,oldProbs,advantage):
            """get Continuous Action Loss

            Args:
                musig (tf.constant): (length,2)
                actions (tf.constant): (length,)
                oldProbs (tf.constant): (length,)
                advantage (tf.constant): (length,)

            Returns:
                tf.constant: (length,)
            """
            mu = musig[:,0]
            sigma = musig[:,1]
            dist = tfp.distributions.Normal(mu,sigma)
            
            nowProbs = dist.prob(actions)
            ratio = tf.math.divide(nowProbs,oldProbs+1e-6)
            entropy = tf.reduce_mean(dist.entropy())
            
            value = tf.math.multiply(ratio,tf.expand_dims(advantage,axis = 1))
            clipValue = tf.clip_by_value(ratio,1. - self.EPSILON,1.+self.EPSILON) * advantage
            loss = -tf.reduce_mean(tf.math.minimum(value,clipValue)) + self.entropyWeight * entropy
            return loss

        def loss(y_true, y_pred):
            # y_true: [[disAct1, disAct2, disAct3, mu, sigma]]
            # y_pred: muSigma = self.actor(state) = 
            # [[disAct1, disAct2, disAct3, mu, sigma]]
            oldDisProbs = y_true[:,0:self.disOutputSize]
            oldConMusigs = y_true[:,self.disOutputSize:self.disOutputSize+self.conActSize]
            conActions = y_true[:,self.disOutputSize+self.conActSize:self.disOutputSize+(self.conActSize*2)]
            advantage = y_true[:,-1]
            
            nowDisProbs = y_pred[:,0:self.disOutputSize] # [disAct1, disAct2, disAct3]
            nowConMusigs = y_pred[:,self.disOutputSize:] #[musig1,musig2]
            
            totalALoss = tf.constant([0.])
            totalActionNum = 0
            
            # for nowProb,oldProb in zip(tf.transpose(nowDisProbs,perm=[1,0,2]),tf.transpose(oldDisProbs,perm=[1,0,2])):
            lastDisActShape = 0
            for shape in self.disActShape:
                thisNowDisProbs = nowDisProbs[:,lastDisActShape:lastDisActShape+shape]
                thisOldDisProbs = oldDisProbs[:,lastDisActShape:lastDisActShape+shape]
                discreteALoss = getDiscreteALoss(thisNowDisProbs,thisOldDisProbs,advantage)
                lastDisActShape += shape
                totalALoss += discreteALoss
                totalActionNum += 1
            # for nowConMusig,conAction,oldPiProb in zip(tf.transpose(nowConMusigs,perm=[1,0,2]),conActions,oldPiProbs):
            lastConAct = 0
            for act in range(self.conActSize):
                thisNowConMusig = nowConMusigs[:,lastConAct:lastConAct+((act+1)*2)]
                thisOldConMusig = oldConMusigs[:,lastConAct:lastConAct+((act+1)*2)]
                thisConAction = conActions[:,act]
                continuousAloss = getContinuousALoss(thisNowConMusig,thisConAction,thisOldConMusig,advantage)
                totalALoss += continuousAloss
                totalActionNum += 1

            loss = tf.divide(totalALoss,totalActionNum)
            return loss
        return loss

    # get Action&V
    def chooseAction(self, state):
        """Agent choose action to take

        Args:
            state (np.array): enviroment state

        Returns:
            np.array: 
                disAct1,
                    discreteAction1
                disAct2,
                    discreteAction2
                disAct3,
                    discreteAction3
                conAction,
                    continuousAction
                predictResult,
                    actor NN predict Result output
        """
        # let actor choose action,use the normal distribution
        # state = np.expand_dims(state,0)
        
        # check state dimension is [1,statesize]
        if state.ndim!=2:
            state = state.reshape([1,self.stateSize])
        
        predictResult = self.actor(state)  # get predict result [[disAct1, disAct2, disAct3, musig]]
        predictResult = predictResult.numpy()
        disAct1Prob = predictResult[0][0:3]
        disAct2Prob = predictResult[0][3:6]
        disAct3Prob = predictResult[0][6:8]
        mu = predictResult[0][8]
        sigma = predictResult[0][9]
        if math.isnan(mu) or math.isnan(sigma):
            # check mu or sigma is nan
            print("mu or sigma is nan")

        disAct1 = np.argmax(disAct1Prob)  # WS 0 or 1 or 2
        disAct2 = np.argmax(disAct2Prob)  # AD 0 or 1 or 2
        disAct3 = np.argmax(disAct3Prob)  # mouse shoot 0 or 1
        normDist = np.random.normal(loc=mu, scale=sigma)  # normalDistribution
        conAction = np.clip(normDist, -self.conActRange,
                            self.conActRange)  # 在正态分布中随机get一个action
        return disAct1, disAct2, disAct3, conAction, predictResult

    def getCriticV(self, state):
        """get Critic predict V value

        Args:
            state (np.array): Env state

        Returns:
            tensor: retrun Critic predict result
        """
        # if state.ndim < 2:
        #    state = np.expand_dims(state,0)
        if state.ndim!=2:
            state = state.reshape([1,self.stateSize])
        return self.critic.predict(state)

    def discountReward(self, nextState, rewards):
        """Discount future rewards

        Args:
            nextState (np.array): next Env state 
            rewards (np.array): reward list of this episode

        Returns:
            np.array: discounted rewards list,same shape as rewards that input
        """
        # 降低未来的rewards
        nextV = self.getCriticV(nextState)
        discountedRewards = []
        for r in rewards[::-1]:
            nextV = r + self.GAMMA*nextV
            discountedRewards.append(nextV)
        discountedRewards.reverse()  # \ESREVER/
        discountedRewards = np.squeeze(discountedRewards)
        discountedRewards = np.expand_dims(discountedRewards, axis=1)
        #discountedRewards = np.array(discountedRewards)[:, np.newaxis]
        return discountedRewards

    def conProb(self, mu, sig, x):
        """calculate probability when x in Normal distribution(mu,sigma)

        Args:
            mu (np,array): mu
            sig (np.array): sigma
            x (np.array): x

        Returns:
            np.array: probabilities
        """
        # 获取在正态分布mu,sig下当取x值时的概率
        # return shape : (length,1)
        mu = np.reshape(mu, (np.size(mu),))
        sig = np.reshape(sig, (np.size(sig),))
        x = np.reshape(x, (np.size(x),))

        dist = tfp.distributions.Normal(mu, sig)
        prob = dist.prob(x)

        prob = np.reshape(prob, (np.size(x), 1))
        #dist = 1./(tf.sqrt(2.*np.pi)*sig)
        #prob = dist*tf.exp(-tf.square(x-mu)/(2.*tf.square(sig)))
        return prob

    def trainCritcActor(self, states, actions, rewards, nextState, criticEpochs, actorEpochs):
        # Train ActorNN and CriticNN
        # states: Buffer States
        # actions: Buffer Actions
        # rewards: Buffer Rewards,没有Discount处理
        # nextState: 下一个单独state
        # criticEpochs: just criticNN'Epochs
        # acotrEpochs: just acotrNN'Epochs
        discountedR = self.discountReward(nextState, rewards)

        criticMeanLoss = self.trainCritic(states, discountedR, criticEpochs)
        actorMeanLoss = self.trainActor(
            states, actions, discountedR, actorEpochs)
        print("A_Loss:", actorMeanLoss, "C_Loss:", criticMeanLoss)
        return actorMeanLoss, criticMeanLoss

    def trainCritic(self, states, discountedR, epochs):
        # Trian Critic
        # states: Buffer States
        # discountedR: Discounted Rewards
        # Epochs: just Epochs

        # IDK why this should be list...It just work...
        # If discountR in np.array type it will throw 'Failed to find data adapter that can handle'
        # discountedR = discountedR.tolist()
        his = self.critic.fit(x=states, y=discountedR,
                              epochs=epochs, verbose=0)
        return np.mean(his.history['loss'])

    def trainActor(self, states, actions, discountedR, epochs):
        """Actor NN trainning function

        Args:
            states (np.array): Env states
            actions (np.array): action history
            discountedR (np.array): discountedR
            epochs (int): epochs,how many time NN learning

        Returns:
            Average actor loss: this learning round's average actor loss
        """
        # Trian Actor
        # states: Buffer States
        # actions: Buffer Actions
        # discountedR: Discounted Rewards
        # Epochs: just Epochs

        states = np.asarray(states)
        actions = np.asarray(actions, dtype=np.float32)
        # predict with old Actor NN
        oldActorResult = self.actor.predict(states)
        
        # assembly Actions history
        disActions = actions[:,0:self.disActSize]
        conActions = actions[:,self.disActSize:]
        # assembly predictResult as old Actor's Result
        oldDisProbs = oldActorResult[:,0:self.disOutputSize] # [disAct1, disAct2, disAct3]
        oldConMusigs = oldActorResult[:,self.disOutputSize:] # [musig1,musig2]
        oldPiProbs = self.conProb(oldConMusigs[:, 0], oldConMusigs[:, 1], conActions)

        criticV = self.critic.predict(states)
        advantage = copy.deepcopy(discountedR - criticV)

        # pack [oldDisProbs,oldPiProbs,conActions,advantage] as y_true
        y_true = np.hstack((oldDisProbs,oldPiProbs,conActions,advantage))

        # train start
        if np.any(tf.math.is_nan(y_true)):
            print("y_true got nan")
            print("oldConMusigs",oldConMusigs)
            print("oldPiProbs",oldPiProbs)
            print("conActions",conActions)
            print("oldConMusigs",oldConMusigs)
        his = self.actor.fit(x=states, y=y_true, epochs=epochs, verbose=0)
        if np.any(tf.math.is_nan(his.history['loss'])):
            print("his.history['loss'] is nan!")
            print(his.history['loss'])
        return np.mean(his.history['loss'])

    def saveWeights(self,score = None):
        """save now NN's Weight. Use "models.save_weights" method. 
        Save as "tf" format "ckpt" file.

        Args:
            score (int): now score
        """
        actor_save_dir = self.saveDir+datetime.datetime.now().strftime("%H%M%S") + "/actor/" + "actor.ckpt"
        critic_save_dir = self.saveDir+datetime.datetime.now().strftime("%H%M%S") + "/critic/" + "critic.ckpt"
        self.actor.save_weights(actor_save_dir, save_format="tf")
        self.critic.save_weights(critic_save_dir, save_format="tf")
        if score != None:
            # create an empty file named  as score to recored score
            score_dir = self.saveDir+datetime.datetime.now().strftime("%H%M%S") + "/" + str(round(score))
            scorefile = open(score_dir,'w')
            scorefile.close()
        print("Model's Weights Saved")
    
    def loadWeightToModels(self,loadDir):
        """load NN Model. Use "models.load_weights()" method.
        Load "tf" format "ckpt" file.

        Args:
            loadDir (string): Model dir
        """
        actorDir = loadDir + "/actor/" + "actor.ckpt"
        criticDir = loadDir + "/critic/" + "critic.ckpt"
        self.actor.load_weights(actorDir)
        self.critic.load_weights(criticDir)
        
        print("++++++++++++++++++++++++++++++++++++")
        print("++++++++++++Model Loaded++++++++++++")
        print(loadDir)
        print("++++++++++++++++++++++++++++++++++++")