Dict Observation Space for Stable Baselines3 Not Working 稳定基线3中的字典观察空间不起作用
Dict Observation Space for Stable Baselines3 Not Working 稳定基线3中的字典观察空间不起作用
下面是一个简化的可复现示例,可以在新的Google Colab笔记本中运行以方便。一旦第一个安装完成,只需运行时 > 重新启动并全部运行
即可生效。
我创建了一个简单的轮盘游戏环境用于测试。对于观察空间,我创建了一个gym.spaces.Dict()
,你将会看到(代码有很好的注释)。
训练没有问题,但当到达测试迭代时,我遇到了错误:
ValueError Traceback (most recent call last)in 6 obs = env.reset() 7 for i in range(1000): ----> 8 action, _state = model.predict(obs, deterministic=True) 9 obs, reward, done, info = env.step(action) 10 env.render() ValueError: Error: Box环境的意外观察形状(),请使用(1,)或(n_env, 1)作为观察形状。
我在某个地方读到,字典空间需要使用gym.wrappers.FlattenObservation来展平,所以我将这行代码更改为:
action, _state = model.predict(FlattenObservation(obs), deterministic=True)
...结果出现了这个错误:
AttributeError Traceback (most recent call last)in 6 obs = env.reset() 7 for i in range(1000): ----> 8 action, _state = model.predict(FlattenObservation(obs), deterministic=True) 9 obs, reward, done, info = env.step(action) 10 env.render() AttributeError: 'collections.OrderedDict'对象没有属性'observation_space'
我还尝试了这样做,结果出现与上一个错误相同的错误:
obs = env.reset() obs = FlattenObservation(obs)
所以显然我做错了什么,但我不知道是什么,因为这是我第一次使用Dict
空间。
import os, sys if not os.path.isdir('/usr/local/lib/python3.7/dist-packages/stable_baselines3'): !pip3 install stable_baselines3 print("\n\n\n Stable Baselines3 has been installed, Restart and Run All now. DO NOT factory reset, or you'll have to start over\n") sys.exit(0) from random import randint from numpy import inf, float32, array, int32, int64 import gym from gym.wrappers import FlattenObservation from stable_baselines3 import A2C, DQN, PPO """Roulette environment class""" class Roulette_Environment(gym.Env): metadata = {'render.modes': ['human', 'text']} """Initialize the environment""" def __init__(self): super(Roulette_Environment, self).__init__() # Some global variables self.max_table_limit = 1000 self.initial_bankroll = 2000 # Spaces # Each number on roulette board can have 0-1000 units placed on it self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,)) # We're going to keep track of how many times each number shows up # while we're playing, plus our current bankroll and the max # table betting limit so the agent knows how much $ in total is allowed # to be placed on the table. Going to use a Dict space for this. self.observation_space = gym.spaces.Dict( { "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int), "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int), } ) """Reset the Environment""" def reset(self): self.current_bankroll = self.initial_bankroll self.done = False # Take a sample from the observation_space to modify the values of self.current_state = self.observation_space.sample() # Reset each number being tracked throughout gameplay to 0 for i in range(0, 37): self.current_state[str(i)] = 0 # Reset our globals self.current_state['current_bankroll'] = self.current_bankroll self.current_state['max_table_limit'] = self.max_table_limit return self.current_state """Step Through the Environment""" def step(self, action): # Convert actions to ints cuz they show up as floats, # even when defined as ints in the environment. # https://github.com/openai/gym/issues/3107 for i in range(len(action)): action[i] = int(action[i]) self.current_action = action # Subtract your bets from bankroll sum_of_bets = sum([bet for bet in self.current_action]) # Spin the wheel self.current_number = randint(a=0, b=36) # Calculate payout/reward self.reward = 36 * self.current_action[self.current_number] - sum_of_bets self.current_bankroll += self.reward # Update the current state self.current_state['current_bankroll'] = self.current_bankroll self.current_state[str(self.current_number)] += 1 # If we've doubled our money, or lost our money if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0: self.done = True return self.current_state, self.reward, self.done, {} """Render the Environment""" def render(self, mode='text'): # Text rendering if mode == "text": print(f'Bets Placed: {self.current_action}') print(f'Number rolled: {self.current_number}') print(f'Reward: {self.reward}') print(f'New Bankroll: {self.current_bankroll}') env = Roulette_Environment() model = PPO('MultiInputPolicy', env, verbose=1) model.learn(total_timesteps=10000) obs = env.reset() # obs = FlattenObservation(obs) for i in range(1000): action, _state = model.predict(obs, deterministic=True) # action, _state = model.predict(FlattenObservation(obs), deterministic=True) obs, reward, done, info = env.step(action) env.render() if done: obs = env.reset()