Can Pre-Trained Text-to-Image Models Generate Visual Goals for Reinforcement Learning

发布时间 2023-11-28 15:56:53作者: Eirrac

概述

Learning form the Void (LfVoid) 根据给定的language instruction对observation进行appearance-based and structure-based修改得到goal images,为RL提供奖励信号。提升了example-based RL methods,无需reward function或者demonstration就可以解决一些robot control tasks

问题

provide guidance or goals to a robot

现有方法:

  • a set of expert demonstrations
  • goal images
  • natual language instructions

不足:

  • 收集数据比较费力甚至prohibitive
  • 多模态之间存在语义模糊

相似工作

  • 利用 DALL-E 2 生成 object rearrangement 的 goal images
    • portal
      • Dall-e-bot: Introducing web-scale diffusion models to robotics
    • drawback
      • 生成的图像太diverse,与真实世界场景不符合
      • 需要用segmentation masks进行物体对应
      • 需要用rule-based transformation planner来消除视觉差异
  • 对生成图像进行编辑,与language description对齐
    • portal
      • Prompt-to-prompt image editing with cross attention control
      • Imagic: Text-based real image editing with diffusion models
      • Directed diffusion: Direct control of object placement through attention guidance
    • advantage
      • 保留其他objects的visual appearance,背景几乎没有影响
    • drawback
      • 生成的图片examined by visual appearance而不是embodied tasks

相关领域

  • Image editing for diffusion models
    • 用 text-to-image diffusion models 生成图像,与给定的文本提示词在语义上是对齐的:Prompt-to-Prompt, Imagic[20], InstructPix2Pix[26], Textual-Inversion[27], DreamBooth[28], Directed Diffusion[21]
  • Example-based visual RL
    • 利用goal states的observation来引导RL学习过程:VICE[5], VICE-RAQ[6], [29], [30], RCE[7]
  • Large generative models for robot control
    • 在机器人控制中的大模型主要是用LLM进行planning或者学习一个language conditioned policy,图像生成模型没怎么用
    • 数据增广:CACTI[15], GenAug[16], ROSIE[17]
    • 生成plans:[36], [37], [38]
    • 训练一个text-to-video generation model,为planning和inverse modeling生成image序列:UniPi[39]
    • DALL-E-Bot[18](见相似工作)

方法

两步走:从observation生成goal image dataset;用example-based visual RL进行学习

Visual goal generation

Given: source prompt \(\mathcal P\), source image \(x_{src}\), editing instruction

Output: target image \(x_{tgt}\) (via Latent Diffusion Model: LDM)

  • 两种 Editing instruction
    • appearance-based: target prompt \(\mathcal{P}^*\),描述外观变化
    • structure-based: 边界框 \(\mathcal{B}\)(目标位置), tokens \(\mathcal{I}\) 对应object的描述

三个模块,diffusion 过程的图片序列\(\{x_t,t=T,T-1,\cdots,0\}\),其中\(x_T\)表示高斯噪声,\(x_0\)是生成图像

  • feature extracting module(参照DreamBooh[28])

    • 训练一个diffusion model和特殊token \(sks\),采用包含target object的target image
    • 最终model可以精准保留\(x_{src}\)的关键细节,例如:the color and texture of a cube to be positioned, the shape of a Franka robot arm for manipulation
  • inversion module

    • 采用DDIM对\(x_{src}\)进行invert,\(x_0\)是近似\(x'_{src}\)
    • 为了消除累计误差,采用了Null-text inversion[40]
  • editing module

    • appearance-based: 采用Latent Diffusion Model (LDM)来生成\(x_0\)

      • Prompt-to-Prompt表示\(x_0\)的空间设置很大程度取决于cross-attention maps中的cross-attention maps \(M_t\),尤其是在前几步扩散中;故而可以通过更换attention map来实现p2p

      • \[P2PEdit(M,M^*,t)=\begin{cases} M_t,\text{if }t>T-N\\ M_t^*,\text{otherwise} \end{cases} \]

    • structure-based: 提出了P2P-DD,是结合了Directed Diffusion[21]的P2P,用来更改图片的空间设置

      • DD实现object replacement的方式是对边界框\(\mathcal B\)计算一个高斯增强掩膜(stregthening mask SM),对其他区域进行注意力退火,施加一个恒定衰弱掩膜(weakening mask WM)

      • 两个掩膜由一个标量\(c\)进行权衡:

      • \[\text{DDEdit}(M_t,\mathcal B,\mathcal I)=\begin{cases} M_t\odot\text{WM}(\bar{\mathcal B},\mathcal I)+c\cdot\text{SM}(\mathcal B,\mathcal I)& \text{if } t>T-N\\ M_t&\text{otherwise} \end{cases} \]

      • 提出的P2P-DD保留了更多的source image的内容,形式为:

      • \[\text{P2P-DDEdit}(M_t,M_t^*,\mathcal B, \mathcal I)=\text{DDEdit}(\text{P2PEdit}(M,M*,t),\mathcal B,\mathcal I) \]

      • 上式只用于前N步diffusion中,之后停止对attention map的控制,让diffusion过程以传统去噪方式进行

      image

      上图P2P表示的就是把\(M_t^*\)\(M_t\)替换的过程,然后采用DD中对两个区域分别进行增强和衰弱。这两步分别对应了原有特征的保留和object位置的更改。

Example-based visual reinforcement learning

对VICE[5]进行了修改。流程:

  • 用随机初始化的环境的初始obs生成target image,每个task有1024个target image
  • 将target image作为正例,replay buffer中agent随机探索若干步之后的image作为负例
  • 用BCE loss对上面的数据训练一个鉴别器,用鉴别器的正输出作为agent的reward

细节:

  • 鉴别器和RL agent共享CNN encoder
  • 采用了label mixup method[6],对0-1标签和他们中间的向量进行随机线性插值,让labels变成0-1之间的连续值;此外作者发现限制重播缓冲区最近部分(最后5%)的负面实例可以提高鉴别器辨别目标和当前观察之间细微差异的能力
  • RL backbone采用了DrQ-v2[42],是对TD3进行图像表示增强的算法

实验

在Sim和Real-world做了三个task:LED-light, Push, Wipe

sim环境是Robosuite

实验图参见原文,这块儿比较直观。