from mlagents_envs.base_env import ActionTuple
from mlagents_envs.environment import UnityEnvironment

import numpy as np
from numpy import ndarray


class makeEnv(object):
    def __init__(
        self,
        envPath: str,
        workerID: int = 1,
        basePort: int = 100,
        stackSize: int = 1,
        stackIntercal: int = 0,
    ):
        self.env = UnityEnvironment(
            file_name=envPath,
            seed=1,
            side_channels=[],
            worker_id=workerID,
            base_port=basePort,
        )
        self.env.reset()

        # get enviroment specs
        self.LOAD_DIR_SIZE_IN_STATE = 3
        self.TRACKED_AGENT = -1
        self.BEHA_SPECS = self.env.behavior_specs
        self.BEHA_NAME = list(self.BEHA_SPECS)[0]
        self.SPEC = self.BEHA_SPECS[self.BEHA_NAME]
        self.OBSERVATION_SPECS = self.SPEC.observation_specs[0]  # observation spec
        self.ACTION_SPEC = self.SPEC.action_spec  # action specs

        self.DISCRETE_SIZE = self.ACTION_SPEC.discrete_size
        self.DISCRETE_SHAPE = list(self.ACTION_SPEC.discrete_branches)
        self.CONTINUOUS_SIZE = self.ACTION_SPEC.continuous_size
        self.SINGLE_STATE_SIZE = self.OBSERVATION_SPECS.shape[0] - self.LOAD_DIR_SIZE_IN_STATE
        self.STATE_SIZE = self.SINGLE_STATE_SIZE * stackSize

        # stacked State
        self.STACK_SIZE = stackSize
        self.STATE_BUFFER_SIZE = stackSize + ((stackSize - 1) * stackIntercal)
        self.STACK_INDEX = list(range(0, self.STATE_BUFFER_SIZE, stackIntercal + 1))
        self.statesBuffer = np.array([[0.0] * self.SINGLE_STATE_SIZE] * self.STATE_BUFFER_SIZE)
        print("√√√√√Enviroment Initialized Success√√√√√")

    def step(
        self,
        actions: list,
        behaviorName: ndarray = None,
        trackedAgent: int = None,
    ):
        """change ations list to ActionTuple then send it to enviroment

        Args:
            actions (list): PPO chooseAction output action list
            behaviorName (ndarray, optional): behaviorName. Defaults to None.
            trackedAgent (int, optional): trackedAgentID. Defaults to None.

        Returns:
            ndarray: nextState, reward, done, loadDir, saveNow
        """
        # take action to enviroment
        # return mextState,reward,done
        if self.DISCRETE_SIZE == 0:
            # create empty discrete action
            discreteActions = np.asarray([[0]])
        else:
            # create discrete action from actions list
            discreteActions = np.asanyarray([actions[0 : self.DISCRETE_SIZE]])
        if self.CONTINUOUS_SIZE == 0:
            # create empty continuous action
            continuousActions = np.asanyarray([[0.0]])
        else:
            # create continuous actions from actions list
            continuousActions = np.asanyarray([actions[self.DISCRETE_SIZE :]])

        if behaviorName is None:
            behaviorName = self.BEHA_NAME
        if trackedAgent is None:
            trackedAgent = self.TRACKED_AGENT

        # create actionTuple
        thisActionTuple = ActionTuple(continuous=continuousActions, discrete=discreteActions)
        # take action to env
        self.env.set_actions(behavior_name=behaviorName, action=thisActionTuple)
        self.env.step()
        # get nextState & reward & done after this action
        nextState, reward, done, loadDir, saveNow = self.getSteps(behaviorName, trackedAgent)
        return nextState, reward, done, loadDir, saveNow

    def getSteps(self, behaviorName=None, trackedAgent=None):
        """get enviroment now observations.
        Include State, Reward, Done, LoadDir, SaveNow

        Args:
            behaviorName (_type_, optional): behaviorName. Defaults to None.
            trackedAgent (_type_, optional): trackedAgent. Defaults to None.

        Returns:
            ndarray: nextState, reward, done, loadDir, saveNow
        """
        # get nextState & reward & done
        if behaviorName is None:
            behaviorName = self.BEHA_NAME
        decisionSteps, terminalSteps = self.env.get_steps(behaviorName)
        if self.TRACKED_AGENT == -1 and len(decisionSteps) >= 1:
            self.TRACKED_AGENT = decisionSteps.agent_id[0]
        if trackedAgent is None:
            trackedAgent = self.TRACKED_AGENT

        if trackedAgent in decisionSteps:  # ゲーム終了していない場合、環境状態がdecision_stepsに保存される
            nextState = decisionSteps[trackedAgent].obs[0]
            nextState = np.reshape(
                nextState, [self.SINGLE_STATE_SIZE + self.LOAD_DIR_SIZE_IN_STATE]
            )
            saveNow = nextState[-1]
            loadDir = nextState[-3:-1]
            nextState = nextState[:-3]
            reward = decisionSteps[trackedAgent].reward
            done = False
        if trackedAgent in terminalSteps:  # ゲーム終了した場合、環境状態がterminal_stepsに保存される
            nextState = terminalSteps[trackedAgent].obs[0]
            nextState = np.reshape(
                nextState, [self.SINGLE_STATE_SIZE + self.LOAD_DIR_SIZE_IN_STATE]
            )
            saveNow = nextState[-1]
            loadDir = nextState[-3:-1]
            nextState = nextState[:-3]
            reward = terminalSteps[trackedAgent].reward
            done = True

        # stack state
        stackedStates = self.stackStates(nextState)
        return stackedStates, reward, done, loadDir, saveNow

    def reset(self):
        """reset enviroment and get observations

        Returns:
            ndarray: nextState, reward, done, loadDir, saveNow
        """
        # reset buffer
        self.statesBuffer = np.array([[0.0] * self.SINGLE_STATE_SIZE] * self.STATE_BUFFER_SIZE)
        # reset env
        self.env.reset()
        nextState, reward, done, loadDir, saveNow = self.getSteps()
        return nextState, reward, done, loadDir, saveNow

    def stackStates(self, state):
        # save buffer
        self.statesBuffer[0:-1] = self.statesBuffer[1:]
        self.statesBuffer[-1] = state

        # return stacked states
        return np.reshape(self.statesBuffer[self.STACK_INDEX], (self.STATE_SIZE))

    def render(self):
        """render enviroment"""
        self.env.render()