Published on

(Research Idea) Transformer-Based Multi-Task RL with Policy Distillation


Introduction

Autonomous agents are often driven by reinforcement learning (RL) algorithms, which train an agent to learn based on feedback from interaction with its environment. However, in high-risk scenarios, agents cannot learn from trial and error in real time, as there is no tolerance for error (ex: crashing on the highway). To overcome this challenge, one research area in the field of RL is offline learning, where the agent does not learn in real time by interacting with its environment but through previously collected data.

Additionally, reinforcement learning algorithms currently have limited applications because they are notoriously data inefficient and require vast amounts of good data. In complex settings, it is difficult and very expensive to obtain a large amount of task specific data. One solution to increasing the availability of data is multi-task learning, which leverages data from similar tasks. Multi-task learning is a challenging research area as training on multiple tasks poses a much more difficult optimization problem than training on one task. One prevalent challenge in learning from various tasks is conflicting updates.[1] This problem of conflicting updates creates the need to devise some order (a curriculum) of tasks to learn from.

Another problem that reinforcement learning algorithms face in training is lack of diversity and state space coverage when learning. For example, an autonomous vehicle looking to traverse a map that initially only learns from one area before learning from others can be susceptible to overfitting to local optima, leading to bad generalization and transfer to new environments. To fix this, one natural solution is to add a novelty bonus for visiting more novel states, as done in random network distillation.[2] Another method encourages state space coverage by rewarding greater diverging policy updates.[3]

Reinforcement learning algorithms can often be unstable and unpredictable in performance. One reason for its unstable nature is the “deadly triad” problem resulting from a combination of function approximation, bootstrapping, and off-policy learning. Chen et al.[4] show that transformers overcome this problem with their ability to model long sequences with self-attention and to perform credit assignment without bootstrapping. With its many applications (such as text generation with Chat-GPT and image generation with Dall-E), the transformer architecture shows great promise as a powerful model and has made many advancements due to lots of active research effort. Transformers also show great promise for RL; due to its flexibility in modeling various distributions, transformer models can better generalize to new data distributions and transfer to new environments compared to other RL algorithms.

Although transformer-based reinforcement learning algorithms can be quite powerful, a big limitation is precisely their size. Large models may not be applicable in settings with compute or latency restrictions. To extend the applicability of transformers to reinforcement learning settings, we propose minimizing the size of the model through knowledge distillation. Knowledge distillation, first introduced by Hinton et al.,[5] is a common technique used to reduce the size of a model. Rusu et al.[6] apply knowledge distillation in RL, where a much smaller policy network learns from much larger networks in the DQN setting.

Proposal

My project aims to alleviate the aforementioned challenges of lack of data, large model size, non-diverse state coverage, and conflicting updates through novelty-driven multi-task policy distillation in an offline learning transformer setting. Rusu et al.[6] show the promise of multi-task policy distillation through results obtained in a DQN setting. However, in their work, the distilled smaller (student) model learns from different agents (teachers) by cycling through different task data stores in some arbitrary fixed order. Similarly, Tseng et al.[7] shows the promise of policy distillation in a multi-agent learning framework; however, the centralized agent simply learns using data of all the agents all at once. My hypothesis is that data from various teachers are not all equal and that designing a well-constructed order of which teacher to learn from can improve learning. Specifically, by encouraging a more diverse curriculum—where the student learns from the teacher it is most dissimilar from—the student can overcome local optima that result from multi-task learning. Because the student learns from different teachers, we ensure that information that the student learns is well-aligned with what it previously learned by selecting non-conflicting gradient updates.

Method

To describe the set up more concretely, we have nn tasks and nn corresponding offline datasets. For each task dataset, we learn a policy using a transformer model as described in Chen et al.[4] We then distill knowledge from the nn teacher policies into a student by learning a policy that aligns with each teacher policy on their corresponding dataset. To determine which teacher the student learns from, the student finds the teacher that it is most different from. The motivating intuition is that each teacher, trained on a unique task, will impart knowledge based on that unique task. The more the student learns from a particular teacher, the more familiar it becomes with knowledge derived from task data corresponding to the teacher. To be well balanced, the student should next learn knowledge that they are least familiar with.

However, we want to ensure that what the student learns is well-aligned with what he learned previously. To prevent clashing updates, for each potential gradient update induced by learning from a specific teacher, we consider its alignment with the previous gradient update. We draw from Yu et al.[1] and project conflicting gradients onto the normal plane of the previous gradient, while keeping non-conflicting gradients the same. Our objective function is a bi-objective function, where we consider both the Kullback-Leibler divergence of the student policy from the teacher policy and the gradient update alignment. Formally written: argmaxi[n]KL(π(s)πstud(s))+γmin(0,cos(i,prev))argmax_{i \in [n]} KL(\pi(\cdot | s) || \pi_{stud}(\cdot | s)) + \gamma min(0, cos(\nabla_i, \nabla_{prev})) . By seeking out knowledge from the teacher that they are most unfamiliar with—while ensuring non-conflicting gradient updates— our student model learns through a diverse and aligned curriculum.

References

[1] Yu, T., Kumar, S., Gupta, A., Levine, S., Hausman, K., & Finn, C. (2020). Gradient Surgery for Multi-Task Learning. CoRR, abs/2001.06782.

[2] Burda, Y., Edwards, H., Storkey, A., & Klimov, O. (2018). Exploration by Random Network Distillation. arXiv preprint arXiv:1810.12894.

[3] Hong, Z.-W., Shann, T.-Y., Su, S.-Y., Chang, Y.-H., & Lee, C.-Y. (2018). Diversity-Driven Exploration Strategy for Deep Reinforcement Learning. CoRR, abs/1802.04564.

[4] Chen, L., Lu, K., Rajeswaran, A., Lee, K., Grover, A., Laskin, M., Abbeel, P., Srinivas, A., & Mordatch, I. (2021). Decision Transformer: Reinforcement Learning via Sequence Modeling. arXiv preprint arXiv:2106.01345.

[5] Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. arXiv preprint arXiv:1503.02531.

[6] Rusu, A. A., Colmenarejo, S. G., Gulcehre, C., Desjardins, G., Kirkpatrick, J., Pascanu, R., Mnih, V., Kavukcuoglu, K., & Hadsell, R. (2016). Policy Distillation. arXiv preprint arXiv:1511.06295.

[7] Tseng, W.-C., Wang, T.-H. J., Lin, Y.-C., & Isola, P. (2022). Offline Multi-Agent Reinforcement Learning with Knowledge Distillation. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, & A. Oh (Eds.), Advances in Neural Information Processing Systems (Vol. 35, pp. 226-237).