在人工智能的广阔领域中,强化学习(RL)一直扮演着至关重要的角色,它让机器智能体能够像我们一样,通过与环境的试错交互来掌握复杂技能。然而,传统RL算法的核心——学习规则(或称更新规则),通常是由人类专家精心设计和固化的,例如我们熟知的Adam、SGD等优化器。它们就像一套固定的工具箱,虽然功能强大,但在面对千变万化的新任务时,未必总是最高效的。
一个革命性的问题随之而来:我们能否让机器自己学会如何学习?这便是元学习(Meta-Learning)的终极目标。Google DeepMind的disco_rl项目正是对这一宏大构想的精彩实践。它不满足于设计固定的学习算法,而是致力于发现和优化学习规则本身。本文将深入解析disco_rl提供的两个核心示例Notebook,带领读者一窥这个前沿领域的内部运作:首先,我们将学会如何使用一个已经“被发现”的强大更新规则Disco103来训练智能体;然后,我们将更进一步,探索如何从零开始,对一个更新规则进行元训练或微调。
#### 第一幕:挥舞神兵 —— 使用预训练的Disco103更新规则
想象一位铁匠学徒,他得到了一把由宗师打造的、近乎完美的锤子。他的任务不是去研究如何造锤子,而是直接用它来锻造最好的剑。eval.ipynb这个Notebook,正是指导我们如何扮演这位学徒的角色。
1. 舞台搭建:环境与智能体
整个过程始于环境的搭建。代码中使用了一个名为Catch的经典RL环境,这是一个可被JAX即时编译(JIT)的版本,极大地提升了运算效率。JAX是整个框架的基石,它提供的自动微分、JIT编译和便捷的并行化能力(jax.pmap)是实现这种大规模计算实验的关键。
# @title Instantiate a simple MLP agent.
def get_env(batch_size: int) -> base_env.Environment:
return jittable_envs.CatchJittableEnvironment(...)
# Create settings for an agent.
agent_settings = agent_lib.get_settings_disco()
agent_settings.net_settings.name = 'mlp'
agent_settings.net_settings.net_args = dict(...)
接着,我们定义了智能体的“大脑”——一个基于多层感知机(MLP)和长短期记忆网络(LSTM)的神经网络。LSTM的存在至关重要,它赋予了智能体记忆能力,使其能够处理需要依赖历史信息才能做出最优决策的任务。
2. 核心引擎:加载Disco103
这部戏剧的主角——Disco103更新规则,以一个.npz权重文件的形式登场。它不是一个像Adam那样由几行数学公式定义的算法,而是一个本身就由神经网络构成的、拥有大量参数的复杂函数。这些参数是DeepMind通过大规模元学习实验“发现”的。
# @title Download and unpack `Disco103` weights.
with open(f'/content/{disco_103_fname}', 'rb') as file:
disco_103_params = unflatten_params(np.load(file))
# Ensure that the agent's update rule's parameters have the same specs.
chex.assert_trees_all_equal_shapes_and_dtypes(
random_update_rule_params, disco_103_params
)
代码通过加载这些权重,并将它们应用到我们智能体的更新规则模块中,相当于给智能体装上了一个预先训练好的、高性能的“学习引擎”。
3. 训练循环:交互、存储与学习
训练过程遵循一个经典的“演员-学习者(Actor-Learner)”模式,但其实现方式充分利用了JAX的优势:
- 数据生成 (
unroll_jittable_actor): “演员”(Actor)在环境中执行策略,生成一系列经验轨迹(rollout)。这个过程被封装在unroll_jittable_actor函数中,并通过jax.lax.scan实现高效的循环展开,整个过程可以在TPU/GPU上高速运行。 - 经验回放 (
SimpleReplayBuffer): 收集到的经验数据被存入一个简单的回放缓冲区。这就像学徒把每次挥锤的经验都记录下来,以便日后复盘。 - 参数更新 (
learner_step_fn): “学习者”(Learner)从缓冲区中采样一批数据,然后调用agent.learner_step函数。这里的关键在于,参数更新不再依赖固定的Adam或SGD,而是由Disco103这个参数化的更新规则来计算梯度和新的网络权重。
# The training loop
for step in tqdm.tqdm(range(num_steps)):
# Generate new trajectories and add them to the buffer.
actor_rollout, actor_state, ts, env_state = unroll_actor(...)
buffer.add(actor_rollout)
# Update agent's parameters on the samples from the buffer.
if len(buffer) >= min_buffer_size:
learner_rollout = buffer.sample(batch_size)
learner_state, _, metrics = learner_step_fn(
...,
update_rule_params, # Applying Disco103 here
...
)
最终,通过绘制智能体在训练过程中的平均回报(avg_returns),我们可以直观地看到Disco103的强大效果。它指导着智能体的神经网络参数,高效地走向最优策略。
#### 第二幕:铸造神兵 —— 元训练自定义更新规则
如果我们不仅仅满足于使用现成的工具,而是想成为那位能够铸造神兵的宗师呢?meta_train.ipynb为我们揭示了这条更具挑战性的道路。这里的目标不再是训练一个解决Catch问题的智能体,而是训练那个指导智能体学习的“更新规则”本身。
这个过程好比一位铁匠宗师,他同时指导着一个班的学徒(a population of agents)。他让学徒们用当前版本的工具(update rule)去练习锻造(inner loop)。然后,他观察学徒们的最终作品(validation performance),并根据这些结果来改进工具的设计(outer loop / meta-update)。
1. 双层优化结构:内部循环与外部循环
元训练的核心是一个嵌套的优化结构:
- 内部循环 (Inner Loop): 在这一层,每个智能体都像
eval.ipynb中那样进行学习。它们使用当前固定的更新规则参数,在环境中收集经验,并更新自己的策略网络参数。这个过程会进行好几步(num_inner_steps)。 - 外部循环 (Outer Loop): 在这一层,我们评估经过内部循环训练后,智能体在一个新的验证任务(
valid_rollout)上的表现。基于这个表现,我们计算出元损失(meta-loss),并反向传播,用这个损失的元梯度(meta-gradient)来更新更新规则本身的参数。
calculate_meta_gradientcalculate_meta_gradient函数是元训练的心脏。它精妙地实现了上述双层优化:
# Inside calculate_meta_gradient
def _outer_loss(update_rule_params, ...):
# Perform N inner steps
(_, new_learner_state, ...), _ = jax.lax.scan(
_inner_step,
(update_rule_params, ...),
(train_rollouts, learner_rngs),
)
# Run inference on the validation rollout with the NEW agent params
agent_rollout_on_valid, _ = hk.BatchApply(...)
# Calculate meta loss (e.g., policy gradient loss on validation data)
meta_loss = pg_loss_per_step.mean() + reg_loss
return meta_loss, ...
# Calculate meta gradients using jax.grad
meta_grads, outputs = jax.grad(_outer_loss, has_aux=True)(...)
这段代码的逻辑可以分解为:
1. _outer_loss函数将update_rule_params作为输入。
2. 内部使用jax.lax.scan执行_inner_step数次,模拟了智能体的内部学习过程。重要的是,这个过程是完全可微分的。JAX会追踪update_rule_params如何影响_inner_step,并最终影响到new_learner_state(智能体的新参数)。
3. 接着,使用这个new_learner_state在全新的验证数据valid_rollout上进行评估,计算出一个策略梯度损失(pg_loss)。这个损失衡量了当前更新规则的好坏——一个好的更新规则应该能让智能体在经过几步学习后,在新数据上表现出色。
4. 最后,jax.grad对整个_outer_loss函数求导,自动计算出meta_loss相对于update_rule_params的梯度,即“元梯度”。3. 群体智能:meta_update与并行化
为了得到更稳定、更泛化的更新规则,元训练通常采用一群(num_agents)智能体并行进行。meta_update函数协调了整个过程:
1. 为每个智能体生成用于内部训练的轨迹(train_rollouts)和用于验证的轨迹(valid_rollouts)。
2. 并行地为每个智能体计算元梯度。
3. 将所有智能体的元梯度平均起来,得到一个最终的更新方向。
4. 使用一个元优化器(meta_opt = optax.adam(...))来应用这个平均梯度,更新全局的update_rule_params。
整个meta_update函数再次通过jax.pmap被分发到所有可用的计算设备上,实现了多智能体、多设备的双重并行,极大地加速了这个计算量巨大的过程。
#### 结论:从使用者到创造者的飞跃
disco_rl通过这两个精心设计的示例,为我们清晰地展示了元学习在强化学习领域的应用路径。eval.ipynb教会我们如何利用元学习的成果——像Disco103这样强大的、被发现的学习规则,来高效解决具体任务。它代表了AI工具的“使用者”视角。
而meta_train.ipynb则是一次深刻的范式转换,它将我们带到了“创造者”的层面。我们不再仅仅是优化一个策略网络,而是在一个更高的抽象层次上,优化学习过程本身。这就像从学习如何驾驶一辆车,跃升到了设计赛车引擎的层面。
尽管元训练的计算成本高昂且过程复杂,但它所开启的可能性是无限的。通过“学习如何学习”,我们有望创造出能够快速适应未知环境、泛化能力更强的通用智能体,这无疑是通往通用人工智能(AGI)道路上,一块坚实而又充满希望的基石。disco_rl不仅是一个代码库,更是一扇通往未来AI算法设计新世界的窗户。