python - 如何提高 DQN 的性能?

标签 python machine-learning reinforcement-learning

我创建了一个深度 Q 网络来玩贪吃蛇。该代码运行良好,但在训练周期中性能并未真正提高。最后,它与采取随机操作的代理几乎没有区别。这是训练代码:

def train(self):
        self.build_model()
        for episode in range(self.max_episodes):
            self.current_episode = episode
            env = SnakeEnv(self.screen)
            episode_reward = 0
            for timestep in range(self.max_steps):
                env.render(self.screen)
                state = env.get_state()
                action = None
                epsilon = self.current_eps
                if epsilon > random.random():
                    action = np.random.choice(env.action_space) #explore
                else:
                    values = self.policy_model.predict(env.get_state()) #exploit
                    action = np.argmax(values)
                experience = env.step(action)
                if(experience['done'] == True):
                    episode_reward += 5 * (len(env.snake.List) - 1)
                    episode_reward += experience['reward']
                    break
                episode_reward += experience['reward']
                if(len(self.memory) < self.memory_size):
                    self.memory.append(Experience(experience['state'], experience['action'], experience['reward'], experience['next_state']))
                else:
                    self.memory[self.push_count % self.memory_size] = Experience(experience['state'], experience['action'], experience['reward'], experience['next_state'])
                self.push_count += 1
                self.decay_epsilon(episode)
                if self.can_sample_memory():
                    memory_sample = self.sample_memory()
                    #q_pred = np.zeros((self.batch_size, 1))
                    #q_target = np.zeros((self.batch_size, 1))
                    #i = 0
                    for memory in memory_sample:
                        memstate = memory.state
                        action = memory.action
                        next_state = memory.next_state
                        reward = memory.reward
                        max_q = reward + self.discount_rate * self.replay_model.predict(next_state)
                        #q_pred[i] = q_value
                        #q_target[i] = max_q
                        #i += 1
                        self.policy_model.fit(memstate, max_q, epochs=1, verbose=0)
            print("Episode: ", episode, " Total Reward: ", episode_reward)
            if episode % self.target_update == 0:
                self.replay_model.set_weights(self.policy_model.get_weights())
        self.policy_model.save_weights('weights.hdf5')
        pygame.quit() 

以下是超参数:

learning_rate = 0.5
discount_rate = 0.99
eps_start = 1
eps_end = .01
eps_decay = .001
memory_size = 100000
batch_size = 256
max_episodes = 1000
max_steps = 5000
target_update = 10

这是网络架构:

model = models.Sequential()
model.add(Dense(500, activation = 'relu', kernel_initializer = 'random_uniform', bias_initializer = 'zeros', input_dim = 400))
model.add(Dense(500, activation = 'relu', kernel_initializer = 'random_uniform', bias_initializer = 'zeros'))
model.add(Dense(5, activation = 'tanh', kernel_initializer = 'random_uniform', bias_initializer = 'zeros')) #tanh for last layer because q value can be > 1
model.compile(loss='mean_squared_error', optimizer = 'adam')

作为引用,网络输出 5 个值,因为蛇可以移动 4 个方向,另外 1 个值表示不采取任何行动。此外,我没有像传统 DQN 那样作为游戏的屏幕截图,而是传入一个 400 维向量作为游戏发生所在的 20 x 20 网格的表示。代理因靠近游戏而获得 1 的奖励。食物或吃掉它,如果它死了,它会得到-1的奖励。如何才能提高性能?

最佳答案

我认为主要问题是你的学习率太高。尝试使用低于 0.001 的值。 Atari DQN 使用 0.00025。

同时将 traget_update 设置为高于 10。例如 500 或更多。

要看到某些内容,步数应至少为 10000。

将批处理大小降低至 32 或 64。

您是否考虑过实现其他一些改进?喜欢 PER、决斗 DQN 吗? 看看这个:https://www.freecodecamp.org/news/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682/

也许您不想再次重新实现轮子,请考虑https://stable-baselines.readthedocs.io/en/master/

最后,您可以查看类似的项目:https://github.com/lukaskiss222/agarDQNbot

关于python - 如何提高 DQN 的性能?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57486988/

相关文章:

python - 将变量的变量列表作为文件传递给Docker

java - 使用 Java-ML 包进行集群

machine-learning - 强化学习的策略是什么?

machine-learning - 值(value)迭代和策略迭代有什么区别?

python - 函数逼近 : How is tile coding different from highly discretized state space?

python - 如何从azure文件共享下载整个目录

python - 莎士比亚编程语言帮助 - Windows

python - 如何高效地查找文档中的短语

machine-learning - RNN 的输入数据格式

tensorflow - 分布式 Tensorflow : who applies the parameter update?