import gym
import numpy as np

from numpy import ndarray
from mlagents_envs.base_env import ActionTuple
from mlagents_envs.environment import UnityEnvironment


class Aimbot(gym.Env):
    def __init__(
        self,
        envPath: str,
        workerID: int = 1,
        basePort: int = 100,
    ):
        super(Aimbot, self).__init__()
        self.env = UnityEnvironment(
            file_name=envPath,
            seed=1,
            side_channels=[],
            worker_id=workerID,
            base_port=basePort,
        )
        self.env.reset()
        # all behavior_specs
        self.unity_specs = self.env.behavior_specs
        #  environment behavior name
        self.unity_beha_name = list(self.unity_specs)[0]
        #  environment behavior spec
        self.unity_specs = self.unity_specs[self.unity_beha_name]
        #  environment observation_space
        self.unity_obs_specs = self.unity_specs.observation_specs[0]
        #  environment action specs
        self.unity_action_spec = self.unity_specs.action_spec
        #  environment sample observation
        decisionSteps, _ = self.env.get_steps(self.unity_beha_name)

        # OBSERVATION SPECS
        #  environment state shape. like tuple:(93,)
        self.unity_observation_shape = self.unity_obs_specs.shape

        # ACTION SPECS
        #  environment continuous action number. int
        self.unity_continuous_size = self.unity_action_spec.continuous_size
        #  environment discrete action shapes. list (3,3,2)
        self.unity_discrete_branches = self.unity_action_spec.discrete_branches
        #  environment discrete action type. int 3
        self.unity_discrete_type = self.unity_action_spec.discrete_size
        # environment discrete action type. int 3+3+2=8
        self.unity_discrete_size = sum(self.unity_discrete_branches)

        # AGENT SPECS
        # all agents ID
        self.unity_agent_IDS = decisionSteps.agent_id
        # agents number
        self.unity_agent_num = len(self.unity_agent_IDS)

    def reset(self):
        """reset enviroment and get observations

        Returns:
            ndarray: nextState, reward, done, loadDir, saveNow
        """
        # reset env
        self.env.reset()
        nextState, reward, done = self.getSteps()
        return nextState, reward, done

    # TODO:
    # delete all stack state DONE
    # getstep State disassembly function DONE
    # delete agent selection function DONE
    # self.step action wrapper function DONE
    def step(
        self,
        actions: ndarray,
    ):
        """change ations list to ActionTuple then send it to enviroment

        Args:
            actions (ndarray): PPO chooseAction output action list.(agentNum,actionNum)

        Returns:
            ndarray: nextState, reward, done
        """
        # take action to enviroment
        # return mextState,reward,done
        if self.unity_discrete_size == 0:
            # create empty discrete action
            discreteActions = np.asarray([[0]])
        else:
            # create discrete action from actions list
            discreteActions = actions[:, 0 : self.unity_discrete_size]
        """
        if self.unity_continuous_size == 0:
            # create empty continuous action
            continuousActions = np.asanyarray([[0.0]])
        else:
            # create continuous actions from actions list
            continuousActions = actions[:,self.unity_discrete_size :]
        """
        continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]])
        # create actionTuple
        thisActionTuple = ActionTuple(continuous=continuousActions, discrete=discreteActions)
        # take action to env
        self.env.set_actions(behavior_name=self.unity_beha_name, action=thisActionTuple)
        self.env.step()
        # get nextState & reward & done after this action
        nextStates, rewards, dones = self.getSteps()
        return nextStates, rewards, dones

    def getSteps(self):
        """get enviroment now observations.
        Include State, Reward, Done

        Args:

        Returns:
            ndarray: nextState, reward, done
        """
        # get nextState & reward & done
        decisionSteps, terminalSteps = self.env.get_steps(self.unity_beha_name)
        nextStates = []
        dones = []
        rewards = []
        for thisAgentID in self.unity_agent_IDS:
            # while Episode over agentID will both in decisionSteps and terminalSteps.
            # avoid redundant state and reward,
            # use agentExist toggle to check if agent is already exist.
            agentExist = False
            # game done
            if thisAgentID in terminalSteps:
                nextStates.append(terminalSteps[thisAgentID].obs[0])
                dones.append(True)
                rewards.append(terminalSteps[thisAgentID].reward)
                agentExist = True
            # game not over yet and agent not in terminalSteps
            if (thisAgentID in decisionSteps) and (not agentExist):
                nextStates.append(decisionSteps[thisAgentID].obs[0])
                dones.append(False)
                rewards.append(decisionSteps[thisAgentID].reward)

        return np.asarray(nextStates), rewards, dones

    def close(self):
        self.env.close()