RL-05-04-结构-Prioritized-Replay

← 上级:RL-05.专属数据结构 · 算法:RL-03-07-算法-DQN变体

PER 按 TD 误差 $|\delta|$ 优先采样「学得差」的 transition,提升样本效率。


一、优先级

$$
p_i = (|\delta_i| + \epsilon)^\alpha, \quad P(i) = \frac{p_i}{\sum_j p_j}
$$

超参 典型
$\alpha$ 0.6(0=均匀,1=完全优先)
$\epsilon$ 1e-6 防零

二、重要性采样权重

$$
w_i = \left( N \cdot P(i) \right)^{-\beta}, \quad \text{归一化 } w_i \leftarrow \frac{w_i}{\max_j w_j}
$$

$\beta$ 从 0.4 退火到 1.0。损失:

$$
L = \mathbb{E}[ w_i \cdot \delta_i^2 ]
$$


三、SumTree 结构

二叉树叶子存 priority,父节点 = 子节点和,$O(\log N)$ 采样与更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class SumTree:
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = [None] * capacity
self.ptr = 0
self.size = 0

def _propagate(self, idx, change):
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)

def add(self, priority, data):
idx = self.ptr + self.capacity - 1
self.data[self.ptr] = data
self.update(idx, priority)
self.ptr = (self.ptr + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)

def update(self, idx, priority):
change = priority - self.tree[idx]
self.tree[idx] = priority
self._propagate(idx, change)

def get(self, value):
# 从根向下找叶子
...

工程可用开源 PER Buffer 或 SB3 内置。


四、流程

  1. sample 按 SumTree 抽 index + 算 $w_i$
  2. 算 TD loss 加权
  3. 用 $|\delta|+\epsilon$ 更新该 index 优先级

五、小结

-------------本文结束感谢您的阅读-------------