176 lines
4.9 KiB
Python
176 lines
4.9 KiB
Python
import os
|
|
import random
|
|
import numpy as np
|
|
|
|
|
|
class GAILMem(object):
|
|
def __init__(self):
|
|
self.states = []
|
|
self.actorProbs = []
|
|
self.actions = []
|
|
self.rewards = []
|
|
self.dones = []
|
|
self.memNum = 0
|
|
print("√√√√√Buffer Initialized Success√√√√√")
|
|
|
|
def clearMem(self):
|
|
"""clearMemories"""
|
|
self.states = []
|
|
self.actorProbs = []
|
|
self.actions = []
|
|
self.rewards = []
|
|
self.dones = []
|
|
self.memNum = 0
|
|
|
|
def saveMemtoFile(self, dir: str):
|
|
"""save memories ndarray to npz file
|
|
|
|
Args:
|
|
dir (str): save direction,like"GAIL-Expert-Data/",end with "/"
|
|
"""
|
|
statesNP = np.asarray(self.states)
|
|
actorProbsNP = np.asarray(self.actorProbs)
|
|
actionsNP = np.asarray(self.actions)
|
|
rewardsNP = np.asarray(self.rewards)
|
|
donesNP = np.asarray(self.dones)
|
|
thisSaveDir = dir + "pack-" + str(self.memNum)
|
|
try:
|
|
np.savez(
|
|
thisSaveDir,
|
|
states=statesNP,
|
|
actorProbs=actorProbsNP,
|
|
actions=actionsNP,
|
|
rewards=rewardsNP,
|
|
dones=donesNP,
|
|
)
|
|
except FileNotFoundError:
|
|
os.mkdir(dir)
|
|
np.savez(
|
|
thisSaveDir,
|
|
states=statesNP,
|
|
actorProbs=actorProbsNP,
|
|
actions=actionsNP,
|
|
rewards=rewardsNP,
|
|
dones=donesNP,
|
|
)
|
|
|
|
def loadMemFile(self, dir: str):
|
|
"""load memories from mpz file
|
|
|
|
Args:
|
|
dir (str): file direction
|
|
"""
|
|
self.clearMem()
|
|
memFile = np.load(dir, allow_pickle=True)
|
|
self.states = memFile["states"].tolist()
|
|
self.actorProbs = memFile["actorProbs"].tolist()
|
|
self.actions = memFile["actions"].tolist()
|
|
self.rewards = memFile["rewards"].tolist()
|
|
self.dones = memFile["dones"].tolist()
|
|
self.memNum = len(self.states)
|
|
|
|
def getRandomSample(self, sampleNum: int = 0):
|
|
"""get random unique sample set.
|
|
|
|
Args:
|
|
sampleNum (int, optional): sample number, while 0 return all samples. Defaults to 0.
|
|
|
|
Returns:
|
|
tuple: (states,actorProbs,actions,rewards,dones)
|
|
"""
|
|
if sampleNum == 0:
|
|
return (
|
|
self.getStates(),
|
|
self.getActorProbs(),
|
|
self.getActions(),
|
|
self.getRewards(),
|
|
self.getDones(),
|
|
)
|
|
else:
|
|
randIndex = random.sample(range(0, self.memNum), sampleNum)
|
|
return (
|
|
self.standDims(np.asarray(self.states)[randIndex]),
|
|
self.standDims(np.asarray(self.actorProbs)[randIndex]),
|
|
self.standDims(np.asarray(self.actions)[randIndex]),
|
|
self.standDims(np.asarray(self.rewards)[randIndex]),
|
|
self.standDims(np.asarray(self.dones)[randIndex]),
|
|
)
|
|
|
|
def getStates(self):
|
|
"""get all States data as ndarray
|
|
|
|
Returns:
|
|
ndarray: ndarray type State data
|
|
"""
|
|
return self.standDims(np.asarray(self.states))
|
|
|
|
def getActorProbs(self):
|
|
"""get all ActorProbs data as ndarray
|
|
|
|
Returns:
|
|
ndarray: ndarray type ActorProbs data
|
|
"""
|
|
|
|
return self.standDims(np.asarray(self.actorProbs))
|
|
|
|
def getActions(self):
|
|
"""get all Actions data as ndarray
|
|
|
|
Returns:
|
|
ndarray: ndarray type Actions data
|
|
"""
|
|
|
|
return self.standDims(np.asarray(self.actions))
|
|
|
|
def getRewards(self):
|
|
"""get all Rewards data as ndarray
|
|
|
|
Returns:
|
|
ndarray: ndarray type Rewards data
|
|
"""
|
|
|
|
return self.standDims(np.asarray(self.rewards))
|
|
|
|
def getDones(self):
|
|
"""get all Dones data as ndarray
|
|
|
|
Returns:
|
|
ndarray: ndarray type Dones data
|
|
"""
|
|
|
|
return self.standDims(np.asarray(self.dones))
|
|
|
|
def standDims(self, data):
|
|
"""standalize data's dimension
|
|
|
|
Args:
|
|
data (list): data list
|
|
|
|
Returns:
|
|
ndarray: ndarra type data
|
|
"""
|
|
# standarlize data's dimension
|
|
if np.ndim(data) > 2:
|
|
return np.squeeze(data, axis=1)
|
|
elif np.ndim(data) < 2:
|
|
return np.expand_dims(data, axis=1)
|
|
else:
|
|
return np.asarray(data)
|
|
|
|
def saveMems(self, state, actorProb, action, reward, done):
|
|
"""save memories
|
|
|
|
Args:
|
|
state (_type_): sates
|
|
actorProb (_type_): actor predict result
|
|
action (_type_): actor choosed action
|
|
reward (_type_): reward
|
|
done (function): done
|
|
"""
|
|
self.states.append(state)
|
|
self.actorProbs.append(actorProb)
|
|
self.actions.append(action)
|
|
self.rewards.append(reward)
|
|
self.dones.append(done)
|
|
self.memNum += 1
|